diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 000000000000..ef7434efe377 --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1,2 @@ +.github/ @pablo-garay @ko3n1g +Dockerfile.ci @pablo-garay @ko3n1g diff --git a/.github/workflows/_test_template.yml b/.github/workflows/_test_template.yml index 0dbb1d50ee52..4ef6c5a9f9df 100644 --- a/.github/workflows/_test_template.yml +++ b/.github/workflows/_test_template.yml @@ -39,18 +39,13 @@ jobs: outputs: conclusion: ${{ steps.main.conclusion }} log: ${{ steps.main.outputs.log }} - container: - image: nemoci.azurecr.io/nemo_container_${{ github.run_id }} - options: - --device=/dev/nvidia0 - --gpus all - --shm-size=8g - --env TRANSFORMERS_OFFLINE=0 - --env HYDRA_FULL_ERROR=1 - --volume /mnt/datadrive/TestData:/home/TestData + permissions: + actions: write # Required for cancelling workflows steps: - - name: Checkout repository - uses: actions/checkout@v4 + - name: Docker system cleanup + run: | + docker system prune -a --filter "until=48h" --force + - id: main name: Run main script timeout-minutes: ${{ inputs.TIMEOUT }} @@ -59,7 +54,7 @@ jobs: ( set -e - ${{ inputs.SCRIPT }} + docker run --rm --device=/dev/nvidia0 --gpus all --shm-size=8g --env TRANSFORMERS_OFFLINE=0 --env HYDRA_FULL_ERROR=1 --volume /mnt/datadrive/TestData:/home/TestData nemoci.azurecr.io/nemo_container_${{ github.run_id }} bash -c '${{ inputs.SCRIPT }}' ) 2> >(tee err.log) EXIT_CODE=$? @@ -70,6 +65,9 @@ jobs: - uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main" if: failure() && inputs.IS_OPTIONAL == false + - name: after_script if: always() && inputs.AFTER_SCRIPT != ':' - run: ${{ inputs.AFTER_SCRIPT }} \ No newline at end of file + run: | + docker run --rm --device=/dev/nvidia0 --gpus all --shm-size=8g --env TRANSFORMERS_OFFLINE=0 --env HYDRA_FULL_ERROR=1 --volume /mnt/datadrive/TestData:/home/TestData nemoci.azurecr.io/nemo_container_${{ github.run_id }} bash -c '${{ inputs.AFTER_SCRIPT }}' + \ No newline at end of file diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index d225ee3ab429..253e114c78f3 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -125,6 +125,19 @@ jobs: ## - name: L2: Multimodal Imagen Train # L2: Community LLM Checkpoints tests + L2_Community_LLM_Checkpoints_tests_Bert: + needs: [cicd-test-container-setup] + uses: ./.github/workflows/_test_template.yml + with: + RUNNER: self-hosted-azure + SCRIPT: | + python scripts/checkpoint_converters/convert_bert_hf_to_nemo.py \ + --input_name_or_path /home/TestData/nlp/megatron_ir/sbert/hf_model/bert-base-uncased \ + --output_path /home/TestData/nlp/megatron_ir/sbert/sbert.nemo + AFTER_SCRIPT: | + rm -f /home/TestData/nlp/megatron_ir/sbert/sbert.nemo + rm -rf /home/TestData/nlp/megatron_ir/sbert/model_weights + L2_Community_LLM_Checkpoints_tests_Llama: needs: [cicd-test-container-setup] uses: ./.github/workflows/_test_template.yml @@ -179,7 +192,28 @@ jobs: rm -f /home/TestData/nlp/megatron_gpt/falcon-ci-hf/falcon_ci.nemo AFTER_SCRIPT: | rm -rf /home/TestData/nlp/megatron_gpt/falcon-ci-hf/model_weights - + + # L2: Community llava multimodal Checkpoints tests + L2_Community_vita_Checkpoints_tests_Llama3: + needs: [cicd-test-container-setup] + uses: ./.github/workflows/_test_template.yml + with: + RUNNER: self-hosted-azure + SCRIPT: | + export PYTHONPATH=/home/TestData/multimodal/video_neva/LLaVA:$PYTHONPATH + CUDA_VISIBLE_DEVICES=0 python examples/multimodal/multimodal_llm/neva/convert_llava_to_neva.py \ + --in-file /home/TestData/multimodal/video_neva/Llama-3-VILA1.5-8B/llm \ + --mm-projector-ckpt-dir /home/TestData/multimodal/video_neva/Llama-3-VILA1.5-8B/mm_projector \ + --mm-vision-tower /home/TestData/multimodal/video_neva/Llama-3-VILA1.5-8B/vision_tower \ + --tokenizer-model /home/TestData/multimodal/video_neva/vita-tokenizer/ \ + --config-file vita_config.yaml \ + --out-file=/home/TestData/multimodal/video_neva/llama3-ci-hf/llama3_ci.nemo \ + --model-type VITA \ + --conv-template llama_3 + AFTER_SCRIPT: | + rm -f /home/TestData/multimodal/video_neva/llama3-ci-hf/llama3_ci.nemo + rm -rf /home/TestData/multimodal/video_neva/llama3-ci-hf/model_weights + # this test is using a 7B model which is too large for GitHub CI # replace the model in this test with a toy model or move the test # to the nightly CI @@ -235,25 +269,29 @@ jobs: quantization.num_calib_size=8 \ inference.batch_size=2 \ export.inference_tensor_parallel=2 \ + export.sample_output=False \ export.save_path=/home/TestData/nlp/megatron_llama/ci_fp8.qnemo AFTER_SCRIPT: | rm -rf /home/TestData/nlp/megatron_llama/ci_fp8.qnemo - L2_PTQ_Llama2_INT8_SQ: + OPTIONAL_L2_PTQ_Llama2_INT8_SQ: needs: [cicd-test-container-setup] uses: ./.github/workflows/_test_template.yml with: RUNNER: self-hosted-azure + TIMEOUT: 15 SCRIPT: | python examples/nlp/language_modeling/megatron_gpt_ptq.py \ - model.restore_from_path=/home/TestData/nlp/megatron_llama/llama_ci.nemo \ - quantization.calib_dataset=/home/TestData/nlp/test_quantization/test.json \ - quantization.algorithm=int8_sq \ - quantization.num_calib_size=8 \ - inference.batch_size=2 \ - export.save_path=/home/TestData/nlp/megatron_llama/ci_int8_sq.qnemo + model.restore_from_path=/home/TestData/nlp/megatron_llama/llama_ci.nemo \ + quantization.calib_dataset=/home/TestData/nlp/test_quantization/test.json \ + quantization.algorithm=int8_sq \ + quantization.num_calib_size=8 \ + inference.batch_size=2 \ + export.sample_output=False \ + export.save_path=/home/TestData/nlp/megatron_llama/ci_int8_sq.qnemo AFTER_SCRIPT: | rm -rf /home/TestData/nlp/megatron_llama/ci_int8_sq.qnemo + IS_OPTIONAL: true # TODO: investigate int4_awq stuck issues and restore the test #L2_PTQ_Llama2_INT4_AWQ: @@ -288,44 +326,42 @@ jobs: #- uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main" # if: "failure()" - L2_QAT_Llama2_INT4: - needs: [cicd-test-container-setup] - runs-on: self-hosted-azure - timeout-minutes: 10 - container: - image: nemoci.azurecr.io/nemo_container_${{ github.run_id }} - options: - # --user 0:128 - --device=/dev/nvidia0 - --gpus all - --shm-size=8g - --env TRANSFORMERS_OFFLINE=0 - --env HYDRA_FULL_ERROR=1 - --volume /mnt/datadrive/TestData:/home/TestData - steps: - - name: Checkout repository - uses: actions/checkout@v4 - - run: | - python examples/nlp/language_modeling/tuning/megatron_gpt_qat.py \ - quantization.algorithm=int4 \ - quantization.num_calib_size=8 \ - trainer.devices=1 \ - trainer.num_nodes=1 \ - trainer.max_steps=4 \ - trainer.val_check_interval=4 \ - +trainer.limit_val_batches=2 \ - exp_manager.explicit_log_dir=llama2_qat_results \ - model.restore_from_path=/home/TestData/nlp/megatron_llama/llama_ci.nemo \ - model.tensor_model_parallel_size=1 \ - model.pipeline_model_parallel_size=1 \ - model.global_batch_size=2 \ - model.data.train_ds.file_names=[/home/TestData/nlp/megatron_sft/quarel.jsonl] \ - model.data.train_ds.concat_sampling_probabilities=[1.0] \ - model.data.validation_ds.file_names=[/home/TestData/nlp/megatron_sft/quarel.jsonl] - - rm -rf llama2_qat_results - - uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main" - if: "failure()" + # OPTIONAL_L2_QAT_Llama2_INT4: + # needs: [cicd-test-container-setup] + # runs-on: self-hosted-azure + # timeout-minutes: 10 + # container: + # image: nemoci.azurecr.io/nemo_container_${{ github.run_id }} + # options: + # # --user 0:128 + # --device=/dev/nvidia0 + # --gpus all + # --shm-size=8g + # --env TRANSFORMERS_OFFLINE=0 + # --env HYDRA_FULL_ERROR=1 + # --volume /mnt/datadrive/TestData:/home/TestData + # steps: + # - name: Checkout repository + # uses: actions/checkout@v4 + # - run: | + # python examples/nlp/language_modeling/tuning/megatron_gpt_qat.py \ + # quantization.algorithm=int4 \ + # quantization.num_calib_size=8 \ + # trainer.devices=1 \ + # trainer.num_nodes=1 \ + # trainer.max_steps=4 \ + # trainer.val_check_interval=4 \ + # +trainer.limit_val_batches=2 \ + # exp_manager.explicit_log_dir=llama2_qat_results \ + # model.restore_from_path=/home/TestData/nlp/megatron_llama/llama_ci.nemo \ + # model.tensor_model_parallel_size=1 \ + # model.pipeline_model_parallel_size=1 \ + # model.global_batch_size=2 \ + # model.data.train_ds.file_names=[/home/TestData/nlp/megatron_sft/quarel.jsonl] \ + # model.data.train_ds.concat_sampling_probabilities=[1.0] \ + # model.data.validation_ds.file_names=[/home/TestData/nlp/megatron_sft/quarel.jsonl] + + # rm -rf llama2_qat_results # L2: ASR dev run ASR_dev_run_Speech_to_Text: @@ -788,7 +824,7 @@ jobs: with: RUNNER: self-hosted-azure SCRIPT: | - pytest tests/collections/asr/decoding/rnnt_alignments_check.py --durations=-1 + pytest tests/collections/asr/decoding/rnnt_alignments_check.py --durations=-1 --with_downloads # L2: Segmentation Tool L2_Segmentation_Tool_Parallel_ctc_segmentation_test_L2_Eng_CitriNet_with_wav: @@ -1846,288 +1882,248 @@ jobs: # } L2_Megatron_Bert_Pretraining_and_Resume_Training_with_Pipeline_Parallelism: - needs: [cicd-test-container-setup] - runs-on: self-hosted-azure - timeout-minutes: 10 - container: - image: nemoci.azurecr.io/nemo_container_${{ github.run_id }} - options: - # --user 0:128 - --device=/dev/nvidia0 - --gpus all - --shm-size=8g - --env TRANSFORMERS_OFFLINE=0 - --env HYDRA_FULL_ERROR=1 - --volume /mnt/datadrive/TestData:/home/TestData - steps: - - name: Checkout repository - uses: actions/checkout@v4 - - run: | - NVTE_FUSED_ATTN=0 NVTE_FLASH_ATTN=0 python examples/nlp/language_modeling/megatron_bert_pretraining.py \ - trainer.devices=2 \ - trainer.accelerator=gpu \ - trainer.log_every_n_steps=1 \ - trainer.val_check_interval=10 \ - trainer.limit_val_batches=2 \ - trainer.accumulate_grad_batches=1 \ - trainer.max_steps=10 \ - trainer.precision=bf16 \ - model.megatron_amp_O2=True \ - trainer.gradient_clip_val=1.0 \ - exp_manager.exp_dir=examples/nlp/language_modeling/bert_pretrain_results \ - model.tensor_model_parallel_size=2 \ - model.optim.name=fused_adam \ - model.optim.lr=2e-4 \ - model.optim.sched.warmup_steps=2 \ - model.optim.sched.constant_steps=2 \ - model.optim.sched.min_lr=8e-5 \ - model.max_position_embeddings=128 \ - model.encoder_seq_length=128 \ - model.data.seq_length=128 \ - model.tokenizer.vocab_file=/home/TestData/nlp/megatron_bert/data/bert/vocab.txt \ - model.num_layers=8 \ - model.hidden_size=256 \ - model.num_attention_heads=8 \ - model.activations_checkpoint_method=block \ - model.activations_checkpoint_num_layers=1 \ - model.data.data_prefix=[.5,/home/TestData/nlp/megatron_bert/data/bert/simple_wiki_bert_preproc_text_sentence,.5,/home/TestData/nlp/megatron_bert/data/bert/simple_wiki_bert_preproc_text_sentence] \ - model.data.index_mapping_dir=examples/nlp/language_modeling/bert_index_mappings - - NVTE_FUSED_ATTN=0 NVTE_FLASH_ATTN=0 python examples/nlp/language_modeling/megatron_bert_pretraining.py \ - trainer.devices=2 \ - trainer.accelerator=gpu \ - trainer.log_every_n_steps=1 \ - trainer.val_check_interval=10 \ - trainer.limit_val_batches=2 \ - trainer.accumulate_grad_batches=1 \ - trainer.max_steps=20 \ - trainer.precision=bf16 \ - model.megatron_amp_O2=True \ - trainer.gradient_clip_val=1.0 \ - exp_manager.exp_dir=examples/nlp/language_modeling/bert_pretrain_results \ - exp_manager.resume_if_exists=True \ - model.tensor_model_parallel_size=2 \ - model.optim.name=fused_adam \ - model.optim.lr=2e-4 \ - model.optim.sched.warmup_steps=2 \ - model.optim.sched.constant_steps=2 \ - model.optim.sched.min_lr=8e-5 \ - model.max_position_embeddings=128 \ - model.encoder_seq_length=128 \ - model.data.seq_length=128 \ - model.tokenizer.vocab_file=/home/TestData/nlp/megatron_bert/data/bert/vocab.txt \ - model.num_layers=8 \ - model.hidden_size=256 \ - model.num_attention_heads=8 \ - model.activations_checkpoint_method=block \ - model.activations_checkpoint_num_layers=1 \ - model.data.data_prefix=[.5,/home/TestData/nlp/megatron_bert/data/bert/simple_wiki_bert_preproc_text_sentence,.5,/home/TestData/nlp/megatron_bert/data/bert/simple_wiki_bert_preproc_text_sentence] \ - model.data.index_mapping_dir=examples/nlp/language_modeling/bert_index_mappings - - L2_Megatron_Bert_Pretraining_and_Resume_Training: - needs: [cicd-test-container-setup] - runs-on: self-hosted-azure - timeout-minutes: 10 - container: - image: nemoci.azurecr.io/nemo_container_${{ github.run_id }} - options: - # --user 0:128 - --device=/dev/nvidia0 - --gpus all - --shm-size=8g - --env TRANSFORMERS_OFFLINE=0 - --env HYDRA_FULL_ERROR=1 - --volume /mnt/datadrive/TestData:/home/TestData - steps: - - name: Checkout repository - uses: actions/checkout@v4 - - run: | - NVTE_FUSED_ATTN=0 NVTE_FLASH_ATTN=0 python examples/nlp/language_modeling/megatron_bert_pretraining.py \ - trainer.devices=2 \ - trainer.accelerator=gpu \ - trainer.log_every_n_steps=1 \ - trainer.val_check_interval=10 \ - trainer.limit_val_batches=2 \ - trainer.accumulate_grad_batches=1 \ - trainer.max_steps=10 \ - trainer.precision=bf16 \ - model.megatron_amp_O2=True \ - trainer.gradient_clip_val=1.0 \ - exp_manager.exp_dir=examples/nlp/language_modeling/bert_pretrain_results \ - model.tensor_model_parallel_size=2 \ - model.optim.name=fused_adam \ - model.optim.lr=2e-4 \ - model.sequence_parallel=True \ - model.optim.sched.warmup_steps=2 \ - model.optim.sched.constant_steps=2 \ - model.optim.sched.min_lr=8e-5 \ - model.max_position_embeddings=128 \ - model.encoder_seq_length=128 \ - model.data.seq_length=128 \ - model.tokenizer.vocab_file=/home/TestData/nlp/megatron_bert/data/bert/vocab.txt \ - model.num_layers=8 \ - model.hidden_size=256 \ - model.num_attention_heads=8 \ - model.activations_checkpoint_method=block \ - model.activations_checkpoint_num_layers=1 \ - model.data.data_prefix=[.5,/home/TestData/nlp/megatron_bert/data/bert/simple_wiki_bert_preproc_text_sentence,.5,/home/TestData/nlp/megatron_bert/data/bert/simple_wiki_bert_preproc_text_sentence] \ - model.data.index_mapping_dir=examples/nlp/language_modeling/bert_index_mappings - - NVTE_FUSED_ATTN=0 NVTE_FLASH_ATTN=0 python examples/nlp/language_modeling/megatron_bert_pretraining.py \ - trainer.devices=2 \ - trainer.accelerator=gpu \ - trainer.log_every_n_steps=1 \ - trainer.val_check_interval=10 \ - trainer.limit_val_batches=2 \ - trainer.accumulate_grad_batches=1 \ - trainer.max_steps=20 \ - trainer.precision=bf16 \ - model.megatron_amp_O2=True \ - trainer.gradient_clip_val=1.0 \ - exp_manager.exp_dir=examples/nlp/language_modeling/bert_pretrain_results \ - exp_manager.resume_if_exists=True \ - model.tensor_model_parallel_size=2 \ - model.optim.name=fused_adam \ - model.optim.lr=2e-4 \ - model.optim.sched.warmup_steps=2 \ - model.optim.sched.constant_steps=2 \ - model.optim.sched.min_lr=8e-5 \ - model.max_position_embeddings=128 \ - model.encoder_seq_length=128 \ - model.data.seq_length=128 \ - model.tokenizer.vocab_file=/home/TestData/nlp/megatron_bert/data/bert/vocab.txt \ - model.num_layers=8 \ - model.hidden_size=256 \ - model.num_attention_heads=8 \ - model.activations_checkpoint_method=block \ - model.activations_checkpoint_num_layers=1 \ - model.data.data_prefix=[.5,/home/TestData/nlp/megatron_bert/data/bert/simple_wiki_bert_preproc_text_sentence,.5,/home/TestData/nlp/megatron_bert/data/bert/simple_wiki_bert_preproc_text_sentence] \ - model.data.index_mapping_dir=examples/nlp/language_modeling/bert_index_mappings - - rm -rf examples/nlp/language_modeling/bert_pretrain_results - rm -rf examples/nlp/language_modeling/bert_index_mappings - - uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main" - if: "failure()" - - L2_Megatron_Core_Bert_Pretraining_and_Resume_Training: - needs: [cicd-test-container-setup] - runs-on: self-hosted-azure - timeout-minutes: 10 - container: - image: nemoci.azurecr.io/nemo_container_${{ github.run_id }} - options: - # --user 0:128 - --device=/dev/nvidia0 - --gpus all - --shm-size=8g - --env TRANSFORMERS_OFFLINE=0 - --env HYDRA_FULL_ERROR=1 - --volume /mnt/datadrive/TestData:/home/TestData - steps: - - name: Checkout repository - uses: actions/checkout@v4 - - run: | - NVTE_FLASH_ATTN=0 NVTE_FUSED_ATTN=0 python examples/nlp/language_modeling/megatron_bert_pretraining.py \ - trainer.devices=2 \ - trainer.accelerator=gpu \ - trainer.log_every_n_steps=1 \ - trainer.val_check_interval=10 \ - trainer.limit_val_batches=2 \ - trainer.accumulate_grad_batches=1 \ - trainer.max_steps=10 \ - trainer.gradient_clip_val=1.0 \ - exp_manager.exp_dir=examples/nlp/language_modeling/bert_pretrain_results \ - model.mcore_bert=True \ - model.tensor_model_parallel_size=2 \ - model.optim.name=fused_adam \ - model.optim.lr=2e-4 \ - model.sequence_parallel=True \ - model.optim.sched.warmup_steps=2 \ - model.optim.sched.constant_steps=2 \ - model.optim.sched.min_lr=8e-5 \ - model.max_position_embeddings=128 \ - model.encoder_seq_length=128 \ - model.data.seq_length=128 \ - model.tokenizer.vocab_file=/home/TestData/nlp/megatron_bert/data/bert/vocab.txt \ - model.num_layers=8 \ - model.hidden_size=256 \ - model.num_attention_heads=8 \ - model.activations_checkpoint_method='block' \ - model.activations_checkpoint_num_layers=1 \ - model.data.data_prefix=[.5,/home/TestData/nlp/megatron_bert/data/bert/simple_wiki_bert_preproc_text_sentence,.5,/home/TestData/nlp/megatron_bert/data/bert/simple_wiki_bert_preproc_text_sentence] \ - model.data.index_mapping_dir=examples/nlp/language_modeling/bert_index_mappings - - NVTE_FLASH_ATTN=0 NVTE_FUSED_ATTN=0 python examples/nlp/language_modeling/megatron_bert_pretraining.py \ - trainer.devices=2 \ - trainer.accelerator=gpu \ - trainer.log_every_n_steps=1 \ - trainer.val_check_interval=10 \ - trainer.limit_val_batches=2 \ - trainer.accumulate_grad_batches=1 \ - trainer.max_steps=20 \ - trainer.gradient_clip_val=1.0 \ - exp_manager.exp_dir=examples/nlp/language_modeling/bert_pretrain_results \ - exp_manager.resume_if_exists=True \ - model.mcore_bert=True \ - model.tensor_model_parallel_size=2 \ - model.optim.name=fused_adam \ - model.optim.lr=2e-4 \ - model.optim.sched.warmup_steps=2 \ - model.optim.sched.constant_steps=2 \ - model.optim.sched.min_lr=8e-5 \ - model.max_position_embeddings=128 \ - model.encoder_seq_length=128 \ - model.data.seq_length=128 \ - model.tokenizer.vocab_file=/home/TestData/nlp/megatron_bert/data/bert/vocab.txt \ - model.num_layers=8 \ - model.hidden_size=256 \ - model.num_attention_heads=8 \ - model.activations_checkpoint_method='block' \ - model.activations_checkpoint_num_layers=1 \ - model.data.data_prefix=[.5,/home/TestData/nlp/megatron_bert/data/bert/simple_wiki_bert_preproc_text_sentence,.5,/home/TestData/nlp/megatron_bert/data/bert/simple_wiki_bert_preproc_text_sentence] \ - model.data.index_mapping_dir=examples/nlp/language_modeling/bert_index_mappings - - rm -rf examples/nlp/language_modeling/bert_pretrain_results - rm -rf examples/nlp/language_modeling/bert_index_mappings - - uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main" - if: "failure()" - - L2_Megatron_RETRO_Pretraining_and_Resume_Training: needs: [cicd-test-container-setup] uses: ./.github/workflows/_test_template.yml with: RUNNER: self-hosted-azure SCRIPT: | - python examples/nlp/language_modeling/megatron_retro_pretraining.py \ - trainer.num_nodes=1 \ + NVTE_FUSED_ATTN=0 NVTE_FLASH_ATTN=0 python examples/nlp/language_modeling/megatron_bert_pretraining.py \ trainer.devices=2 \ - trainer.precision=bf16 \ trainer.accelerator=gpu \ - model.data.data_prefix=['none'] \ - exp_manager.exp_dir=examples/nlp/language_modeling/mcore_retro_results \ - model.mcore_gpt=True \ - model.tensor_model_parallel_size=1 \ - model.pipeline_model_parallel_size=1 \ - model.optim.name=distributed_fused_adam \ - model.retro.retro_project_dir=/home/TestData/nlp/megatron_retro/mcore_retro/micro-wiki-core \ - model.data.num_workers=4 \ - model.micro_batch_size=1 \ - model.data.shuffle_documents=False \ - trainer.val_check_interval=30 \ - +trainer.num_sanity_val_steps=0 \ - model.init_method_std=0.023 \ - model.optim.lr=6.0e-4 \ - model.megatron_amp_O2=True \ - model.data.splits_string=\'\"98,2,0\"\' \ - model.data.dataloader_type=cyclic \ - trainer.max_steps=10 - - python examples/nlp/language_modeling/megatron_retro_pretraining.py \ - trainer.num_nodes=1 \ - trainer.devices=2 \ + trainer.log_every_n_steps=1 \ + trainer.val_check_interval=10 \ + trainer.limit_val_batches=2 \ + trainer.accumulate_grad_batches=1 \ + trainer.max_steps=10 \ trainer.precision=bf16 \ - trainer.accelerator=gpu \ - model.data.data_prefix=['none'] \ + model.megatron_amp_O2=True \ + trainer.gradient_clip_val=1.0 \ + exp_manager.exp_dir=examples/nlp/language_modeling/bert_pretrain_results \ + model.tensor_model_parallel_size=2 \ + model.optim.name=fused_adam \ + model.optim.lr=2e-4 \ + model.optim.sched.warmup_steps=2 \ + model.optim.sched.constant_steps=2 \ + model.optim.sched.min_lr=8e-5 \ + model.max_position_embeddings=128 \ + model.encoder_seq_length=128 \ + model.data.seq_length=128 \ + model.tokenizer.vocab_file=/home/TestData/nlp/megatron_bert/data/bert/vocab.txt \ + model.num_layers=8 \ + model.hidden_size=256 \ + model.num_attention_heads=8 \ + model.activations_checkpoint_method=block \ + model.activations_checkpoint_num_layers=1 \ + model.data.data_prefix=[.5,/home/TestData/nlp/megatron_bert/data/bert/simple_wiki_bert_preproc_text_sentence,.5,/home/TestData/nlp/megatron_bert/data/bert/simple_wiki_bert_preproc_text_sentence] \ + model.data.index_mapping_dir=examples/nlp/language_modeling/bert_index_mappings + + NVTE_FUSED_ATTN=0 NVTE_FLASH_ATTN=0 python examples/nlp/language_modeling/megatron_bert_pretraining.py \ + trainer.devices=2 \ + trainer.accelerator=gpu \ + trainer.log_every_n_steps=1 \ + trainer.val_check_interval=10 \ + trainer.limit_val_batches=2 \ + trainer.accumulate_grad_batches=1 \ + trainer.max_steps=20 \ + trainer.precision=bf16 \ + model.megatron_amp_O2=True \ + trainer.gradient_clip_val=1.0 \ + exp_manager.exp_dir=examples/nlp/language_modeling/bert_pretrain_results \ + exp_manager.resume_if_exists=True \ + model.tensor_model_parallel_size=2 \ + model.optim.name=fused_adam \ + model.optim.lr=2e-4 \ + model.optim.sched.warmup_steps=2 \ + model.optim.sched.constant_steps=2 \ + model.optim.sched.min_lr=8e-5 \ + model.max_position_embeddings=128 \ + model.encoder_seq_length=128 \ + model.data.seq_length=128 \ + model.tokenizer.vocab_file=/home/TestData/nlp/megatron_bert/data/bert/vocab.txt \ + model.num_layers=8 \ + model.hidden_size=256 \ + model.num_attention_heads=8 \ + model.activations_checkpoint_method=block \ + model.activations_checkpoint_num_layers=1 \ + model.data.data_prefix=[.5,/home/TestData/nlp/megatron_bert/data/bert/simple_wiki_bert_preproc_text_sentence,.5,/home/TestData/nlp/megatron_bert/data/bert/simple_wiki_bert_preproc_text_sentence] \ + model.data.index_mapping_dir=examples/nlp/language_modeling/bert_index_mappings + + L2_Megatron_Bert_Pretraining_and_Resume_Training: + needs: [cicd-test-container-setup] + uses: ./.github/workflows/_test_template.yml + with: + RUNNER: self-hosted-azure + SCRIPT: | + NVTE_FUSED_ATTN=0 NVTE_FLASH_ATTN=0 python examples/nlp/language_modeling/megatron_bert_pretraining.py \ + trainer.devices=2 \ + trainer.accelerator=gpu \ + trainer.log_every_n_steps=1 \ + trainer.val_check_interval=10 \ + trainer.limit_val_batches=2 \ + trainer.accumulate_grad_batches=1 \ + trainer.max_steps=10 \ + trainer.precision=bf16 \ + model.megatron_amp_O2=True \ + trainer.gradient_clip_val=1.0 \ + exp_manager.exp_dir=examples/nlp/language_modeling/bert_pretrain_results \ + model.tensor_model_parallel_size=2 \ + model.optim.name=fused_adam \ + model.optim.lr=2e-4 \ + model.sequence_parallel=True \ + model.optim.sched.warmup_steps=2 \ + model.optim.sched.constant_steps=2 \ + model.optim.sched.min_lr=8e-5 \ + model.max_position_embeddings=128 \ + model.encoder_seq_length=128 \ + model.data.seq_length=128 \ + model.tokenizer.vocab_file=/home/TestData/nlp/megatron_bert/data/bert/vocab.txt \ + model.num_layers=8 \ + model.hidden_size=256 \ + model.num_attention_heads=8 \ + model.activations_checkpoint_method=block \ + model.activations_checkpoint_num_layers=1 \ + model.data.data_prefix=[.5,/home/TestData/nlp/megatron_bert/data/bert/simple_wiki_bert_preproc_text_sentence,.5,/home/TestData/nlp/megatron_bert/data/bert/simple_wiki_bert_preproc_text_sentence] \ + model.data.index_mapping_dir=examples/nlp/language_modeling/bert_index_mappings + + NVTE_FUSED_ATTN=0 NVTE_FLASH_ATTN=0 python examples/nlp/language_modeling/megatron_bert_pretraining.py \ + trainer.devices=2 \ + trainer.accelerator=gpu \ + trainer.log_every_n_steps=1 \ + trainer.val_check_interval=10 \ + trainer.limit_val_batches=2 \ + trainer.accumulate_grad_batches=1 \ + trainer.max_steps=20 \ + trainer.precision=bf16 \ + model.megatron_amp_O2=True \ + trainer.gradient_clip_val=1.0 \ + exp_manager.exp_dir=examples/nlp/language_modeling/bert_pretrain_results \ + exp_manager.resume_if_exists=True \ + model.tensor_model_parallel_size=2 \ + model.optim.name=fused_adam \ + model.optim.lr=2e-4 \ + model.optim.sched.warmup_steps=2 \ + model.optim.sched.constant_steps=2 \ + model.optim.sched.min_lr=8e-5 \ + model.max_position_embeddings=128 \ + model.encoder_seq_length=128 \ + model.data.seq_length=128 \ + model.tokenizer.vocab_file=/home/TestData/nlp/megatron_bert/data/bert/vocab.txt \ + model.num_layers=8 \ + model.hidden_size=256 \ + model.num_attention_heads=8 \ + model.activations_checkpoint_method=block \ + model.activations_checkpoint_num_layers=1 \ + model.data.data_prefix=[.5,/home/TestData/nlp/megatron_bert/data/bert/simple_wiki_bert_preproc_text_sentence,.5,/home/TestData/nlp/megatron_bert/data/bert/simple_wiki_bert_preproc_text_sentence] \ + model.data.index_mapping_dir=examples/nlp/language_modeling/bert_index_mappings + AFTER_SCRIPT: | + rm -rf examples/nlp/language_modeling/bert_pretrain_results + rm -rf examples/nlp/language_modeling/bert_index_mappings + + L2_Megatron_Core_Bert_Pretraining_and_Resume_Training: + needs: [cicd-test-container-setup] + uses: ./.github/workflows/_test_template.yml + with: + RUNNER: self-hosted-azure + SCRIPT: | + NVTE_FLASH_ATTN=0 NVTE_FUSED_ATTN=0 python examples/nlp/language_modeling/megatron_bert_pretraining.py \ + trainer.devices=2 \ + trainer.accelerator=gpu \ + trainer.log_every_n_steps=1 \ + trainer.val_check_interval=10 \ + trainer.limit_val_batches=2 \ + trainer.accumulate_grad_batches=1 \ + trainer.max_steps=10 \ + trainer.gradient_clip_val=1.0 \ + exp_manager.exp_dir=examples/nlp/language_modeling/bert_pretrain_results \ + model.mcore_bert=True \ + model.tensor_model_parallel_size=2 \ + model.optim.name=fused_adam \ + model.optim.lr=2e-4 \ + model.sequence_parallel=True \ + model.optim.sched.warmup_steps=2 \ + model.optim.sched.constant_steps=2 \ + model.optim.sched.min_lr=8e-5 \ + model.max_position_embeddings=128 \ + model.encoder_seq_length=128 \ + model.data.seq_length=128 \ + model.tokenizer.vocab_file=/home/TestData/nlp/megatron_bert/data/bert/vocab.txt \ + model.num_layers=8 \ + model.hidden_size=256 \ + model.num_attention_heads=8 \ + model.activations_checkpoint_method='block' \ + model.activations_checkpoint_num_layers=1 \ + model.data.data_prefix=[.5,/home/TestData/nlp/megatron_bert/data/bert/simple_wiki_bert_preproc_text_sentence,.5,/home/TestData/nlp/megatron_bert/data/bert/simple_wiki_bert_preproc_text_sentence] \ + model.data.index_mapping_dir=examples/nlp/language_modeling/bert_index_mappings + + NVTE_FLASH_ATTN=0 NVTE_FUSED_ATTN=0 python examples/nlp/language_modeling/megatron_bert_pretraining.py \ + trainer.devices=2 \ + trainer.accelerator=gpu \ + trainer.log_every_n_steps=1 \ + trainer.val_check_interval=10 \ + trainer.limit_val_batches=2 \ + trainer.accumulate_grad_batches=1 \ + trainer.max_steps=20 \ + trainer.gradient_clip_val=1.0 \ + exp_manager.exp_dir=examples/nlp/language_modeling/bert_pretrain_results \ + exp_manager.resume_if_exists=True \ + model.mcore_bert=True \ + model.tensor_model_parallel_size=2 \ + model.optim.name=fused_adam \ + model.optim.lr=2e-4 \ + model.optim.sched.warmup_steps=2 \ + model.optim.sched.constant_steps=2 \ + model.optim.sched.min_lr=8e-5 \ + model.max_position_embeddings=128 \ + model.encoder_seq_length=128 \ + model.data.seq_length=128 \ + model.tokenizer.vocab_file=/home/TestData/nlp/megatron_bert/data/bert/vocab.txt \ + model.num_layers=8 \ + model.hidden_size=256 \ + model.num_attention_heads=8 \ + model.activations_checkpoint_method='block' \ + model.activations_checkpoint_num_layers=1 \ + model.data.data_prefix=[.5,/home/TestData/nlp/megatron_bert/data/bert/simple_wiki_bert_preproc_text_sentence,.5,/home/TestData/nlp/megatron_bert/data/bert/simple_wiki_bert_preproc_text_sentence] \ + model.data.index_mapping_dir=examples/nlp/language_modeling/bert_index_mappings + AFTER_SCRIPT: | + rm -rf examples/nlp/language_modeling/bert_pretrain_results + rm -rf examples/nlp/language_modeling/bert_index_mappings + + L2_Megatron_RETRO_Pretraining_and_Resume_Training: + needs: [cicd-test-container-setup] + uses: ./.github/workflows/_test_template.yml + with: + RUNNER: self-hosted-azure + SCRIPT: | + python examples/nlp/language_modeling/megatron_retro_pretraining.py \ + trainer.num_nodes=1 \ + trainer.devices=2 \ + trainer.precision=bf16 \ + trainer.accelerator=gpu \ + model.data.data_prefix=['none'] \ + exp_manager.exp_dir=examples/nlp/language_modeling/mcore_retro_results \ + model.mcore_gpt=True \ + model.tensor_model_parallel_size=1 \ + model.pipeline_model_parallel_size=1 \ + model.optim.name=distributed_fused_adam \ + model.retro.retro_project_dir=/home/TestData/nlp/megatron_retro/mcore_retro/micro-wiki-core \ + model.data.num_workers=4 \ + model.micro_batch_size=1 \ + model.data.shuffle_documents=False \ + trainer.val_check_interval=30 \ + +trainer.num_sanity_val_steps=0 \ + model.init_method_std=0.023 \ + model.optim.lr=6.0e-4 \ + model.megatron_amp_O2=True \ + model.data.splits_string=\'\"98,2,0\"\' \ + model.data.dataloader_type=cyclic \ + trainer.max_steps=10 + + python examples/nlp/language_modeling/megatron_retro_pretraining.py \ + trainer.num_nodes=1 \ + trainer.devices=2 \ + trainer.precision=bf16 \ + trainer.accelerator=gpu \ + model.data.data_prefix=['none'] \ exp_manager.exp_dir=examples/nlp/language_modeling/mcore_retro_results \ model.mcore_gpt=True \ model.tensor_model_parallel_size=1 \ @@ -2314,65 +2310,37 @@ jobs: L2_RAG_Pipeline_Indexing: needs: [cicd-test-container-setup] - runs-on: self-hosted-azure - timeout-minutes: 10 - container: - image: nemoci.azurecr.io/nemo_container_${{ github.run_id }} - options: - # --user 0:128 - --device=/dev/nvidia0 - --gpus all - --shm-size=8g - --env TRANSFORMERS_OFFLINE=0 - --env HYDRA_FULL_ERROR=1 - --volume /mnt/datadrive/TestData:/home/TestData - steps: - - name: Checkout repository - uses: actions/checkout@v4 - - run: | - python examples/nlp/rag/rag_indexing.py \ - trainer.num_nodes=1 \ - trainer.devices=1 \ - trainer.precision='bf16-mixed' \ - indexing.embedder.model_path='/home/TestData/nlp/rag_pipeline/testing_models/embedders/sbert_nemo.nemo' \ - indexing.embedder.embed_batch_size=128 \ - indexing.data.data_path='/home/TestData/nlp/rag_pipeline/testing_data/corpus_data/sample_data' \ - indexing.data.chunk_size=256 \ - indexing.data.chunk_overlap=10 \ - indexing.index_path='/home/TestData/nlp/rag_pipeline/testing_data/saved_index/sample_index' - - uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main" - if: "failure()" + uses: ./.github/workflows/_test_template.yml + with: + RUNNER: self-hosted-azure + SCRIPT: | + python examples/nlp/rag/rag_indexing.py \ + trainer.num_nodes=1 \ + trainer.devices=1 \ + trainer.precision='bf16-mixed' \ + indexing.embedder.model_path='/home/TestData/nlp/rag_pipeline/testing_models/embedders/sbert_nemo.nemo' \ + indexing.embedder.embed_batch_size=128 \ + indexing.data.data_path='/home/TestData/nlp/rag_pipeline/testing_data/corpus_data/sample_data' \ + indexing.data.chunk_size=256 \ + indexing.data.chunk_overlap=10 \ + indexing.index_path='/home/TestData/nlp/rag_pipeline/testing_data/saved_index/sample_index' L2_RAG_Pipeline_Generating: needs: [cicd-test-container-setup] - runs-on: self-hosted-azure - timeout-minutes: 10 - container: - image: nemoci.azurecr.io/nemo_container_${{ github.run_id }} - options: - # --user 0:128 - --device=/dev/nvidia0 - --gpus all - --shm-size=8g - --env TRANSFORMERS_OFFLINE=0 - --env HYDRA_FULL_ERROR=1 - --volume /mnt/datadrive/TestData:/home/TestData - steps: - - name: Checkout repository - uses: actions/checkout@v4 - - run: | - python examples/nlp/rag/rag_generating.py \ - trainer.devices=1 \ - trainer.precision='bf16-mixed' \ - indexing.embedder.model_path='/home/TestData/nlp/rag_pipeline/testing_models/embedders/sbert_nemo.nemo' \ - indexing.index_path='/home/TestData/nlp/rag_pipeline/testing_data/saved_index/sample_index' \ - generating.llm.model_path='/home/TestData/nlp/rag_pipeline/testing_models/llms/megatron_gpt_125m.nemo' \ - generating.inference.tokens_to_generate=50 \ - generating.inference.greedy=False \ - generating.inference.temperature=1.0 \ - generating.query='Which art schools did I applied to?' - - uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main" - if: "failure()" + uses: ./.github/workflows/_test_template.yml + with: + RUNNER: self-hosted-azure + SCRIPT: | + python examples/nlp/rag/rag_generating.py \ + trainer.devices=1 \ + trainer.precision='bf16-mixed' \ + indexing.embedder.model_path='/home/TestData/nlp/rag_pipeline/testing_models/embedders/sbert_nemo.nemo' \ + indexing.index_path='/home/TestData/nlp/rag_pipeline/testing_data/saved_index/sample_index' \ + generating.llm.model_path='/home/TestData/nlp/rag_pipeline/testing_models/llms/megatron_gpt_125m.nemo' \ + generating.inference.tokens_to_generate=50 \ + generating.inference.greedy=False \ + generating.inference.temperature=1.0 \ + generating.query='Which art schools did I applied to?' L2_BioMegatron_Bert_NER_Task: needs: [cicd-test-container-setup] @@ -2391,7 +2359,7 @@ jobs: L2_Megatron_GPT_Pretraining_and_Resume_Training_TP2: needs: [cicd-test-container-setup] - runs-on: self-hosted-azure + runs-on: self-hosted-azure-gpus-2-h100 timeout-minutes: 10 container: image: nemoci.azurecr.io/nemo_container_${{ github.run_id }} @@ -2403,6 +2371,21 @@ jobs: --env TRANSFORMERS_OFFLINE=0 --env HYDRA_FULL_ERROR=1 --volume /mnt/datadrive/TestData:/home/TestData + env: + # This is to improve p2p overlap on H100 + NVTE_FWD_LAYERNORM_SM_MARGIN: 8 + NVTE_BWD_LAYERNORM_SM_MARGIN: 8 + TORCH_NCCL_AVOID_RECORD_STREAMS: 1 + NCCL_MIN_NCHANNELS: 4 + # TP overlap is not supported in docker environment + #NVTE_UB_SPLIT_RS: 0 + #NVTE_UB_ATOMIC_GEMM_RS: 1 + #NVTE_RS_STRIDED_ATOMIC: 1 + #NVTE_UB_FP8_RS: 1 + # Increase p2p chunksize to 2MB + NCCL_P2P_NET_CHUNKSIZE: 2097152 + # Disable gc when switching to/from validation steps + NEMO_MANUAL_GC_IN_VALIDATION: 0 steps: - name: Checkout repository uses: actions/checkout@v4 @@ -2417,8 +2400,17 @@ jobs: trainer.max_steps=3 \ trainer.gradient_clip_val=1.0 \ exp_manager.exp_dir=examples/nlp/language_modeling/gpt_pretrain_results \ + ++model.transformer_engine=True \ + ++model.fp8=True \ + ++model.fp8_hybrid=True \ + ++model.fp8_amax_history_len=1024 \ + ++model.fp8_amax_compute_algo=max \ + ++model.reduce_amax=True \ + ++model.use_te_rng_tracker=True \ + ++model.name=megatron_gpt_full_te_layer_autocast \ + model.ub_tp_comm_overlap=False \ model.tensor_model_parallel_size=2 \ - model.optim.name=fused_adam \ + model.optim.name=distributed_fused_adam \ model.optim.lr=2e-4 \ model.optim.sched.warmup_steps=1 \ model.optim.sched.constant_steps=1 \ @@ -2452,8 +2444,17 @@ jobs: trainer.gradient_clip_val=1.0 \ exp_manager.exp_dir=examples/nlp/language_modeling/gpt_pretrain_results \ exp_manager.resume_if_exists=True \ + ++model.transformer_engine=True \ + ++model.fp8=True \ + ++model.fp8_hybrid=True \ + ++model.fp8_amax_history_len=1024 \ + ++model.fp8_amax_compute_algo=max \ + ++model.reduce_amax=True \ + ++model.use_te_rng_tracker=True \ + ++model.name=megatron_gpt_full_te_layer_autocast \ + model.ub_tp_comm_overlap=False \ model.tensor_model_parallel_size=2 \ - model.optim.name=fused_adam \ + model.optim.name=distributed_fused_adam \ model.optim.lr=2e-4 \ model.optim.sched.warmup_steps=2 \ model.optim.sched.constant_steps=2 \ @@ -2483,99 +2484,85 @@ jobs: L2_Megatron_GPT_with_Rope_Pretraining_and_Resume_Training_TP2: needs: [cicd-test-container-setup] - runs-on: self-hosted-azure - timeout-minutes: 10 - container: - image: nemoci.azurecr.io/nemo_container_${{ github.run_id }} - options: - # --user 0:128 - --device=/dev/nvidia0 - --gpus all - --shm-size=8g - --env TRANSFORMERS_OFFLINE=0 - --env HYDRA_FULL_ERROR=1 - --volume /mnt/datadrive/TestData:/home/TestData - steps: - - name: Checkout repository - uses: actions/checkout@v4 - - run: | - python examples/nlp/language_modeling/megatron_gpt_pretraining.py \ - trainer.devices=2 \ - trainer.accelerator=gpu \ - trainer.log_every_n_steps=1 \ - trainer.val_check_interval=2 \ - trainer.limit_val_batches=2 \ - trainer.accumulate_grad_batches=1 \ - trainer.max_steps=3 \ - trainer.gradient_clip_val=1.0 \ - exp_manager.exp_dir=examples/nlp/language_modeling/gpt_pretrain_results \ - model.tensor_model_parallel_size=2 \ - model.optim.name=fused_adam \ - model.optim.lr=2e-4 \ - model.optim.sched.warmup_steps=1 \ - model.optim.sched.constant_steps=1 \ - model.optim.sched.min_lr=8e-5 \ - model.max_position_embeddings=128 \ - model.encoder_seq_length=128 \ - model.data.seq_length=128 \ - model.position_embedding_type=rope \ - model.rotary_percentage=0.5 \ - model.bias=False \ - model.bias_activation_fusion=False \ - model.bias_dropout_add_fusion=False \ - model.tokenizer.vocab_file=/home/TestData/nlp/megatron_gpt/data/gpt/vocab.json \ - model.tokenizer.merge_file=/home/TestData/nlp/megatron_gpt/data/gpt/merges.txt \ - model.num_layers=8 \ - model.hidden_size=256 \ - model.num_attention_heads=8 \ - model.activations_checkpoint_method=block \ - model.activations_checkpoint_granularity=full \ - model.activations_checkpoint_num_layers=1 \ - model.data.data_prefix=[.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document,.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document] \ - model.data.index_mapping_dir=examples/nlp/language_modeling/gpt_index_mappings - - # commented out to save time on github ci @adithyare - # python examples/nlp/language_modeling/megatron_gpt_pretraining.py \ - # trainer.devices=2 \ - # trainer.accelerator=gpu \ - # trainer.log_every_n_steps=1 \ - # trainer.val_check_interval=2 \ - # trainer.limit_val_batches=1 \ - # trainer.accumulate_grad_batches=1 \ - # trainer.max_steps=6 \ - # trainer.gradient_clip_val=1.0 \ - # exp_manager.exp_dir=examples/nlp/language_modeling/gpt_pretrain_results \ - # exp_manager.resume_if_exists=True \ - # model.tensor_model_parallel_size=2 \ - # model.optim.name=fused_adam \ - # model.optim.lr=2e-4 \ - # model.optim.sched.warmup_steps=2 \ - # model.optim.sched.constant_steps=2 \ - # model.optim.sched.min_lr=8e-5 \ - # model.max_position_embeddings=128 \ - # model.encoder_seq_length=128 \ - # model.data.seq_length=128 \ - # model.position_embedding_type=rope \ - # model.rotary_percentage=0.5 \ - # model.normalization=rmsnorm \ - # model.bias=False \ - # model.bias_activation_fusion=False \ - # model.bias_dropout_add_fusion=False \ - # model.tokenizer.vocab_file=/home/TestData/nlp/megatron_gpt/data/gpt/vocab.json \ - # model.tokenizer.merge_file=/home/TestData/nlp/megatron_gpt/data/gpt/merges.txt \ - # model.num_layers=8 \ - # model.hidden_size=256 \ - # model.num_attention_heads=8 \ - # model.activations_checkpoint_method=block \ - # model.activations_checkpoint_granularity=full \ - # model.activations_checkpoint_num_layers=1 \ - # model.data.data_prefix=[.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document,.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document] \ - # model.data.index_mapping_dir=examples/nlp/language_modeling/gpt_index_mappings" - - rm -rf examples/nlp/language_modeling/gpt_pretrain_results - rm -rf examples/nlp/language_modeling/gpt_index_mappings - - uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main" - if: "failure()" + uses: ./.github/workflows/_test_template.yml + with: + RUNNER: self-hosted-azure + SCRIPT: | + python examples/nlp/language_modeling/megatron_gpt_pretraining.py \ + trainer.devices=2 \ + trainer.accelerator=gpu \ + trainer.log_every_n_steps=1 \ + trainer.val_check_interval=2 \ + trainer.limit_val_batches=2 \ + trainer.accumulate_grad_batches=1 \ + trainer.max_steps=3 \ + trainer.gradient_clip_val=1.0 \ + exp_manager.exp_dir=examples/nlp/language_modeling/gpt_pretrain_results \ + model.tensor_model_parallel_size=2 \ + model.optim.name=fused_adam \ + model.optim.lr=2e-4 \ + model.optim.sched.warmup_steps=1 \ + model.optim.sched.constant_steps=1 \ + model.optim.sched.min_lr=8e-5 \ + model.max_position_embeddings=128 \ + model.encoder_seq_length=128 \ + model.data.seq_length=128 \ + model.position_embedding_type=rope \ + model.rotary_percentage=0.5 \ + model.bias=False \ + model.bias_activation_fusion=False \ + model.bias_dropout_add_fusion=False \ + model.tokenizer.vocab_file=/home/TestData/nlp/megatron_gpt/data/gpt/vocab.json \ + model.tokenizer.merge_file=/home/TestData/nlp/megatron_gpt/data/gpt/merges.txt \ + model.num_layers=8 \ + model.hidden_size=256 \ + model.num_attention_heads=8 \ + model.activations_checkpoint_method=block \ + model.activations_checkpoint_granularity=full \ + model.activations_checkpoint_num_layers=1 \ + model.data.data_prefix=[.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document,.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document] \ + model.data.index_mapping_dir=examples/nlp/language_modeling/gpt_index_mappings + + # commented out to save time on github ci @adithyare + # python examples/nlp/language_modeling/megatron_gpt_pretraining.py \ + # trainer.devices=2 \ + # trainer.accelerator=gpu \ + # trainer.log_every_n_steps=1 \ + # trainer.val_check_interval=2 \ + # trainer.limit_val_batches=1 \ + # trainer.accumulate_grad_batches=1 \ + # trainer.max_steps=6 \ + # trainer.gradient_clip_val=1.0 \ + # exp_manager.exp_dir=examples/nlp/language_modeling/gpt_pretrain_results \ + # exp_manager.resume_if_exists=True \ + # model.tensor_model_parallel_size=2 \ + # model.optim.name=fused_adam \ + # model.optim.lr=2e-4 \ + # model.optim.sched.warmup_steps=2 \ + # model.optim.sched.constant_steps=2 \ + # model.optim.sched.min_lr=8e-5 \ + # model.max_position_embeddings=128 \ + # model.encoder_seq_length=128 \ + # model.data.seq_length=128 \ + # model.position_embedding_type=rope \ + # model.rotary_percentage=0.5 \ + # model.normalization=rmsnorm \ + # model.bias=False \ + # model.bias_activation_fusion=False \ + # model.bias_dropout_add_fusion=False \ + # model.tokenizer.vocab_file=/home/TestData/nlp/megatron_gpt/data/gpt/vocab.json \ + # model.tokenizer.merge_file=/home/TestData/nlp/megatron_gpt/data/gpt/merges.txt \ + # model.num_layers=8 \ + # model.hidden_size=256 \ + # model.num_attention_heads=8 \ + # model.activations_checkpoint_method=block \ + # model.activations_checkpoint_granularity=full \ + # model.activations_checkpoint_num_layers=1 \ + # model.data.data_prefix=[.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document,.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document] \ + # model.data.index_mapping_dir=examples/nlp/language_modeling/gpt_index_mappings" + AFTER_SCRIPT: | + rm -rf examples/nlp/language_modeling/gpt_pretrain_results + rm -rf examples/nlp/language_modeling/gpt_index_mappings # This test requires Ampere but some of the test GPUs are Volta # Need to add a check for compute capability before uncommenting this test @@ -2671,284 +2658,243 @@ jobs: L2_Megatron_GPT_with_ResetLR_Pretraining_and_Resume_Training_TP2: needs: [cicd-test-container-setup] - runs-on: self-hosted-azure - timeout-minutes: 10 - container: - image: nemoci.azurecr.io/nemo_container_${{ github.run_id }} - options: - # --user 0:128 - --device=/dev/nvidia0 - --gpus all - --shm-size=8g - --env TRANSFORMERS_OFFLINE=0 - --env HYDRA_FULL_ERROR=1 - --volume /mnt/datadrive/TestData:/home/TestData - steps: - - name: Checkout repository - uses: actions/checkout@v4 - - run: | - python examples/nlp/language_modeling/megatron_gpt_pretraining.py \ - trainer.devices=2 \ - trainer.accelerator=gpu \ - trainer.log_every_n_steps=1 \ - trainer.val_check_interval=3 \ - trainer.limit_val_batches=2 \ - trainer.accumulate_grad_batches=1 \ - trainer.max_steps=3 \ - trainer.precision=bf16 \ - trainer.gradient_clip_val=1.0 \ - exp_manager.exp_dir=examples/nlp/language_modeling/gpt_pretrain_results \ - model.tensor_model_parallel_size=2 \ - model.megatron_amp_O2=True \ - model.optim.name=distributed_fused_adam \ - model.optim.lr=2e-4 \ - model.optim.sched.warmup_steps=2 \ - model.optim.sched.constant_steps=2 \ - model.optim.sched.min_lr=8e-5 \ - model.max_position_embeddings=128 \ - model.encoder_seq_length=128 \ - model.data.seq_length=128 \ - model.tokenizer.vocab_file=/home/TestData/nlp/megatron_gpt/data/gpt/vocab.json \ - model.tokenizer.merge_file=/home/TestData/nlp/megatron_gpt/data/gpt/merges.txt \ - model.num_layers=8 \ - model.hidden_size=256 \ - model.num_attention_heads=8 \ - model.data.data_prefix=[.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document,.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document] \ - model.data.index_mapping_dir=examples/nlp/language_modeling/gpt_index_mappings - - python examples/nlp/language_modeling/megatron_gpt_pretraining.py \ - trainer.devices=2 \ - trainer.accelerator=gpu \ - trainer.log_every_n_steps=1 \ - trainer.val_check_interval=3 \ - trainer.limit_val_batches=2 \ - trainer.accumulate_grad_batches=1 \ - trainer.max_steps=6 \ - trainer.precision=bf16 \ - trainer.gradient_clip_val=1.0 \ - exp_manager.exp_dir=examples/nlp/language_modeling/gpt_pretrain_results \ - exp_manager.resume_if_exists=True \ - model.reset_lr=True \ - model.tensor_model_parallel_size=2 \ - model.megatron_amp_O2=True \ - model.optim.name=distributed_fused_adam \ - model.optim.lr=2e-4 \ - model.optim.sched.warmup_steps=2 \ - model.optim.sched.constant_steps=2 \ - model.optim.sched.min_lr=8e-5 \ - model.max_position_embeddings=128 \ - model.encoder_seq_length=128 \ - model.data.seq_length=128 \ - model.tokenizer.vocab_file=/home/TestData/nlp/megatron_gpt/data/gpt/vocab.json \ - model.tokenizer.merge_file=/home/TestData/nlp/megatron_gpt/data/gpt/merges.txt \ - model.num_layers=8 \ - model.hidden_size=256 \ - model.num_attention_heads=8 \ - model.data.data_prefix=[.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document,.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document] \ - model.data.index_mapping_dir=examples/nlp/language_modeling/gpt_index_mappings - - rm -rf examples/nlp/language_modeling/gpt_pretrain_results - rm -rf examples/nlp/language_modeling/gpt_index_mappings - - uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main" - if: "failure()" + uses: ./.github/workflows/_test_template.yml + with: + RUNNER: self-hosted-azure + SCRIPT: | + python examples/nlp/language_modeling/megatron_gpt_pretraining.py \ + trainer.devices=2 \ + trainer.accelerator=gpu \ + trainer.log_every_n_steps=1 \ + trainer.val_check_interval=3 \ + trainer.limit_val_batches=2 \ + trainer.accumulate_grad_batches=1 \ + trainer.max_steps=3 \ + trainer.precision=bf16 \ + trainer.gradient_clip_val=1.0 \ + exp_manager.exp_dir=examples/nlp/language_modeling/gpt_pretrain_results \ + model.tensor_model_parallel_size=2 \ + model.megatron_amp_O2=True \ + model.optim.name=distributed_fused_adam \ + model.optim.lr=2e-4 \ + model.optim.sched.warmup_steps=2 \ + model.optim.sched.constant_steps=2 \ + model.optim.sched.min_lr=8e-5 \ + model.max_position_embeddings=128 \ + model.encoder_seq_length=128 \ + model.data.seq_length=128 \ + model.tokenizer.vocab_file=/home/TestData/nlp/megatron_gpt/data/gpt/vocab.json \ + model.tokenizer.merge_file=/home/TestData/nlp/megatron_gpt/data/gpt/merges.txt \ + model.num_layers=8 \ + model.hidden_size=256 \ + model.num_attention_heads=8 \ + model.data.data_prefix=[.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document,.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document] \ + model.data.index_mapping_dir=examples/nlp/language_modeling/gpt_index_mappings + + python examples/nlp/language_modeling/megatron_gpt_pretraining.py \ + trainer.devices=2 \ + trainer.accelerator=gpu \ + trainer.log_every_n_steps=1 \ + trainer.val_check_interval=3 \ + trainer.limit_val_batches=2 \ + trainer.accumulate_grad_batches=1 \ + trainer.max_steps=6 \ + trainer.precision=bf16 \ + trainer.gradient_clip_val=1.0 \ + exp_manager.exp_dir=examples/nlp/language_modeling/gpt_pretrain_results \ + exp_manager.resume_if_exists=True \ + model.reset_lr=True \ + model.tensor_model_parallel_size=2 \ + model.megatron_amp_O2=True \ + model.optim.name=distributed_fused_adam \ + model.optim.lr=2e-4 \ + model.optim.sched.warmup_steps=2 \ + model.optim.sched.constant_steps=2 \ + model.optim.sched.min_lr=8e-5 \ + model.max_position_embeddings=128 \ + model.encoder_seq_length=128 \ + model.data.seq_length=128 \ + model.tokenizer.vocab_file=/home/TestData/nlp/megatron_gpt/data/gpt/vocab.json \ + model.tokenizer.merge_file=/home/TestData/nlp/megatron_gpt/data/gpt/merges.txt \ + model.num_layers=8 \ + model.hidden_size=256 \ + model.num_attention_heads=8 \ + model.data.data_prefix=[.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document,.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document] \ + model.data.index_mapping_dir=examples/nlp/language_modeling/gpt_index_mappings + AFTER_SCRIPT: | + rm -rf examples/nlp/language_modeling/gpt_pretrain_results + rm -rf examples/nlp/language_modeling/gpt_index_mappings L2_Megatron_GPT_with_ALiBi_Pretraining_and_Resume_Training_TP2: needs: [cicd-test-container-setup] - runs-on: self-hosted-azure - timeout-minutes: 10 - container: - image: nemoci.azurecr.io/nemo_container_${{ github.run_id }} - options: - # --user 0:128 - --device=/dev/nvidia0 - --gpus all - --shm-size=8g - --env TRANSFORMERS_OFFLINE=0 - --env HYDRA_FULL_ERROR=1 - --volume /mnt/datadrive/TestData:/home/TestData - steps: - - name: Checkout repository - uses: actions/checkout@v4 - - run: | - python examples/nlp/language_modeling/megatron_gpt_pretraining.py \ - trainer.devices=2 \ - trainer.accelerator=gpu \ - trainer.log_every_n_steps=1 \ - trainer.val_check_interval=2 \ - trainer.limit_val_batches=2 \ - trainer.accumulate_grad_batches=1 \ - trainer.max_steps=3 \ - trainer.gradient_clip_val=1.0 \ - exp_manager.exp_dir=examples/nlp/language_modeling/gpt_pretrain_results \ - model.tensor_model_parallel_size=2 \ - model.optim.name=fused_adam \ - model.optim.lr=2e-4 \ - model.optim.sched.warmup_steps=1 \ - model.optim.sched.constant_steps=1 \ - model.optim.sched.min_lr=8e-5 \ - model.max_position_embeddings=128 \ - model.encoder_seq_length=128 \ - model.data.seq_length=128 \ - model.position_embedding_type=alibi \ - model.bias=False \ - model.bias_activation_fusion=False \ - model.bias_dropout_add_fusion=False \ - model.tokenizer.vocab_file=/home/TestData/nlp/megatron_gpt/data/gpt/vocab.json \ - model.tokenizer.merge_file=/home/TestData/nlp/megatron_gpt/data/gpt/merges.txt \ - model.num_layers=8 \ - model.hidden_size=256 \ - model.num_attention_heads=8 \ - model.activations_checkpoint_method=block \ - model.activations_checkpoint_granularity=full \ - model.activations_checkpoint_num_layers=1 \ - model.data.data_prefix=[.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document,.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document] \ - model.data.index_mapping_dir=examples/nlp/language_modeling/gpt_index_mappings - - # not testing resume functionality to save time on ci @adithyare - #python examples/nlp/language_modeling/megatron_gpt_pretraining.py \ - #trainer.devices=2 \ - #trainer.accelerator=gpu \ - #trainer.log_every_n_steps=1 \ - #trainer.val_check_interval=2 \ - #trainer.limit_val_batches=1 \ - #trainer.accumulate_grad_batches=1 \ - #trainer.max_steps=6 \ - #trainer.gradient_clip_val=1.0 \ - #exp_manager.exp_dir=examples/nlp/language_modeling/gpt_pretrain_results \ - #exp_manager.resume_if_exists=True \ - #model.tensor_model_parallel_size=2 \ - #model.optim.name=fused_adam \ - #model.optim.lr=2e-4 \ - #model.optim.sched.warmup_steps=2 \ - #model.optim.sched.constant_steps=2 \ - #model.optim.sched.min_lr=8e-5 \ - #model.max_position_embeddings=128 \ - #model.encoder_seq_length=128 \ - #model.data.seq_length=128 \ - #model.position_embedding_type=alibi \ - #model.normalization=rmsnorm \ - #model.bias=False \ - #model.bias_activation_fusion=False \ - #model.bias_dropout_add_fusion=False \ - #model.tokenizer.vocab_file=/home/TestData/nlp/megatron_gpt/data/gpt/vocab.json \ - #model.tokenizer.merge_file=/home/TestData/nlp/megatron_gpt/data/gpt/merges.txt \ - #model.num_layers=8 \ - #model.hidden_size=256 \ - #model.num_attention_heads=8 \ - #model.activations_checkpoint_method=block \ - #model.activations_checkpoint_granularity=full \ - #model.activations_checkpoint_num_layers=1 \ - #model.data.data_prefix=[.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document,.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document] \ - #model.data.index_mapping_dir=examples/nlp/language_modeling/gpt_index_mappings" - - rm -rf examples/nlp/language_modeling/gpt_pretrain_results - rm -rf examples/nlp/language_modeling/gpt_index_mappings - - uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main" - if: "failure()" + uses: ./.github/workflows/_test_template.yml + with: + RUNNER: self-hosted-azure + SCRIPT: | + python examples/nlp/language_modeling/megatron_gpt_pretraining.py \ + trainer.devices=2 \ + trainer.accelerator=gpu \ + trainer.log_every_n_steps=1 \ + trainer.val_check_interval=2 \ + trainer.limit_val_batches=2 \ + trainer.accumulate_grad_batches=1 \ + trainer.max_steps=3 \ + trainer.gradient_clip_val=1.0 \ + exp_manager.exp_dir=examples/nlp/language_modeling/gpt_pretrain_results \ + model.tensor_model_parallel_size=2 \ + model.optim.name=fused_adam \ + model.optim.lr=2e-4 \ + model.optim.sched.warmup_steps=1 \ + model.optim.sched.constant_steps=1 \ + model.optim.sched.min_lr=8e-5 \ + model.max_position_embeddings=128 \ + model.encoder_seq_length=128 \ + model.data.seq_length=128 \ + model.position_embedding_type=alibi \ + model.bias=False \ + model.bias_activation_fusion=False \ + model.bias_dropout_add_fusion=False \ + model.tokenizer.vocab_file=/home/TestData/nlp/megatron_gpt/data/gpt/vocab.json \ + model.tokenizer.merge_file=/home/TestData/nlp/megatron_gpt/data/gpt/merges.txt \ + model.num_layers=8 \ + model.hidden_size=256 \ + model.num_attention_heads=8 \ + model.activations_checkpoint_method=block \ + model.activations_checkpoint_granularity=full \ + model.activations_checkpoint_num_layers=1 \ + model.data.data_prefix=[.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document,.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document] \ + model.data.index_mapping_dir=examples/nlp/language_modeling/gpt_index_mappings + + # not testing resume functionality to save time on ci @adithyare + #python examples/nlp/language_modeling/megatron_gpt_pretraining.py \ + #trainer.devices=2 \ + #trainer.accelerator=gpu \ + #trainer.log_every_n_steps=1 \ + #trainer.val_check_interval=2 \ + #trainer.limit_val_batches=1 \ + #trainer.accumulate_grad_batches=1 \ + #trainer.max_steps=6 \ + #trainer.gradient_clip_val=1.0 \ + #exp_manager.exp_dir=examples/nlp/language_modeling/gpt_pretrain_results \ + #exp_manager.resume_if_exists=True \ + #model.tensor_model_parallel_size=2 \ + #model.optim.name=fused_adam \ + #model.optim.lr=2e-4 \ + #model.optim.sched.warmup_steps=2 \ + #model.optim.sched.constant_steps=2 \ + #model.optim.sched.min_lr=8e-5 \ + #model.max_position_embeddings=128 \ + #model.encoder_seq_length=128 \ + #model.data.seq_length=128 \ + #model.position_embedding_type=alibi \ + #model.normalization=rmsnorm \ + #model.bias=False \ + #model.bias_activation_fusion=False \ + #model.bias_dropout_add_fusion=False \ + #model.tokenizer.vocab_file=/home/TestData/nlp/megatron_gpt/data/gpt/vocab.json \ + #model.tokenizer.merge_file=/home/TestData/nlp/megatron_gpt/data/gpt/merges.txt \ + #model.num_layers=8 \ + #model.hidden_size=256 \ + #model.num_attention_heads=8 \ + #model.activations_checkpoint_method=block \ + #model.activations_checkpoint_granularity=full \ + #model.activations_checkpoint_num_layers=1 \ + #model.data.data_prefix=[.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document,.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document] \ + #model.data.index_mapping_dir=examples/nlp/language_modeling/gpt_index_mappings" + AFTER_SCRIPT: | + rm -rf examples/nlp/language_modeling/gpt_pretrain_results + rm -rf examples/nlp/language_modeling/gpt_index_mappings L2_Megatron_GPT_with_KERPLE_Pretraining_and_Resume_Training_TP2: needs: [cicd-test-container-setup] - runs-on: self-hosted-azure - timeout-minutes: 10 - container: - image: nemoci.azurecr.io/nemo_container_${{ github.run_id }} - options: - # --user 0:128 - --device=/dev/nvidia0 - --gpus all - --shm-size=8g - --env TRANSFORMERS_OFFLINE=0 - --env HYDRA_FULL_ERROR=1 - --volume /mnt/datadrive/TestData:/home/TestData - steps: - - name: Checkout repository - uses: actions/checkout@v4 - - run: | - python examples/nlp/language_modeling/megatron_gpt_pretraining.py \ - trainer.devices=2 \ - trainer.accelerator=gpu \ - trainer.log_every_n_steps=1 \ - trainer.val_check_interval=2 \ - trainer.limit_val_batches=2 \ - trainer.accumulate_grad_batches=1 \ - trainer.max_steps=3 \ - trainer.gradient_clip_val=1.0 \ - exp_manager.exp_dir=examples/nlp/language_modeling/gpt_pretrain_results \ - model.tensor_model_parallel_size=2 \ - model.optim.name=fused_adam \ - model.optim.lr=2e-4 \ - model.optim.sched.warmup_steps=1 \ - model.optim.sched.constant_steps=1 \ - model.optim.sched.min_lr=8e-5 \ - model.max_position_embeddings=128 \ - model.encoder_seq_length=128 \ - model.data.seq_length=128 \ - model.position_embedding_type=kerple \ - model.bias=False \ - model.bias_activation_fusion=False \ - model.bias_dropout_add_fusion=False \ - model.tokenizer.vocab_file=/home/TestData/nlp/megatron_gpt/data/gpt/vocab.json \ - model.tokenizer.merge_file=/home/TestData/nlp/megatron_gpt/data/gpt/merges.txt \ - model.num_layers=8 \ - model.hidden_size=256 \ - model.num_attention_heads=8 \ - model.activations_checkpoint_method=block \ - model.activations_checkpoint_granularity=full \ - model.activations_checkpoint_num_layers=1 \ - model.data.data_prefix=[.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document,.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document] \ - model.data.index_mapping_dir=examples/nlp/language_modeling/gpt_index_mappings - - # commented out to save time on github ci @adithyare - #python examples/nlp/language_modeling/megatron_gpt_pretraining.py \ - #trainer.devices=2 \ - #trainer.accelerator=gpu \ - #trainer.log_every_n_steps=1 \ - #trainer.val_check_interval=2 \ - #trainer.limit_val_batches=1 \ - #trainer.accumulate_grad_batches=1 \ - #trainer.max_steps=6 \ - #trainer.precision=16 \ - #trainer.gradient_clip_val=1.0 \ - #exp_manager.exp_dir=examples/nlp/language_modeling/gpt_pretrain_results \ - #exp_manager.resume_if_exists=True \ - #model.tensor_model_parallel_size=2 \ - #model.optim.name=fused_adam \ - #model.optim.lr=2e-4 \ - #model.optim.sched.warmup_steps=2 \ - #model.optim.sched.constant_steps=2 \ - #model.optim.sched.min_lr=8e-5 \ - #model.max_position_embeddings=128 \ - #model.encoder_seq_length=128 \ - #model.data.seq_length=128 \ - #model.position_embedding_type=kerple \ - #model.normalization=rmsnorm \ - #model.bias=False \ - #model.bias_activation_fusion=False \ - #model.bias_dropout_add_fusion=False \ - #model.tokenizer.vocab_file=/home/TestData/nlp/megatron_gpt/data/gpt/vocab.json \ - #model.tokenizer.merge_file=/home/TestData/nlp/megatron_gpt/data/gpt/merges.txt \ - #model.num_layers=8 \ - #model.hidden_size=256 \ - #model.num_attention_heads=8 \ - #model.activations_checkpoint_method=block \ - #model.activations_checkpoint_granularity=full \ - #model.activations_checkpoint_num_layers=1 \ - #model.data.data_prefix=[.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document,.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document] \ - #model.data.index_mapping_dir=examples/nlp/language_modeling/gpt_index_mappings" - - rm -rf examples/nlp/language_modeling/gpt_pretrain_results - rm -rf examples/nlp/language_modeling/gpt_index_mappings - - uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main" - if: "failure()" + uses: ./.github/workflows/_test_template.yml + with: + RUNNER: self-hosted-azure + SCRIPT: | + python examples/nlp/language_modeling/megatron_gpt_pretraining.py \ + trainer.devices=2 \ + trainer.accelerator=gpu \ + trainer.log_every_n_steps=1 \ + trainer.val_check_interval=2 \ + trainer.limit_val_batches=2 \ + trainer.accumulate_grad_batches=1 \ + trainer.max_steps=3 \ + trainer.gradient_clip_val=1.0 \ + exp_manager.exp_dir=examples/nlp/language_modeling/gpt_pretrain_results \ + model.tensor_model_parallel_size=2 \ + model.optim.name=fused_adam \ + model.optim.lr=2e-4 \ + model.optim.sched.warmup_steps=1 \ + model.optim.sched.constant_steps=1 \ + model.optim.sched.min_lr=8e-5 \ + model.max_position_embeddings=128 \ + model.encoder_seq_length=128 \ + model.data.seq_length=128 \ + model.position_embedding_type=kerple \ + model.bias=False \ + model.bias_activation_fusion=False \ + model.bias_dropout_add_fusion=False \ + model.tokenizer.vocab_file=/home/TestData/nlp/megatron_gpt/data/gpt/vocab.json \ + model.tokenizer.merge_file=/home/TestData/nlp/megatron_gpt/data/gpt/merges.txt \ + model.num_layers=8 \ + model.hidden_size=256 \ + model.num_attention_heads=8 \ + model.activations_checkpoint_method=block \ + model.activations_checkpoint_granularity=full \ + model.activations_checkpoint_num_layers=1 \ + model.data.data_prefix=[.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document,.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document] \ + model.data.index_mapping_dir=examples/nlp/language_modeling/gpt_index_mappings + + # commented out to save time on github ci @adithyare + #python examples/nlp/language_modeling/megatron_gpt_pretraining.py \ + #trainer.devices=2 \ + #trainer.accelerator=gpu \ + #trainer.log_every_n_steps=1 \ + #trainer.val_check_interval=2 \ + #trainer.limit_val_batches=1 \ + #trainer.accumulate_grad_batches=1 \ + #trainer.max_steps=6 \ + #trainer.precision=16 \ + #trainer.gradient_clip_val=1.0 \ + #exp_manager.exp_dir=examples/nlp/language_modeling/gpt_pretrain_results \ + #exp_manager.resume_if_exists=True \ + #model.tensor_model_parallel_size=2 \ + #model.optim.name=fused_adam \ + #model.optim.lr=2e-4 \ + #model.optim.sched.warmup_steps=2 \ + #model.optim.sched.constant_steps=2 \ + #model.optim.sched.min_lr=8e-5 \ + #model.max_position_embeddings=128 \ + #model.encoder_seq_length=128 \ + #model.data.seq_length=128 \ + #model.position_embedding_type=kerple \ + #model.normalization=rmsnorm \ + #model.bias=False \ + #model.bias_activation_fusion=False \ + #model.bias_dropout_add_fusion=False \ + #model.tokenizer.vocab_file=/home/TestData/nlp/megatron_gpt/data/gpt/vocab.json \ + #model.tokenizer.merge_file=/home/TestData/nlp/megatron_gpt/data/gpt/merges.txt \ + #model.num_layers=8 \ + #model.hidden_size=256 \ + #model.num_attention_heads=8 \ + #model.activations_checkpoint_method=block \ + #model.activations_checkpoint_granularity=full \ + #model.activations_checkpoint_num_layers=1 \ + #model.data.data_prefix=[.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document,.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document] \ + #model.data.index_mapping_dir=examples/nlp/language_modeling/gpt_index_mappings" + AFTER_SCRIPT: | + rm -rf examples/nlp/language_modeling/gpt_pretrain_results + rm -rf examples/nlp/language_modeling/gpt_index_mappings L2_Megatron_GPT_Pretraining_and_Resume_Training_PP2: needs: [cicd-test-container-setup] uses: ./.github/workflows/_test_template.yml with: - RUNNER: self-hosted-azure + RUNNER: self-hosted-azure-gpus-2-h100 SCRIPT: | python examples/nlp/language_modeling/megatron_gpt_pretraining.py \ trainer.devices=2 \ + trainer.accelerator=gpu \ trainer.log_every_n_steps=1 \ trainer.val_check_interval=2 \ trainer.limit_val_batches=2 \ @@ -2957,6 +2903,15 @@ jobs: trainer.precision=bf16 \ trainer.gradient_clip_val=1.0 \ exp_manager.exp_dir=examples/nlp/language_modeling/gpt_pretrain_results \ + ++model.transformer_engine=True \ + ++model.fp8=True \ + ++model.fp8_hybrid=True \ + ++model.fp8_amax_history_len=1024 \ + ++model.fp8_amax_compute_algo=max \ + ++model.reduce_amax=True \ + ++model.use_te_rng_tracker=True \ + ++model.name=megatron_gpt_full_te_layer_autocast \ + model.ub_tp_comm_overlap=False \ model.pipeline_model_parallel_size=2 \ model.tensor_model_parallel_size=1 \ model.mcore_gpt=True \ @@ -2981,12 +2936,15 @@ jobs: model.hidden_size=256 \ model.num_attention_heads=8 \ model.activations_checkpoint_method=block \ + model.activations_checkpoint_granularity=full \ model.activations_checkpoint_num_layers=1 \ + model.data.validation_drop_last=False \ model.data.data_prefix=[.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document,.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document] \ model.data.index_mapping_dir=examples/nlp/language_modeling/gpt_index_mappings python examples/nlp/language_modeling/megatron_gpt_pretraining.py \ trainer.devices=2 \ + trainer.accelerator=gpu \ trainer.log_every_n_steps=1 \ trainer.val_check_interval=2 \ trainer.limit_val_batches=2 \ @@ -2998,6 +2956,15 @@ jobs: model.megatron_amp_O2=True \ exp_manager.exp_dir=examples/nlp/language_modeling/gpt_pretrain_results \ exp_manager.resume_if_exists=True \ + ++model.transformer_engine=True \ + ++model.fp8=True \ + ++model.fp8_hybrid=True \ + ++model.fp8_amax_history_len=1024 \ + ++model.fp8_amax_compute_algo=max \ + ++model.reduce_amax=True \ + ++model.use_te_rng_tracker=True \ + ++model.name=megatron_gpt_full_te_layer_autocast \ + model.ub_tp_comm_overlap=False \ model.pipeline_model_parallel_size=2 \ model.tensor_model_parallel_size=1 \ model.optim.name=distributed_fused_adam \ @@ -3020,7 +2987,9 @@ jobs: model.hidden_size=256 \ model.num_attention_heads=8 \ model.activations_checkpoint_method=block \ + model.activations_checkpoint_granularity=full \ model.activations_checkpoint_num_layers=1 \ + model.data.validation_drop_last=False \ model.data.data_prefix=[.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document,.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document] \ model.data.index_mapping_dir=examples/nlp/language_modeling/gpt_index_mappings AFTER_SCRIPT: | @@ -3097,50 +3066,62 @@ jobs: L2_Megatron_GPT_Finetuning_StarCoder_PP1: needs: [cicd-test-container-setup] - runs-on: self-hosted-azure-gpus-1 - timeout-minutes: 10 - container: - image: nemoci.azurecr.io/nemo_container_${{ github.run_id }} - options: - # --user 0:128 - --device=/dev/nvidia0 - --gpus all - --shm-size=8g - --env TRANSFORMERS_OFFLINE=0 - --env HYDRA_FULL_ERROR=1 - --volume /mnt/datadrive/TestData:/home/TestData - steps: - - name: Checkout repository - uses: actions/checkout@v4 - - run: | - python examples/nlp/language_modeling/tuning/megatron_gpt_finetuning.py \ - trainer.devices=1 \ - trainer.num_nodes=1 \ - trainer.precision=bf16 \ - trainer.max_steps=4 \ - trainer.val_check_interval=4 \ - trainer.enable_checkpointing=False \ - +trainer.limit_val_batches=2 \ - +trainer.limit_test_batches=2 \ - exp_manager.checkpoint_callback_params.save_best_model=False \ - exp_manager.exp_dir=examples/nlp/language_modeling/gpt_sft_results \ - model.peft.peft_scheme=none \ - model.optim.name=distributed_fused_adam \ - model.restore_from_path=/home/TestData/nlp/megatron_gpt/starcoder-ci-nemo/megatron_starcoder_tp1_pp1.nemo \ - model.tensor_model_parallel_size=1 \ - model.pipeline_model_parallel_size=1 \ - model.data.train_ds.file_names=[/home/TestData/nlp/megatron_sft/quarel.jsonl] \ - model.data.train_ds.num_workers=0 \ - model.data.test_ds.file_names=[/home/TestData/nlp/megatron_sft/quarel.jsonl] \ - model.data.validation_ds.num_workers=0 \ - model.data.validation_ds.file_names=[/home/TestData/nlp/megatron_sft/quarel.jsonl] \ - model.data.test_ds.num_workers=0 \ - model.data.train_ds.concat_sampling_probabilities=[1.0] - - rm -rf examples/nlp/language_modeling/gpt_sft_results - - uses: "NVIDIA/NeMo/.github/actions/cancel-workflow@main" - if: "failure()" + uses: ./.github/workflows/_test_template.yml + with: + RUNNER: self-hosted-azure-gpus-1 + SCRIPT: | + python examples/nlp/language_modeling/tuning/megatron_gpt_finetuning.py \ + trainer.devices=1 \ + trainer.num_nodes=1 \ + trainer.precision=bf16 \ + trainer.max_steps=4 \ + trainer.val_check_interval=4 \ + trainer.enable_checkpointing=False \ + +trainer.limit_val_batches=2 \ + +trainer.limit_test_batches=2 \ + exp_manager.checkpoint_callback_params.save_best_model=False \ + exp_manager.exp_dir=examples/nlp/language_modeling/gpt_sft_results \ + model.peft.peft_scheme=none \ + model.optim.name=distributed_fused_adam \ + model.restore_from_path=/home/TestData/nlp/megatron_gpt/starcoder-ci-nemo/megatron_starcoder_tp1_pp1.nemo \ + model.tensor_model_parallel_size=1 \ + model.pipeline_model_parallel_size=1 \ + model.data.train_ds.file_names=[/home/TestData/nlp/megatron_sft/quarel.jsonl] \ + model.data.train_ds.num_workers=0 \ + model.data.test_ds.file_names=[/home/TestData/nlp/megatron_sft/quarel.jsonl] \ + model.data.validation_ds.num_workers=0 \ + model.data.validation_ds.file_names=[/home/TestData/nlp/megatron_sft/quarel.jsonl] \ + model.data.test_ds.num_workers=0 \ + model.data.train_ds.concat_sampling_probabilities=[1.0] + AFTER_SCRIPT: | + rm -rf examples/nlp/language_modeling/gpt_sft_results + L2_Megatron_GPT_Reranker: + needs: [cicd-test-container-setup] + uses: ./.github/workflows/_test_template.yml + with: + RUNNER: self-hosted-azure + SCRIPT: | + rm -rf /home/TestData/nlp/megatron_ir/working_dir + + python examples/nlp/information_retrieval/megatron_gpt_reranker_finetuning.py \ + exp_manager.exp_dir='/home/TestData/nlp/megatron_ir/working_dir' \ + model.global_batch_size=4 \ + model.micro_batch_size=4 \ + trainer.devices=1 \ + trainer.num_nodes=1 \ + trainer.max_epochs=null \ + trainer.max_steps=20 \ + trainer.val_check_interval=10 \ + model.restore_from_path='/home/TestData/nlp/megatron_gpt/mcore_45M/megatron_llama.nemo' \ + model.peft.lora_tuning.adapter_dim=8 \ + model.data.validation_ds.file_names=[/home/TestData/nlp/megatron_ir/train.jsonl] \ + model.data.validation_ds.write_embeddings_to_file=True \ + model.data.validation_ds.output_file_path_prefix='/home/TestData/nlp/megatron_ir/working_dir/val_embs' \ + model.data.train_ds.file_names=[/home/TestData/nlp/megatron_ir/train.jsonl] + AFTER_SCRIPT: | + rm -rf /home/TestData/nlp/megatron_ir/working_dir + L2_Megatron_GPT_Embedding: needs: [cicd-test-container-setup] uses: ./.github/workflows/_test_template.yml @@ -3283,6 +3264,62 @@ jobs: AFTER_SCRIPT: | rm -rf /home/TestData/nlp/lora_tuning_tp2 + L2_Megatron_GPT_PEFT_Lora_TP2SP1: + needs: [cicd-test-container-setup] + uses: ./.github/workflows/_test_template.yml + with: + RUNNER: self-hosted-azure-gpus-2-h100 + SCRIPT: | + rm -rf /home/TestData/nlp/lora_tuning_tp2_sp1 + + CUDA_DEVICE_MAX_CONNECTIONS=1 NVTE_FLASH_ATTN=0 NVTE_FUSED_ATTN=1 python examples/nlp/language_modeling/tuning/megatron_gpt_finetuning.py \ + trainer.devices=2 \ + trainer.log_every_n_steps=1 \ + trainer.max_epochs=9999 \ + trainer.max_steps=3 \ + trainer.val_check_interval=3 \ + ++trainer.limit_val_batches=2 \ + trainer.precision=bf16 \ + exp_manager.exp_dir=/home/TestData/nlp/lora_tuning_tp2_sp1 \ + +model.mcore_gpt=True \ + model.pipeline_model_parallel_size=1 \ + model.tensor_model_parallel_size=2 \ + model.sequence_parallel=True \ + model.megatron_amp_O2=True \ + model.restore_from_path=/home/TestData/nlp/megatron_gpt/mcore_45M/megatron_llama.nemo \ + +model.fp8=True \ + +model.fp8_params=True \ + +model.fp8_hybrid=True \ + +model.fp8_e4m3=False \ + +model.fp8_interval=1 \ + +model.fp8_margin=0 \ + +model.fp8_amax_history_len=32 \ + +model.fp8_amax_compute_algo=max \ + +model.reduce_amax=False \ + +model.ub_tp_comm_overlap=False \ + +model.tp_comm_overlap_ag=False \ + +model.tp_comm_overlap_rs=False \ + +model.tp_comm_overlap_disable_qkv=True \ + model.peft.peft_scheme='lora' \ + model.peft.lora_tuning.adapter_dim=16 \ + model.peft.lora_tuning.alpha=32 \ + model.peft.lora_tuning.column_init_method="kaiming" \ + +model.peft.lora_tuning.dropout_position='pre' \ + model.peft.lora_tuning.target_modules=['attention'] \ + model.peft.lora_tuning.adapter_dropout=0.1 \ + +model.peft.lora_tuning.a2a_experimental=1 \ + model.answer_only_loss=True \ + model.micro_batch_size=1 \ + model.global_batch_size=1 \ + model.data.train_ds.file_names=[/home/TestData/nlp/megatron_sft/quarel.jsonl] \ + model.data.train_ds.concat_sampling_probabilities=[1.0] \ + model.data.train_ds.num_workers=0 \ + model.data.validation_ds.num_workers=0 \ + model.data.validation_ds.file_names=[/home/TestData/nlp/megatron_sft/quarel.jsonl] \ + model.data.validation_ds.names=[quarel] + AFTER_SCRIPT: | + rm -rf /home/TestData/nlp/lora_tuning_tp2_sp1 + L2_Megatron_GPT_Eval: needs: [cicd-test-container-setup] uses: ./.github/workflows/_test_template.yml @@ -3408,7 +3445,8 @@ jobs: trainer.limit_val_batches=2 \ trainer.accumulate_grad_batches=1 \ trainer.max_steps=10 \ - trainer.precision=16 \ + trainer.precision=bf16 \ + model.megatron_amp_O2=True \ trainer.gradient_clip_val=1.0 \ exp_manager.exp_dir=examples/nlp/language_modeling/t5_pretrain_results \ model.tensor_model_parallel_size=2 \ @@ -3450,7 +3488,8 @@ jobs: trainer.limit_val_batches=2 \ trainer.accumulate_grad_batches=1 \ trainer.max_steps=10 \ - trainer.precision=16 \ + trainer.precision=bf16 \ + model.megatron_amp_O2=True \ trainer.gradient_clip_val=1.0 \ exp_manager.exp_dir=examples/nlp/language_modeling/t5_pretrain_results \ exp_manager.resume_if_exists=True \ @@ -3835,7 +3874,7 @@ jobs: trainer.precision=16 \ trainer.gradient_clip_val=1.0 \ exp_manager.exp_dir=examples/nlp/language_modeling/t5_pretrain_results \ - model.pipeline_model_parallel_split_rank=1 \ + model.pipeline_model_parallel_split_rank=0 \ model.seq_length=256 \ model.encoder.num_layers=4 \ model.decoder.num_layers=1 \ @@ -3948,6 +3987,17 @@ jobs: --prompt 'How do I fix my GPU memory issue? I am seeing out of memory.' \ --tensor_model_parallel_size 1 + L2_Megatron_Core_T5_Eval: + needs: [cicd-test-container-setup] + uses: ./.github/workflows/_test_template.yml + with: + RUNNER: self-hosted-azure + SCRIPT: | + NVTE_FLASH_ATTN=0 NVTE_FUSED_ATTN=0 python examples/nlp/language_modeling/megatron_t5_eval.py \ + --model_file /home/TestData/nlp/megatron_t5/220m/megatron_mcore_t5_220m.nemo \ + --prompt 'How do I fix my GPU memory issue? I am seeing out of memory.' \ + --tensor_model_parallel_size 1 + L2_Megatron_BART_Pretraining_and_Resume_Training_TP2: needs: [cicd-test-container-setup] uses: ./.github/workflows/_test_template.yml @@ -4138,6 +4188,57 @@ jobs: AFTER_SCRIPT: | rm -rf /home/TestData/nlp/t5_lora_tuning_tp2 + L2_Megatron_Core_T5_PEFT_Lora_TP2: + needs: [cicd-test-container-setup] + uses: ./.github/workflows/_test_template.yml + with: + RUNNER: self-hosted-azure + SCRIPT: | + rm -rf /home/TestData/nlp/mcore_t5_lora_tuning_tp2 + + NVTE_FUSED_ATTN=0 NVTE_FLASH_ATTN=0 python examples/nlp/language_modeling/tuning/megatron_t5_finetuning.py \ + trainer.devices=2 \ + trainer.log_every_n_steps=1 \ + trainer.max_epochs=9999 \ + trainer.max_steps=3 \ + trainer.val_check_interval=3 \ + ++trainer.limit_val_batches=2 \ + trainer.precision=16 \ + exp_manager.exp_dir=/home/TestData/nlp/mcore_t5_lora_tuning_tp2 \ + model.pipeline_model_parallel_size=1 \ + model.tensor_model_parallel_size=2 \ + model.restore_from_path=/home/TestData/nlp/megatron_t5/220m/megatron_mcore_t5_220m.nemo \ + model.peft.peft_scheme=lora \ + model.answer_only_loss=True \ + model.micro_batch_size=1 \ + model.global_batch_size=1 \ + model.data.train_ds.file_names=[/home/TestData/nlp/megatron_sft/quarel.jsonl] \ + model.data.train_ds.concat_sampling_probabilities=[1.0] \ + model.data.train_ds.num_workers=0 \ + model.data.validation_ds.num_workers=0 \ + model.data.validation_ds.file_names=[/home/TestData/nlp/megatron_sft/quarel.jsonl] \ + model.data.validation_ds.names=[quarel] + + NVTE_FUSED_ATTN=0 NVTE_FLASH_ATTN=0 python examples/nlp/language_modeling/tuning/megatron_t5_generate.py \ + model.restore_from_path=/home/TestData/nlp/megatron_t5/220m/megatron_mcore_t5_220m.nemo \ + model.peft.restore_from_path=/home/TestData/nlp/mcore_t5_lora_tuning_tp2/megatron_t5_peft_lora_tuning/checkpoints/megatron_t5_peft_lora_tuning.nemo \ + model.peft.restore_from_ckpt_name=null \ + model.peft.restore_from_hparams_path=null \ + model.tensor_model_parallel_size=2 \ + trainer.devices=2 \ + model.data.test_ds.file_names=[/home/TestData/nlp/megatron_sft/quarel_4.jsonl] \ + model.data.test_ds.names=[quarel4] \ + model.global_batch_size=1 \ + model.micro_batch_size=1 \ + model.data.test_ds.tokens_to_generate=10 \ + model.data.test_ds.write_predictions_to_file=True \ + model.data.test_ds.output_file_path_prefix=/home/TestData/nlp/mcore_t5_lora_tuning_tp2/out \ + inference.greedy=True \ + inference.repetition_penalty=1.0 \ + inference.outfile_path=/home/TestData/nlp/mcore_t5_lora_tuning_tp2/out.jsonl + AFTER_SCRIPT: | + rm -rf /home/TestData/nlp/mcore_t5_lora_tuning_tp2 + # L2: Megatron Mock Data Generation L2_Megatron_Mock_Data_Generation_MockGPTDataset: needs: [cicd-test-container-setup] @@ -4430,13 +4531,47 @@ jobs: AFTER_SCRIPT: | rm -rf examples/multimodal/text_to_image/sd_train_results + L2_NeMo_2_GPT_Pretraining_no_transformer_engine: + needs: [cicd-test-container-setup] + uses: ./.github/workflows/_test_template.yml + with: + RUNNER: self-hosted-azure + SCRIPT: | + pip uninstall -y apex ## TODO: remove when apex is no longer a dependency + pip uninstall -y transformer_engine + + python examples/llm/megatron_gpt_pretraining.py \ + --devices=2 \ + --max-steps=3 \ + --experiment-dir=examples/llm/gpt_pretrain_results \ + --vocab-path=/home/TestData/nlp/megatron_gpt/data/gpt/vocab.json \ + --merges-path=/home/TestData/nlp/megatron_gpt/data/gpt/merges.txt \ + --data-path=/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document \ + --index-mapping-dir=examples/llm/gpt_index_mappings + + python examples/llm/megatron_gpt_pretraining.py \ + --devices=2 \ + --max-steps=6 \ + --experiment-dir=examples/llm/gpt_pretrain_results \ + --vocab-path=/home/TestData/nlp/megatron_gpt/data/gpt/vocab.json \ + --merges-path=/home/TestData/nlp/megatron_gpt/data/gpt/merges.txt \ + --data-path=/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document \ + --index-mapping-dir=examples/llm/gpt_index_mappings + AFTER_SCRIPT: | + rm -rf examples/llm/gpt_pretrain_results + rm -rf examples/llm/gpt_index_mappings + Nemo_CICD_Test: needs: + - gpu-test + - cicd-test-container-setup - L0_Unit_Tests_GPU - L0_Unit_Tests_CPU + - L2_Community_LLM_Checkpoints_tests_Bert - L2_Community_LLM_Checkpoints_tests_Llama - L2_Community_LLM_Checkpoints_tests_StarCoder - L2_Community_LLM_Checkpoints_tests_Falcon + - L2_Community_vita_Checkpoints_tests_Llama3 #- OPTIONAL_L2_Community_LLM_Checkpoints_tests_Baichuan2 - ASR_dev_run_Speech_to_Text - ASR_dev_run_Speech_to_Text_WPE_-_CitriNet @@ -4501,6 +4636,7 @@ jobs: - L2_Megatron_GPT_Embedding - L2_Megatron_GPT_PEFT_Lora_PP2_O2 - L2_Megatron_GPT_PEFT_Lora_TP2_O1 + - L2_Megatron_GPT_PEFT_Lora_TP2SP1 - L2_Megatron_GPT_Eval - L2_Megatron_GPT_Eval_PP2 - L2_Megatron_GPT_SFT_Eval_inference_seq_len_greaterThan_training_seq_len @@ -4514,9 +4650,11 @@ jobs: - L2_Megatron_T5_w_Mixture_of_Expert_Pretraining - L2_Megatron_UL2_Pretraining_and_Resume_Training_TP2 - L2_Megatron_T5_Eval + - L2_Megatron_Core_T5_Eval - L2_Megatron_BART_Pretraining_and_Resume_Training_TP2 - L2_Megatron_BART_Pretraining_and_Resume_Training_PP2 - L2_Megatron_T5_PEFT_Lora_TP2 + - L2_Megatron_Core_T5_PEFT_Lora_TP2 - L2_Megatron_Mock_Data_Generation_MockGPTDataset - L2_Megatron_Mock_Data_Generation_MockT5Dataset - L2_TTS_Fast_dev_runs_1_Tacotron_2 @@ -4527,6 +4665,7 @@ jobs: - L2_TTS_Fast_dev_runs_1_Hifigan - Speech_Checkpoints_tests - L2_Stable_Diffusion_Training + - L2_NeMo_2_GPT_Pretraining_no_transformer_engine if: always() runs-on: ubuntu-latest steps: diff --git a/.github/workflows/config/changelog-config.json b/.github/workflows/config/changelog-config.json index fe18f8ac0681..40c98c0a571b 100644 --- a/.github/workflows/config/changelog-config.json +++ b/.github/workflows/config/changelog-config.json @@ -1,47 +1,47 @@ { "categories": [ { - "title": "## ASR \n\n
Changelog\n\n
\n\n", + "title": "## ASR\n\n
Changelog", "labels": ["asr"], "exclude_labels": ["cherry-pick"] }, { - "title": "## TTS \n\n
Changelog\n\n
\n\n", + "title": "
\n\n## TTS\n\n
Changelog", "labels": ["tts"], "exclude_labels": ["cherry-pick"] }, { - "title": "## NLP / NMT \n\n
Changelog\n\n
\n\n", + "title": "
\n\n## NLP / NMT\n\n
Changelog", "labels": ["nlp", "nmt", "megatron"], "exclude_labels": ["cherry-pick"] }, { - "title": "## Text Normalization / Inverse Text Normalization \n\n
Changelog\n\n
\n\n", + "title": "
\n\n## Text Normalization / Inverse Text Normalization\n\n
Changelog", "labels": ["tn", "itn"], "exclude_labels": ["cherry-pick"] }, { - "title": "## NeMo Tools \n\n
Changelog\n\n
\n\n", + "title": "
\n\n## NeMo Tools\n\n
Changelog", "labels": ["tools"], "exclude_labels": ["cherry-pick"] }, { - "title": "## Export \n\n
Changelog\n\n
\n\n", + "title": "
\n\n## Export\n\n
Changelog", "labels": ["export"], "exclude_labels": ["cherry-pick"] }, { - "title": "## Documentation \n\n
Changelog\n\n
\n\n", + "title": "
\n\n## Documentation\n\n
Changelog", "labels": ["docs"], "exclude_labels": ["cherry-pick"] }, { - "title": "## Bugfixes \n\n
Changelog\n\n
\n\n", + "title": "
\n\n## Bugfixes\n\n
Changelog", "labels": ["bug"], "exclude_labels": ["cherry-pick"] }, { - "title": "## Cherrypick \n\n
Changelog\n\n
\n\n", + "title": "
\n\n## Cherrypick\n\n
Changelog", "labels": ["cherry-pick"], "exclude_labels": ["cherry-pick"] } @@ -50,7 +50,7 @@ "ignore" ], "sort": "ASC", - "template": "\n${{CHANGELOG}}\nUncategorized:\n${{UNCATEGORIZED}}\n\n", + "template": "\n${{CHANGELOG}}
\n\n## Uncategorized:\n\n
Changelog\n\n${{UNCATEGORIZED}}\n
\n", "pr_template": "- ${{TITLE}} by @${{AUTHOR}} :: PR: #${{NUMBER}}", "empty_template": "${{OWNER}}\n${{REPO}}\n${{FROM_TAG}}\n${{TO_TAG}}", "label_extractor": [ diff --git a/.github/workflows/import-test.yml b/.github/workflows/import-test.yml index 6f2f52bfb0ae..3af15294b2a2 100644 --- a/.github/workflows/import-test.yml +++ b/.github/workflows/import-test.yml @@ -12,7 +12,7 @@ jobs: test-asr-imports: runs-on: ubuntu-latest container: - image: pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime + image: pytorch/pytorch:2.4.0-cuda11.8-cudnn9-runtime steps: - name: Checkout repo uses: actions/checkout@v2 @@ -43,7 +43,7 @@ jobs: test-tts-imports: runs-on: ubuntu-latest container: - image: pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime + image: pytorch/pytorch:2.4.0-cuda11.8-cudnn9-runtime steps: - name: Checkout repo uses: actions/checkout@v2 @@ -70,4 +70,4 @@ jobs: # Run import checks python tests/core_ptl/check_imports.py --domain "tts" # Uninstall NeMo - pip uninstall -y nemo_toolkit \ No newline at end of file + pip uninstall -y nemo_toolkit diff --git a/.github/workflows/mcore-tag-bump-bot.yml b/.github/workflows/mcore-tag-bump-bot.yml new file mode 100644 index 000000000000..13f4059a3a6b --- /dev/null +++ b/.github/workflows/mcore-tag-bump-bot.yml @@ -0,0 +1,59 @@ +# Regularly updates the CI container +name: MCore Tag Bump Bot +on: + workflow_dispatch: + schedule: + - cron: 0 0 * * * + +jobs: + main: + runs-on: ubuntu-latest + environment: main + steps: + - name: Checkout NVIDIA/Megatron-LM + uses: actions/checkout@v4 + with: + repository: NVIDIA/Megatron-LM + ref: main + path: ${{ github.run_id }} + + - name: Get latest mcore commit + id: ref + run: | + cd ${{ github.run_id }} + sha=$(git rev-parse origin/main) + echo "sha=${sha}" >> "$GITHUB_OUTPUT" + echo "short_sha=${sha:0:7}" >> "$GITHUB_OUTPUT" + echo "date=$(date +%F)" >> "$GITHUB_OUTPUT" + + - name: Checkout ${{ github.repository }} + uses: actions/checkout@v4 + with: + path: ${{ github.run_id }} + token: ${{ secrets.PAT }} + + - name: Bump MCORE_TAG + run: | + cd ${{ github.run_id }} + sed -i 's/^ARG MCORE_TAG=.*$/ARG MCORE_TAG=${{ steps.ref.outputs.sha }}/' Dockerfile.ci + + - name: Create Bump PR + uses: peter-evans/create-pull-request@v6 + id: create-pull-request + with: + path: ${{ github.run_id }} + branch: bump-ci-container-${{ steps.ref.outputs.date }} + base: main + title: 'Bump `Dockerfile.ci` (${{ steps.ref.outputs.date }})' + token: ${{ secrets.PAT }} + body: | + 🚀 PR to Bump `Dockerfile.ci`. + + 📝 Please remember the following to-do's before merge: + - [ ] Verify the presubmit CI + + 🙏 Please merge this PR only if the CI workflow completed successfully. + commit-message: "[🤠]: Howdy folks, let's bump `Dockerfile.ci` to ${{ steps.ref.outputs.short_sha }} !" + signoff: true + reviewers: 'pablo-garay' + labels: 'Run CICD' diff --git a/.github/workflows/release-freeze.yml b/.github/workflows/release-freeze.yml new file mode 100644 index 000000000000..f8d037271f36 --- /dev/null +++ b/.github/workflows/release-freeze.yml @@ -0,0 +1,192 @@ +name: "NeMo Code freeze" + +on: + workflow_dispatch: + inputs: + next_version: + description: 'MAJOR.MINOR.PATCH[rcN] (Example: 2.0.0rc1, or 2.1.0)' + required: true + type: string + mcore_version: + description: 'Version of MCore to use (must be a valid git ref)' + required: true + type: string +jobs: + create-release-branch: + runs-on: ubuntu-latest + if: contains(fromJSON('["ko3n1g"]'), github.actor) + environment: + name: main + outputs: + version: ${{ steps.release-branch.outputs.version }} + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + path: ${{ github.run_id }} + fetch-depth: 0 + fetch-tags: true + ref: main + + - name: Get Previous tag + id: previous-tag + # git for-each-ref --sort=-creatordate --format '%(refname)' refs/tags ==> refs/tags/vX.Y.Z in descending order of date + # awk 'FNR == 2 {print substr($1, 11, length($1))}') ==> Selects the 2nd tag from the list, then strips the /refs/tags/ part of the tag + # set-output name=tag_name:: ==> Takes the clean tag vX.Y.Z and sets it to steps.previous_tag.outputs.tag_name + run: | + TAG=$(git for-each-ref --sort=-creatordate --format '%(refname)' refs/tags | awk 'FNR == 2 {print substr($1, 11, length($1))}') + echo "tag-name=$TAG" >> "$GITHUB_OUTPUT" + + - name: Get release branch ref + id: release-branch + run: | + cd ${{ github.run_id }} + + VERSION=$(python -c 'import nemo; print(nemo.__version__)') + echo "Release version r$VERSION" > version + echo "version=$VERSION" >> "$GITHUB_OUTPUT" + + - name: Pin branch name in Notebooks + run: | + cd ${{ github.run_id }} + find tutorials -type f -name "*.ipynb" -exec sed -i "s/BRANCH = 'main'/BRANCH = 'r${{ steps.release-branch.outputs.version }}'/g" {} + + + - name: Pin MCore in Dockerfile + run: | + cd ${{ github.run_id }} + sed -i 's/^ARG MCORE_TAG=.*$/ARG MCORE_TAG=${{ inputs.mcore_version }}/' Dockerfile.ci + + - name: Build Changelog + id: build-changelog + uses: mikepenz/release-changelog-builder-action@v3.3.1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + with: + # Configuration file is setup with filters for domains + # owner:repo must point to current repo + # fromTag: Auto resolved from historical tag order (previous tag compared to current tag) + # toTag: Current tag reference + configuration: ".github/workflows/config/changelog-config.json" + owner: ${{ github.repository_owner }} + repo: ${{ github.event.repository.name }} + ignorePreReleases: "false" + failOnError: "false" + fromTag: ${{ steps.previous-tag.outputs.tag-name }} + toTag: main + + - name: Append Changelog + run: | + echo "${{ steps.build-changelog.outputs.changelog }}" + + - name: Create Release PR + uses: peter-evans/create-pull-request@v6 + id: create-pull-request + with: + path: ${{ github.run_id }} + branch: r${{ steps.release-branch.outputs.version }} + title: 'Release `${{ steps.release-branch.outputs.version }}`' + body: | + 🚀 PR to release NeMo `${{ steps.release-branch.outputs.version }}`. + + 📝 Please remember the following to-do's before merge: + - [ ] Fill-in the comment `Highlights` + - [ ] Review the comment `Detailed Changelogs` + + 🚨 Please also keep in mind to _not_ delete the headings of the task commits. They are required by the post-merge automation. + + 🙏 Please merge this PR only if the CI workflow completed successfully. + + commit-message: "[🤠]: Howdy folks, let's release NeMo `${{ steps.release-branch.outputs.version }}` !" + signoff: true + assignees: okoenig + labels: 'Run CICD' + + - name: Add Summary comment + uses: peter-evans/create-or-update-comment@v4 + with: + issue-number: ${{ steps.create-pull-request.outputs.pull-request-number }} + body: | + # Highlights + __ + + - name: Add Changelog comment + uses: peter-evans/create-or-update-comment@v4 + with: + issue-number: ${{ steps.create-pull-request.outputs.pull-request-number }} + body: | + # Detailed Changelogs + ${{ steps.build-changelog.outputs.changelog }} + + bump-next-version: + runs-on: ubuntu-latest + needs: [create-release-branch] + environment: + name: main + env: + VERSION_FILE: nemo/package_info.py + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + path: ${{ github.run_id }} + fetch-depth: 0 + fetch-tags: true + ref: main + token: ${{ secrets.PAT }} + + - name: Bump version + id: bump-version + run: | + cd ${{ github.run_id }} + FULL_VERSION_NUM=${{ inputs.next_version }} + VERSION=${FULL_VERSION_NUM%%rc*} + MAJOR=$(echo "$VERSION" | cut -d. -f1) + MINOR=$(echo "$VERSION" | cut -d. -f2) + PATCH=$(echo "$VERSION" | cut -d. -f3) + PRE_RELEASE=${FULL_VERSION_NUM#$VERSION} + + sed -i 's/^MAJOR\s*=\s*[0-9]\+/MAJOR = '$MAJOR'/' $VERSION_FILE + sed -i 's/^MINOR\s*=\s*[0-9]\+/MINOR = '$MINOR'/' $VERSION_FILE + sed -i 's/^PATCH\s*=\s*[0-9]\+/PATCH = '$PATCH'/' $VERSION_FILE + sed -i 's/^PRE_RELEASE\s*=\s*'.*'/PRE_RELEASE = '\'$PRE_RELEASE\''/' $VERSION_FILE + + cat $VERSION_FILE + PRE_RELEASE=$(echo $PRE_RELEASE | tr -d "'") + echo "version=$MAJOR.$MINOR.$PATCH$PRE_RELEASE" >> "$GITHUB_OUTPUT" + + - name: Create Version Bump PR + uses: peter-evans/create-pull-request@v6 + id: create-pull-request + with: + path: ${{ github.run_id }} + branch: bot/chore/version-bump-${{ inputs.next_version }} + title: 'Version bump to `${{ inputs.next_version }}`' + body: | + 🚀 Version bump NeMo toolkit to `${{ inputs.next_version }}` + + commit-message: "[🤠]: Howdy folks, let's bump NeMo `${{ inputs.next_version }}` !" + signoff: true + assignees: okoenig + labels: 'Run CICD' + + notify: + runs-on: ubuntu-latest + needs: [create-release-branch, bump-next-version] + environment: + name: main + steps: + - name: Main + run: | + MESSAGE='{ + "blocks": [ + { + "type": "section", + "text": { + "type": "mrkdwn", + "text": "Releasebot 🤖: NeMo Toolkit has been frozen 🎉 to branch `r${{ needs.create-release-branch.outputs.version }}`" + } + } + ] + }' + + curl -X POST -H "Content-type: application/json" --data "$MESSAGE" ${{ secrets.SLACK_RELEASE_ENDPOINT }} \ No newline at end of file diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 000000000000..3f4c4f3c19de --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,179 @@ +name: "NeMo Code release" + +on: + issue_comment: + types: [created] + +jobs: + main: + if: > + github.event_name == 'issue_comment' && + github.event.issue.pull_request && + startsWith(github.event.comment.body, '/release-please') && + contains(fromJSON('["ko3n1g"]'), github.actor) + runs-on: ubuntu-latest + environment: + name: main + steps: + - name: Update PR issue comment + shell: bash + env: + message: ${{ github.event.comment.body }} + run: | + message="$message + + --- + + Releasebot 🤖: Release processes started... + " + message="${message//$'\n'/
}" + + curl -L \ + -X PATCH \ + -H "Accept: application/vnd.github+json" \ + -H "Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" \ + -H "X-GitHub-Api-Version: 2022-11-28" \ + https://api.github.com/repos/${{ github.repository }}/issues/comments/${{ github.event.comment.id }} \ + -d '{"body":"'"$message"'"}' + + - name: Get PR number + shell: bash + id: get-pr-num + run: | + PR_URL="${{ github.event.issue.pull_request.url }}" + PR_NUM=${PR_URL##*/} + echo "pr_number=$PR_NUM" >> $GITHUB_OUTPUT + + - name: Get Pull Request Information + uses: actions/github-script@v6 + id: get-pr-branch + with: + result-encoding: string + script: | + const pr = await github.rest.pulls.get({ + owner: context.repo.owner, + repo: context.repo.repo, + pull_number: ${{ steps.get-pr-num.outputs.pr_number }} + }); + console.log('Pull Request Information:', pr.data); + return pr.data.head.ref; + + - name: Checkout repository + uses: actions/checkout@v4 + with: + path: ${{ github.run_id }} + ref: ${{ steps.get-pr-branch.outputs.result }} + + - name: Get version number + id: version-number + run: | + cd ${{ github.run_id }} + VERSION=$(python -c "import nemo; print(nemo.__version__)") + echo "VERSION=$VERSION" >> "$GITHUB_OUTPUT" + + - name: Extract changelog + id: extract-changelog + uses: peter-evans/find-comment@v3 + with: + issue-number: ${{ steps.get-pr-num.outputs.pr_number }} + body-includes: '# Detailed Changelogs' + + - name: Extract summary + id: extract-summary + uses: peter-evans/find-comment@v3 + with: + issue-number: ${{ steps.get-pr-num.outputs.pr_number }} + body-includes: '# Highlights' + + - name: Create Release doc + id: create-release-doc + env: + SUMMARY: ${{ steps.extract-summary.outputs.comment-body }} + CHANGELOG: ${{ steps.extract-changelog.outputs.comment-body }} + run: | + + echo "TITLE<> $GITHUB_ENV + echo "NVIDIA Neural Modules ${{ steps.version-number.outputs.VERSION }}" >> $GITHUB_ENV + echo "EOF" >> $GITHUB_ENV + + echo "BODY<> $GITHUB_ENV + echo "$SUMMARY" >> $GITHUB_ENV + echo "$CHANGELOG" >> $GITHUB_ENV + echo "EOF" >> $GITHUB_ENV + + - name: Create Release + uses: softprops/action-gh-release@v2 + with: + name: ${{ env.TITLE }} + tag_name: ${{ steps.version-number.outputs.VERSION }} + body: ${{ env.BODY }} + + - name: Build, test, and release wheel + env: + TWINE_USERNAME: ${{ secrets.TWINE_USERNAME }} + TWINE_PASSWORD: ${{ secrets.TWINE_PASSWORD }} + run: | + cd ${{ github.run_id }} + python3 -m pip install --upgrade build + python3 -m build + + pip install dist/*.whl + + cd ../ + + INSTALLED_VERSION=$(python -c 'import nemo; print(nemo.__version__)') + EXPECTED_VERSION=${{ steps.version-number.outputs.VERSION }} + + if [[ "$INSTALLED_VERSION" != "$EXPECTED_VERSION" ]]; then + echo 'Wheel has an outdated version, mission abort immediately!' + exit 1 + fi + + echo Proceed with uploading wheel... + cd ${{ github.run_id }} + python3 -m pip install --upgrade twine + python3 -m twine upload --repository pypi dist/* + + - name: Update PR issue comment + shell: bash + env: + message: ${{ github.event.comment.body }} + run: | + message="$message + + --- + + Releasebot 🤖: Release done 🎉 + " + message="${message//$'\n'/
}" + + curl -L \ + -X PATCH \ + -H "Accept: application/vnd.github+json" \ + -H "Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" \ + -H "X-GitHub-Api-Version: 2022-11-28" \ + https://api.github.com/repos/${{ github.repository }}/issues/comments/${{ github.event.comment.id }} \ + -d '{"body":"'"$message"'"}' + + - name: Close Pull + run: | + cd ${{ github.run_id }} + gh pr close --comment "Releasebot 🤖: Closing PR" "${{ steps.get-pr-num.outputs.pr_number }}" + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + + - name: notify + run: | + MESSAGE='{ + "blocks": [ + { + "type": "section", + "text": { + "type": "mrkdwn", + "text": "Releasebot 🤖: NeMo Toolkit released `${{ steps.version-number.outputs.VERSION }}` 🚀" + } + } + ] + }' + + curl -X POST -H "Content-type: application/json" --data "$MESSAGE" ${{ secrets.SLACK_RELEASE_ENDPOINT }} \ No newline at end of file diff --git a/.gitignore b/.gitignore index 1ff2a92cac64..1aa5ef00de5e 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ *.pkl #*.ipynb output +output_2048 result *.pt tests/data/asr @@ -179,3 +180,4 @@ examples/neural_graphs/*.yml .hydra/ nemo_experiments/ +slurm*.out diff --git a/Dockerfile.ci b/Dockerfile.ci index dd8af593768f..38b82a288a2b 100644 --- a/Dockerfile.ci +++ b/Dockerfile.ci @@ -33,8 +33,8 @@ WORKDIR /workspace # Install NeMo requirements ARG TE_TAG=7d576ed25266a17a7b651f2c12e8498f67e0baea -ARG MODELOPT_VERSION=0.13.0 -ARG MCORE_TAG=0bc3547702464501feefeb5523b7a17e591b21fa +ARG MODELOPT_VERSION=0.15.0 +ARG MCORE_TAG=2fd6e2b74efca73a1f2d27b89bb5419384b4d3bf ARG APEX_TAG=810ffae374a2b9cb4b5c5e28eaeca7d7998fca0c RUN \ --mount=type=bind,source=requirements,target=requirements \ @@ -47,6 +47,7 @@ pip install --no-cache-dir --no-build-isolation --extra-index-url https://pypi.n "megatron_core @ git+https://github.com/NVIDIA/Megatron-LM.git@${MCORE_TAG}" \ "nvidia-modelopt[torch]~=${MODELOPT_VERSION}" \ "apex @ git+https://github.com/NVIDIA/apex.git@${APEX_TAG}" \ +"unstructured==0.14.9" \ "llama-index==0.10.43" \ "onnxscript @ git+https://github.com/microsoft/onnxscript" \ -r tools/ctc_segmentation/requirements.txt \ @@ -68,14 +69,14 @@ git clone https://github.com/state-spaces/mamba.git && \ git checkout v2.0.3 && \ python setup.py install && \ cd .. && \ - rm -rf mamba + rm -rf mamba git clone https://github.com/Dao-AILab/causal-conv1d && \ cd causal-conv1d && \ git checkout v1.2.2.post1 && \ python setup.py install && \ cd .. && \ - rm -rf causal-conv1d + rm -rf causal-conv1d EOF @@ -89,4 +90,3 @@ chmod 777 -R /workspace EOF ENV PYTHONPATH="${PYTHONPATH}:/workspace/Megatron-LM" - diff --git a/Dockerfile b/Dockerfile.speech similarity index 83% rename from Dockerfile rename to Dockerfile.speech index a42ae592a9bd..e7cc670a132d 100644 --- a/Dockerfile +++ b/Dockerfile.speech @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -ARG BASE_IMAGE=nvcr.io/nvidia/pytorch:24.01-py3 +ARG BASE_IMAGE=nvcr.io/nvidia/pytorch:24.02-py3 # build an image that includes only the nemo dependencies, ensures that dependencies # are included first for optimal caching, and useful for building a development @@ -62,23 +62,28 @@ RUN apt-get update && \ rm -rf /var/lib/apt/lists/* WORKDIR /workspace/ + +ARG TE_TAG=7d576ed25266a17a7b651f2c12e8498f67e0baea +ARG MCORE_TAG=338af51452a53982d202e8386db6233adad1ce86 +ARG APEX_TAG=810ffae374a2b9cb4b5c5e28eaeca7d7998fca0c # Install megatron core, this can be removed once 0.3 pip package is released # We leave it here in case we need to work off of a specific commit in main RUN git clone https://github.com/NVIDIA/Megatron-LM.git && \ cd Megatron-LM && \ - git checkout 02871b4df8c69fac687ab6676c4246e936ce92d0 && \ + git checkout ${MCORE_TAG} && \ pip install . # Performance optimizations for distributed optimizer: https://github.com/NVIDIA/apex/pull/1771 RUN git clone https://github.com/NVIDIA/apex.git && \ cd apex && \ - git checkout f058162b215791b15507bb542f22ccfde49c872d && \ - pip install -v --no-build-isolation --disable-pip-version-check --no-cache-dir --config-settings "--build-option=--cpp_ext --cuda_ext --fast_layer_norm --distributed_adam --deprecated_fused_adam" ./ + git checkout ${APEX_TAG} && \ + pip install -v --no-build-isolation --disable-pip-version-check --no-cache-dir \ + --config-settings "--build-option=--cpp_ext --cuda_ext --fast_layer_norm --distributed_adam --deprecated_fused_adam" ./ # Transformer Engine 1.2.0 RUN git clone https://github.com/NVIDIA/TransformerEngine.git && \ cd TransformerEngine && \ - git fetch origin da30634a6c9ccdbb6c587b6c93b1860e4b038204 && \ + git fetch origin ${TE_TAG} && \ git checkout FETCH_HEAD && \ git submodule init && git submodule update && \ NVTE_FRAMEWORK=pytorch NVTE_WITH_USERBUFFERS=1 MPI_HOME=/usr/local/mpi pip install . @@ -126,7 +131,9 @@ RUN INSTALL_MSG=$(/bin/bash /tmp/nemo/scripts/installers/install_k2.sh); INSTALL WORKDIR /tmp/nemo ENV LHOTSE_REQUIRE_TORCHAUDIO=0 COPY requirements . -RUN for f in $(ls requirements*.txt); do pip3 install --disable-pip-version-check --no-cache-dir -r $f; done +# exclude requirements_vllm.txt, since `vllm==0.5.x` breaks the container due to hardcoded requirements `torch==2.3.0` +RUN for f in $(ls requirements*.txt | grep -v 'requirements_vllm.txt'); do \ + pip3 install --disable-pip-version-check --no-cache-dir -r $f; done # install flash attention RUN pip install flash-attn @@ -151,7 +158,12 @@ RUN /usr/bin/test -n "$NEMO_VERSION" && \ RUN --mount=from=nemo-src,target=/tmp/nemo,rw cd /tmp/nemo && pip install ".[all]" # Check install -RUN python -c "import nemo.collections.nlp as nemo_nlp" && \ +# NB: adjusting LD_LIBRARY_PATH (only here, should not be persistent!) is a temporary hack +# to avoid failure if CUDA is unavailable (`docker build` does not expose GPUs) +# The error is raised in NeMo Core, and the main reason is reinstalled Transformer-Engine; +RUN export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:${CUDA_HOME}/compat/lib.real && \ + python -c "import nemo.collections.asr as nemo_asr" && \ + python -c "import nemo.collections.nlp as nemo_nlp" && \ python -c "import nemo.collections.tts as nemo_tts" && \ python -c "import nemo_text_processing.text_normalization as text_normalization" diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 000000000000..cfcd6ee939cb --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1 @@ +include requirements/* \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 000000000000..9b019d3ac175 --- /dev/null +++ b/README.md @@ -0,0 +1,650 @@ +[![Project Status: Active -- The project has reached a stable, usable state and is being actively developed.](http://www.repostatus.org/badges/latest/active.svg)](http://www.repostatus.org/#active) +[![Documentation](https://readthedocs.com/projects/nvidia-nemo/badge/?version=main)](https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/) +[![CodeQL](https://github.com/nvidia/nemo/actions/workflows/codeql.yml/badge.svg?branch=main&event=push)](https://github.com/nvidia/nemo/actions/workflows/codeql.yml) +[![NeMo core license and license for collections in this repo](https://img.shields.io/badge/License-Apache%202.0-brightgreen.svg)](https://github.com/NVIDIA/NeMo/blob/master/LICENSE) +[![Release version](https://badge.fury.io/py/nemo-toolkit.svg)](https://badge.fury.io/py/nemo-toolkit) +[![Python version](https://img.shields.io/pypi/pyversions/nemo-toolkit.svg)](https://badge.fury.io/py/nemo-toolkit) +[![PyPi total downloads](https://static.pepy.tech/personalized-badge/nemo-toolkit?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=downloads)](https://pepy.tech/project/nemo-toolkit) +[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) + +# **NVIDIA NeMo Framework** + +## Latest News + + +
+ Large Language Models and Multimodal Models +
+ + + New Llama 3.1 Support + (2024-07-23) + + The NeMo Framework now supports training and customizing the Llama 3.1 collection of LLMs from Meta. +

+
+
+ + + Accelerate your Generative AI Distributed Training Workloads with the NVIDIA NeMo Framework on Amazon EKS + (2024-07-16) + + NVIDIA NeMo Framework now runs distributed training workloads on an Amazon Elastic Kubernetes Service (Amazon EKS) cluster. For step-by-step instructions on creating an EKS cluster and running distributed training workloads with NeMo, see the GitHub repository here. +

+
+
+ + + NVIDIA NeMo Accelerates LLM Innovation with Hybrid State Space Model Support + (2024/06/17) + + NVIDIA NeMo and Megatron Core now support pre-training and fine-tuning of state space models (SSMs). NeMo also supports training models based on the Griffin architecture as described by Google DeepMind. +

+
+
+ + + NVIDIA releases 340B base, instruct, and reward models pretrained on a total of 9T tokens. + (2024-06-18) + + See documentation and tutorials for SFT, PEFT, and PTQ with + + Nemotron 340B + + in the NeMo Framework User Guide. +

+
+
+ + + NVIDIA sets new generative AI performance and scale records in MLPerf Training v4.0 + (2024/06/12) + + Using NVIDIA NeMo Framework and NVIDIA Hopper GPUs NVIDIA was able to scale to 11,616 H100 GPUs and achieve near-linear performance scaling on LLM pretraining. + NVIDIA also achieved the highest LLM fine-tuning performance and raised the bar for text-to-image training. +

+
+
+ + + Accelerate your generative AI journey with NVIDIA NeMo Framework on GKE + (2024/03/16) + + An end-to-end walkthrough to train generative AI models on the Google Kubernetes Engine (GKE) using the NVIDIA NeMo Framework is available at https://github.com/GoogleCloudPlatform/nvidia-nemo-on-gke. + The walkthrough includes detailed instructions on how to set up a Google Cloud Project and pre-train a GPT model using the NeMo Framework. +

+
+
+ +
+ Speech Recognition +
+ + + New Standard for Speech Recognition and Translation from the NVIDIA NeMo Canary Model + (2024/04/18) + + The NeMo team just released Canary, a multilingual model that transcribes speech in English, Spanish, German, and French with punctuation and capitalization. + Canary also provides bi-directional translation, between English and the three other supported languages. +

+
+
+ + + Pushing the Boundaries of Speech Recognition with NVIDIA NeMo Parakeet ASR Models + (2024/04/18) + + NVIDIA NeMo, an end-to-end platform for the development of multimodal generative AI models at scale anywhere—on any cloud and on-premises—released the Parakeet family of automatic speech recognition (ASR) models. + These state-of-the-art ASR models, developed in collaboration with Suno.ai, transcribe spoken English with exceptional accuracy. +

+
+
+ + + Turbocharge ASR Accuracy and Speed with NVIDIA NeMo Parakeet-TDT + (2024/04/18) + + NVIDIA NeMo, an end-to-end platform for developing multimodal generative AI models at scale anywhere—on any cloud and on-premises—recently released Parakeet-TDT. + This new addition to the  NeMo ASR Parakeet model family boasts better accuracy and 64% greater speed over the previously best model, Parakeet-RNNT-1.1B. +

+
+
+ + +## Introduction + +NVIDIA NeMo Framework is a scalable and cloud-native generative AI +framework built for researchers and PyTorch developers working on Large +Language Models (LLMs), Multimodal Models (MMs), Automatic Speech +Recognition (ASR), Text to Speech (TTS), and Computer Vision (CV) +domains. It is designed to help you efficiently create, customize, and +deploy new generative AI models by leveraging existing code and +pre-trained model checkpoints. + +For technical documentation, please see the [NeMo Framework User +Guide](https://docs.nvidia.com/nemo-framework/user-guide/latest/playbooks/index.html). + +## LLMs and MMs Training, Alignment, and Customization + +All NeMo models are trained with +[Lightning](https://github.com/Lightning-AI/lightning). Training is +automatically scalable to 1000s of GPUs. + +When applicable, NeMo models leverage cutting-edge distributed training +techniques, incorporating [parallelism +strategies](https://docs.nvidia.com/nemo-framework/user-guide/latest/modeloverview.html) +to enable efficient training of very large models. These techniques +include Tensor Parallelism (TP), Pipeline Parallelism (PP), Fully +Sharded Data Parallelism (FSDP), Mixture-of-Experts (MoE), and Mixed +Precision Training with BFloat16 and FP8, as well as others. + +NeMo Transformer-based LLMs and MMs utilize [NVIDIA Transformer +Engine](https://github.com/NVIDIA/TransformerEngine) for FP8 training on +NVIDIA Hopper GPUs, while leveraging [NVIDIA Megatron +Core](https://github.com/NVIDIA/Megatron-LM/tree/main/megatron/core) for +scaling Transformer model training. + +NeMo LLMs can be aligned with state-of-the-art methods such as SteerLM, +Direct Preference Optimization (DPO), and Reinforcement Learning from +Human Feedback (RLHF). See [NVIDIA NeMo +Aligner](https://github.com/NVIDIA/NeMo-Aligner) for more information. + +In addition to supervised fine-tuning (SFT), NeMo also supports the +latest parameter efficient fine-tuning (PEFT) techniques such as LoRA, +P-Tuning, Adapters, and IA3. Refer to the [NeMo Framework User +Guide](https://docs.nvidia.com/nemo-framework/user-guide/latest/sft_peft/index.html) +for the full list of supported models and techniques. + +## LLMs and MMs Deployment and Optimization + +NeMo LLMs and MMs can be deployed and optimized with [NVIDIA NeMo +Microservices](https://developer.nvidia.com/nemo-microservices-early-access). + +## Speech AI + +NeMo ASR and TTS models can be optimized for inference and deployed for +production use cases with [NVIDIA Riva](https://developer.nvidia.com/riva). + +## NeMo Framework Launcher + +[NeMo Framework +Launcher](https://github.com/NVIDIA/NeMo-Megatron-Launcher) is a +cloud-native tool that streamlines the NeMo Framework experience. It is +used for launching end-to-end NeMo Framework training jobs on CSPs and +Slurm clusters. + +The NeMo Framework Launcher includes extensive recipes, scripts, +utilities, and documentation for training NeMo LLMs. It also includes +the NeMo Framework [Autoconfigurator](https://github.com/NVIDIA/NeMo-Megatron-Launcher#53-using-autoconfigurator-to-find-the-optimal-configuration), +which is designed to find the optimal model parallel configuration for +training on a specific cluster. + +To get started quickly with the NeMo Framework Launcher, please see the +[NeMo Framework +Playbooks](https://docs.nvidia.com/nemo-framework/user-guide/latest/playbooks/index.html). +The NeMo Framework Launcher does not currently support ASR and TTS +training, but it will soon. + +## Get Started with NeMo Framework + +Getting started with NeMo Framework is easy. State-of-the-art pretrained +NeMo models are freely available on [Hugging Face +Hub](https://huggingface.co/models?library=nemo&sort=downloads&search=nvidia) +and [NVIDIA +NGC](https://catalog.ngc.nvidia.com/models?query=nemo&orderBy=weightPopularDESC). +These models can be used to generate text or images, transcribe audio, +and synthesize speech in just a few lines of code. + +We have extensive +[tutorials](https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/starthere/tutorials.html) +that can be run on [Google Colab](https://colab.research.google.com) or +with our [NGC NeMo Framework +Container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo). +We also have +[playbooks](https://docs.nvidia.com/nemo-framework/user-guide/latest/playbooks/index.html) +for users who want to train NeMo models with the NeMo Framework +Launcher. + +For advanced users who want to train NeMo models from scratch or +fine-tune existing NeMo models, we have a full suite of [example +scripts](https://github.com/NVIDIA/NeMo/tree/main/examples) that support +multi-GPU/multi-node training. + +## Key Features + +- [Large Language Models](nemo/collections/nlp/README.md) +- [Multimodal](nemo/collections/multimodal/README.md) +- [Automatic Speech Recognition](nemo/collections/asr/README.md) +- [Text to Speech](nemo/collections/tts/README.md) +- [Computer Vision](nemo/collections/vision/README.md) + +## Requirements + +- Python 3.10 or above +- Pytorch 1.13.1 or above +- NVIDIA GPU (if you intend to do model training) + +## Developer Documentation + +| Version | Status | Description | +| ------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------ | +| Latest | [![Documentation Status](https://readthedocs.com/projects/nvidia-nemo/badge/?version=main)](https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/) | [Documentation of the latest (i.e. main) branch.](https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/) | +| Stable | [![Documentation Status](https://readthedocs.com/projects/nvidia-nemo/badge/?version=stable)](https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/) | [Documentation of the stable (i.e. most recent release)](https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/) | + +## Install NeMo Framework + +The NeMo Framework can be installed in a variety of ways, depending on +your needs. Depending on the domain, you may find one of the following +installation methods more suitable. + +- Conda / Pip - Refer to [Conda](#conda) and [Pip](#pip) for + installation instructions. + - This is the recommended method for ASR and TTS domains. + - When using a Nvidia PyTorch container as the base, this is the + recommended method for all domains. +- Docker Containers - Refer to [Docker containers](#docker-containers) + for installation instructions. + - NeMo Framework container - + [nvcr.io/nvidia/nemo:24.05]{.title-ref} +- LLMs and MMs Dependencies - Refer to [LLMs and MMs + Dependencies](#install-llms-and-mms-dependencies) for installation + instructions. + +**Important: We strongly recommended that you start with a base NVIDIA +PyTorch container: nvcr.io/nvidia/pytorch:24.02-py3.** + +### Conda + +Install NeMo in a fresh Conda environment: + +```bash +conda create --name nemo python==3.10.12 +conda activate nemo +``` + +Install PyTorch using their +[configurator](https://pytorch.org/get-started/locally/): + +```bash +conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia +``` + +The command to install PyTorch may depend on your system. Use the +configurator linked above to find the right command for your system. + +Then, install NeMo via Pip or from Source. We do not provide NeMo on the +conda-forge or any other Conda channel. + +### Pip + +To install the nemo_toolkit, use the following installation method: + +```bash +apt-get update && apt-get install -y libsndfile1 ffmpeg +pip install Cython packaging +pip install nemo_toolkit['all'] +``` + +Depending on the shell used, you may need to use the +`"nemo_toolkit[all]"` specifier instead in the above command. + +### Pip from a Specific Domain + +To install a specific domain of NeMo, you must first install the +nemo_toolkit using the instructions listed above. Then, you run the +following domain-specific commands: + +```bash +pip install nemo_toolkit['asr'] +pip install nemo_toolkit['nlp'] +pip install nemo_toolkit['tts'] +pip install nemo_toolkit['vision'] +pip install nemo_toolkit['multimodal'] +``` + +### Pip from a Source Branch + +If you want to work with a specific version of NeMo from a particular +GitHub branch (e.g main), use the following installation method: + +```bash +apt-get update && apt-get install -y libsndfile1 ffmpeg +pip install Cython packaging +python -m pip install git+https://github.com/NVIDIA/NeMo.git@{BRANCH}#egg=nemo_toolkit[all] +``` + +### Build from Source + +If you want to clone the NeMo GitHub repository and contribute to NeMo +open-source development work, use the following installation method: + +```bash +apt-get update && apt-get install -y libsndfile1 ffmpeg +git clone https://github.com/NVIDIA/NeMo +cd NeMo +./reinstall.sh +``` + +If you only want the toolkit without the additional Conda-based +dependencies, you can replace `reinstall.sh` with `pip install -e .` +when your PWD is the root of the NeMo repository. + +### Mac Computers with Apple Silicon + +To install NeMo on Mac computers with the Apple M-Series GPU, you need +to create a new Conda environment, install PyTorch 2.0 or higher, and +then install the nemo_toolkit. + +**Important: This method is only applicable to the ASR domain.** + +Run the following code: + +```shell +# [optional] install mecab using Homebrew, to use sacrebleu for NLP collection +# you can install Homebrew here: https://brew.sh +brew install mecab + +# [optional] install pynini using Conda, to use text normalization +conda install -c conda-forge pynini + +# install Cython manually +pip install cython packaging + +# clone the repo and install in development mode +git clone https://github.com/NVIDIA/NeMo +cd NeMo +pip install 'nemo_toolkit[all]' + +# Note that only the ASR toolkit is guaranteed to work on MacBook - so for MacBook use pip install 'nemo_toolkit[asr]' +``` + +### Windows Computers + +To install the Windows Subsystem for Linux (WSL), run the following code +in PowerShell: + +```shell +wsl --install +# [note] If you run wsl --install and see the WSL help text, it means WSL is already installed. +``` + +To learn more about installing WSL, refer to [Microsoft\'s official +documentation](https://learn.microsoft.com/en-us/windows/wsl/install). + +After installing your Linux distribution with WSL, two options are +available: + +**Option 1:** Open the distribution (Ubuntu by default) from the Start +menu and follow the instructions. + +**Option 2:** Launch the Terminal application. Download it from +[Microsoft\'s Windows Terminal +page](https://learn.microsoft.com/en-us/windows/terminal) if not +installed. + +Next, follow the instructions for Linux systems, as provided above. For +example: + +```bash +apt-get update && apt-get install -y libsndfile1 ffmpeg +git clone https://github.com/NVIDIA/NeMo +cd NeMo +./reinstall.sh +``` + +### RNNT + +For optimal performance of a Recurrent Neural Network Transducer (RNNT), +install the Numba package from Conda. + +Run the following code: + +```bash +conda remove numba +pip uninstall numba +conda install -c conda-forge numba +``` + +## Install LLMs and MMs Dependencies + +If you work with the LLM and MM domains, three additional dependencies +are required: NVIDIA Apex, NVIDIA Transformer Engine, and NVIDIA +Megatron Core. When working with the [main]{.title-ref} branch, these +dependencies may require a recent commit. + +The most recent working versions of these dependencies are here: + +```bash +export apex_commit=810ffae374a2b9cb4b5c5e28eaeca7d7998fca0c +export te_commit=bfe21c3d68b0a9951e5716fb520045db53419c5e +export mcore_commit=02871b4df8c69fac687ab6676c4246e936ce92d0 +export nv_pytorch_tag=24.02-py3 +``` + +When using a released version of NeMo, please refer to the [Software +Component +Versions](https://docs.nvidia.com/nemo-framework/user-guide/latest/softwarecomponentversions.html) +for the correct versions. + +### PyTorch Container + +We recommended that you start with a base NVIDIA PyTorch container: +nvcr.io/nvidia/pytorch:24.02-py3. + +If starting with a base NVIDIA PyTorch container, you must first launch +the container: + +```bash +docker run \ + --gpus all \ + -it \ + --rm \ + --shm-size=16g \ + --ulimit memlock=-1 \ + --ulimit stack=67108864 \ + nvcr.io/nvidia/pytorch:$nv_pytorch_tag +``` + +Next, you need to install the dependencies. + +### Apex + +NVIDIA Apex is required for LLM and MM domains. Although Apex is +pre-installed in the NVIDIA PyTorch container, you may need to update it +to a newer version. + +To install Apex, run the following code: + +```bash +git clone https://github.com/NVIDIA/apex.git +cd apex +git checkout $apex_commit +pip install . -v --no-build-isolation --disable-pip-version-check --no-cache-dir --config-settings "--build-option=--cpp_ext --cuda_ext --fast_layer_norm --distributed_adam --deprecated_fused_adam --group_norm" +``` + +When attempting to install Apex separately from the NVIDIA PyTorch +container, you might encounter an error if the CUDA version on your +system is different from the one used to compile PyTorch. To bypass this +error, you can comment out the relevant line in the setup file located +in the Apex repository on GitHub here: +. + +cuda-nvprof is needed to install Apex. The version should match the CUDA +version that you are using. + +To install cuda-nvprof, run the following code: + +```bash +conda install -c nvidia cuda-nvprof=11.8 +``` + +Finally, install the packaging: + +```bash +pip install packaging +``` + +To install the most recent versions of Apex locally, it might be +necessary to remove the [pyproject.toml]{.title-ref} file from the Apex +directory. + +### Transformer Engine + +NVIDIA Transformer Engine is required for LLM and MM domains. Although +the Transformer Engine is pre-installed in the NVIDIA PyTorch container, +you may need to update it to a newer version. + +The Transformer Engine facilitates training with FP8 precision on NVIDIA +Hopper GPUs and introduces many enhancements for the training of +Transformer-based models. Refer to [Transformer Engine](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/installation.html) +for information. + +To install Transformer Engine, run the following code: + +```bash +git clone https://github.com/NVIDIA/TransformerEngine.git && \ +cd TransformerEngine && \ +git checkout $te_commit && \ +git submodule init && git submodule update && \ +NVTE_FRAMEWORK=pytorch NVTE_WITH_USERBUFFERS=1 MPI_HOME=/usr/local/mpi pip install . +``` + +Transformer Engine requires PyTorch to be built with at least CUDA 11.8. + +### Megatron Core + +Megatron Core is required for LLM and MM domains. Megatron Core is a +library for scaling large Transformer-based models. NeMo LLMs and MMs +leverage Megatron Core for model parallelism, transformer architectures, +and optimized PyTorch datasets. + +To install Megatron Core, run the following code: + +```bash +git clone https://github.com/NVIDIA/Megatron-LM.git && \ +cd Megatron-LM && \ +git checkout $mcore_commit && \ +pip install . && \ +cd megatron/core/datasets && \ +make +``` + +## NeMo Text Processing + +NeMo Text Processing, specifically Inverse Text Normalization, is now a +separate repository. It is located here: +. + +## Docker Containers + +NeMo containers are launched concurrently with NeMo version updates. +NeMo Framework now supports LLMs, MMs, ASR, and TTS in a single +consolidated Docker container. You can find additional information about +released containers on the [NeMo releases +page](https://github.com/NVIDIA/NeMo/releases). + +To use a pre-built container, run the following code: + +```bash +docker pull nvcr.io/nvidia/nemo:24.05 +``` + +To build a nemo container with Dockerfile from a branch, run the +following code: + +```bash +DOCKER_BUILDKIT=1 docker build -f Dockerfile -t nemo:latest +``` + +If you choose to work with the main branch, we recommend using NVIDIA\'s +PyTorch container version 23.10-py3 and then installing from GitHub. + +```bash +docker run --gpus all -it --rm -v :/NeMo --shm-size=8g \ +-p 8888:8888 -p 6006:6006 --ulimit memlock=-1 --ulimit \ +stack=67108864 --device=/dev/snd nvcr.io/nvidia/pytorch:23.10-py3 +``` + +## Future Work + +The NeMo Framework Launcher does not currently support ASR and TTS +training, but it will soon. + +## Discussions Board + +FAQ can be found on the NeMo [Discussions +board](https://github.com/NVIDIA/NeMo/discussions). You are welcome to +ask questions or start discussions on the board. + +## Contribute to NeMo + +We welcome community contributions! Please refer to +[CONTRIBUTING.md](https://github.com/NVIDIA/NeMo/blob/stable/CONTRIBUTING.md) +for the process. + +## Publications + +We provide an ever-growing list of +[publications](https://nvidia.github.io/NeMo/publications/) that utilize +the NeMo Framework. + +To contribute an article to the collection, please submit a pull request +to the `gh-pages-src` branch of this repository. For detailed +information, please consult the README located at the [gh-pages-src +branch](https://github.com/NVIDIA/NeMo/tree/gh-pages-src#readme). + +## Blogs + + +
+ Large Language Models and Multimodal Models +
+ + + Bria Builds Responsible Generative AI for Enterprises Using NVIDIA NeMo, Picasso + (2024/03/06) + + Bria, a Tel Aviv startup at the forefront of visual generative AI for enterprises now leverages the NVIDIA NeMo Framework. + The Bria.ai platform uses reference implementations from the NeMo Multimodal collection, trained on NVIDIA Tensor Core GPUs, to enable high-throughput and low-latency image generation. + Bria has also adopted NVIDIA Picasso, a foundry for visual generative AI models, to run inference. +

+
+
+ + + New NVIDIA NeMo Framework Features and NVIDIA H200 + (2023/12/06) + + NVIDIA NeMo Framework now includes several optimizations and enhancements, + including: + 1) Fully Sharded Data Parallelism (FSDP) to improve the efficiency of training large-scale AI models, + 2) Mix of Experts (MoE)-based LLM architectures with expert parallelism for efficient LLM training at scale, + 3) Reinforcement Learning from Human Feedback (RLHF) with TensorRT-LLM for inference stage acceleration, and + 4) up to 4.2x speedups for Llama 2 pre-training on NVIDIA H200 Tensor Core GPUs. +

+ + H200-NeMo-performance +

+
+
+ + + NVIDIA now powers training for Amazon Titan Foundation models + (2023/11/28) + + NVIDIA NeMo Framework now empowers the Amazon Titan foundation models (FM) with efficient training of large language models (LLMs). + The Titan FMs form the basis of Amazon’s generative AI service, Amazon Bedrock. + The NeMo Framework provides a versatile framework for building, customizing, and running LLMs. +

+
+
+ + +## Licenses + +- [NeMo GitHub Apache 2.0 + license](https://github.com/NVIDIA/NeMo?tab=Apache-2.0-1-ov-file#readme) +- NeMo is licensed under the [NVIDIA AI PRODUCT + AGREEMENT](https://www.nvidia.com/en-us/data-center/products/nvidia-ai-enterprise/eula/). + By pulling and using the container, you accept the terms and + conditions of this license. diff --git a/README.rst b/README.rst deleted file mode 100644 index e24ce6f05a36..000000000000 --- a/README.rst +++ /dev/null @@ -1,584 +0,0 @@ - -|status| |documentation| |codeql| |license| |pypi| |pyversion| |downloads| |black| - -.. |status| image:: http://www.repostatus.org/badges/latest/active.svg - :target: http://www.repostatus.org/#active - :alt: Project Status: Active – The project has reached a stable, usable state and is being actively developed. - -.. |documentation| image:: https://readthedocs.com/projects/nvidia-nemo/badge/?version=main - :alt: Documentation - :target: https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/ - -.. |license| image:: https://img.shields.io/badge/License-Apache%202.0-brightgreen.svg - :target: https://github.com/NVIDIA/NeMo/blob/master/LICENSE - :alt: NeMo core license and license for collections in this repo - -.. |pypi| image:: https://badge.fury.io/py/nemo-toolkit.svg - :target: https://badge.fury.io/py/nemo-toolkit - :alt: Release version - -.. |pyversion| image:: https://img.shields.io/pypi/pyversions/nemo-toolkit.svg - :target: https://badge.fury.io/py/nemo-toolkit - :alt: Python version - -.. |downloads| image:: https://static.pepy.tech/personalized-badge/nemo-toolkit?period=total&units=international_system&left_color=grey&right_color=brightgreen&left_text=downloads - :target: https://pepy.tech/project/nemo-toolkit - :alt: PyPi total downloads - -.. |codeql| image:: https://github.com/nvidia/nemo/actions/workflows/codeql.yml/badge.svg?branch=main&event=push - :target: https://github.com/nvidia/nemo/actions/workflows/codeql.yml - :alt: CodeQL - -.. |black| image:: https://img.shields.io/badge/code%20style-black-000000.svg - :target: https://github.com/psf/black - :alt: Code style: black - -.. _main-readme: - -**NVIDIA NeMo Framework** -========================= - -Latest News ------------ - -.. raw:: html - -
- Large Language Models and Multimodal -
- - - NVIDIA releases 340B base, instruct, and reward models pretrained on a total of 9T tokens. - (2024-06-18) - - See documentation and tutorials for SFT, PEFT, and PTQ with - - Nemotron 340B - - in the NeMo Framework User Guide. -

-
- -
- - - NVIDIA sets new generative AI performance and scale records in MLPerf Training v4.0 - (2024/06/12) - - - Using NVIDIA NeMo Framework and NVIDIA Hopper GPUs NVIDIA was able to scale to 11,616 H100 GPUs and achieve near-linear performance scaling on LLM pretraining. - NVIDIA also achieved the highest LLM fine-tuning performance and raised the bar for text-to-image training. -

-
- -
- - - Accelerate your generative AI journey with NVIDIA NeMo Framework on GKE - (2024/03/16) - - - An end-to-end walkthrough to train generative AI models on the Google Kubernetes Engine (GKE) using the NVIDIA NeMo Framework is available at https://github.com/GoogleCloudPlatform/nvidia-nemo-on-gke. - The walkthrough includes detailed instructions on how to set up a Google Cloud Project and pre-train a GPT model using the NeMo Framework. -

-
- -
- - - Bria Builds Responsible Generative AI for Enterprises Using NVIDIA NeMo, Picasso - (2024/03/06) - - - Bria, a Tel Aviv startup at the forefront of visual generative AI for enterprises now leverages the NVIDIA NeMo Framework. - The Bria.ai platform uses reference implementations from the NeMo Multimodal collection, trained on NVIDIA Tensor Core GPUs, to enable high-throughput and low-latency image generation. - Bria has also adopted NVIDIA Picasso, a foundry for visual generative AI models, to run inference. -

-
- -
- - - New NVIDIA NeMo Framework Features and NVIDIA H200 - (2023/12/06) - - - NVIDIA NeMo Framework now includes several optimizations and enhancements, - including: - 1) Fully Sharded Data Parallelism (FSDP) to improve the efficiency of training large-scale AI models, - 2) Mix of Experts (MoE)-based LLM architectures with expert parallelism for efficient LLM training at scale, - 3) Reinforcement Learning from Human Feedback (RLHF) with TensorRT-LLM for inference stage acceleration, and - 4) up to 4.2x speedups for Llama 2 pre-training on NVIDIA H200 Tensor Core GPUs. -

- - H200-NeMo-performance -

-
- -
- - - NVIDIA now powers training for Amazon Titan Foundation models - (2023/11/28) - - - NVIDIA NeMo Framework now empowers the Amazon Titan foundation models (FM) with efficient training of large language models (LLMs). - The Titan FMs form the basis of Amazon’s generative AI service, Amazon Bedrock. - The NeMo Framework provides a versatile framework for building, customizing, and running LLMs. -

-
- -
- -
- Speech Recognition -
- - - New Standard for Speech Recognition and Translation from the NVIDIA NeMo Canary Model - (2024/04/18) - - - The NeMo team just released Canary, a multilingual model that transcribes speech in English, Spanish, German, and French with punctuation and capitalization. - Canary also provides bi-directional translation, between English and the three other supported languages. -

-
- -
- - - Pushing the Boundaries of Speech Recognition with NVIDIA NeMo Parakeet ASR Models - (2024/04/18) - - - NVIDIA NeMo, an end-to-end platform for the development of multimodal generative AI models at scale anywhere—on any cloud and on-premises—released the Parakeet family of automatic speech recognition (ASR) models. - These state-of-the-art ASR models, developed in collaboration with Suno.ai, transcribe spoken English with exceptional accuracy. -

-
- -
- - - Turbocharge ASR Accuracy and Speed with NVIDIA NeMo Parakeet-TDT - (2024/04/18) - - - NVIDIA NeMo, an end-to-end platform for developing multimodal generative AI models at scale anywhere—on any cloud and on-premises—recently released Parakeet-TDT. - This new addition to the  NeMo ASR Parakeet model family boasts better accuracy and 64% greater speed over the previously best model, Parakeet-RNNT-1.1B. -

-
- -
- - - - -Introduction ------------- - -NVIDIA NeMo Framework is a scalable and cloud-native generative AI framework built for researchers and PyTorch developers working on Large Language Models (LLMs), Multimodal Models (MMs), Automatic Speech Recognition (ASR), Text to Speech (TTS), and Computer Vision (CV) domains. It is designed to help you efficiently create, customize, and deploy new generative AI models by leveraging existing code and pre-trained model checkpoints. - -For technical documentation, please see the `NeMo Framework User Guide `_. - -LLMs and MMs Training, Alignment, and Customization ---------------------------------------------------- - -All NeMo models are trained with `Lightning `_. -Training is automatically scalable to 1000s of GPUs. - -When applicable, NeMo models leverage cutting-edge distributed training techniques, incorporating `parallelism strategies `_ to enable efficient training of very large models. These techniques include Tensor Parallelism (TP), Pipeline Parallelism (PP), Fully Sharded Data Parallelism (FSDP), Mixture-of-Experts (MoE), and Mixed Precision Training with BFloat16 and FP8, as well as others. - -NeMo Transformer-based LLMs and MMs utilize `NVIDIA Transformer Engine `_ for FP8 training on NVIDIA Hopper GPUs, while leveraging `NVIDIA Megatron Core `_ for scaling Transformer model training. - -NeMo LLMs can be aligned with state-of-the-art methods such as SteerLM, Direct Preference Optimization (DPO), and Reinforcement Learning from Human Feedback (RLHF). See `NVIDIA NeMo Aligner `_ for more information. - -In addition to supervised fine-tuning (SFT), NeMo also supports the latest parameter efficient fine-tuning (PEFT) techniques such as LoRA, P-Tuning, Adapters, and IA3. Refer to the `NeMo Framework User Guide `_ for the full list of supported models and techniques. - -LLMs and MMs Deployment and Optimization ----------------------------------------- - -NeMo LLMs and MMs can be deployed and optimized with `NVIDIA NeMo Microservices `_. - -Speech AI ---------- - -NeMo ASR and TTS models can be optimized for inference and deployed for production use cases with `NVIDIA Riva `_. - -NeMo Framework Launcher ------------------------ - -`NeMo Framework Launcher `_ is a cloud-native tool that streamlines the NeMo Framework experience. It is used for launching end-to-end NeMo Framework training jobs on CSPs and Slurm clusters. - -The NeMo Framework Launcher includes extensive recipes, scripts, utilities, and documentation for training NeMo LLMs. It also includes the NeMo Framework `Autoconfigurator `_, which is designed to find the optimal model parallel configuration for training on a specific cluster. - -To get started quickly with the NeMo Framework Launcher, please see the `NeMo Framework Playbooks `_. The NeMo Framework Launcher does not currently support ASR and TTS training, but it will soon. - -Get Started with NeMo Framework -------------------------------- - -Getting started with NeMo Framework is easy. State-of-the-art pretrained NeMo models are freely available on `Hugging Face Hub `_ and -`NVIDIA NGC `_. -These models can be used to generate text or images, transcribe audio, and synthesize speech in just a few lines of code. - -We have extensive `tutorials `_ that -can be run on `Google Colab `_ or with our `NGC NeMo Framework Container `_. We also have `playbooks `_ for users who want to train NeMo models with the NeMo Framework Launcher. - -For advanced users who want to train NeMo models from scratch or fine-tune existing NeMo models, we have a full suite of `example scripts `_ that support multi-GPU/multi-node training. - -Key Features ------------- - -* `Large Language Models `_ -* `Multimodal `_ -* `Automatic Speech Recognition `_ -* `Text to Speech `_ -* `Computer Vision `_ - -Requirements ------------- - -* Python 3.10 or above -* Pytorch 1.13.1 or above -* NVIDIA GPU (if you intend to do model training) - -Developer Documentation ------------------------ - -.. |main| image:: https://readthedocs.com/projects/nvidia-nemo/badge/?version=main - :alt: Documentation Status - :scale: 100% - :target: https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/ - -.. |stable| image:: https://readthedocs.com/projects/nvidia-nemo/badge/?version=stable - :alt: Documentation Status - :scale: 100% - :target: https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/stable/ - -+---------+-------------+------------------------------------------------------------------------------------------------------------------------------------------+ -| Version | Status | Description | -+=========+=============+==========================================================================================================================================+ -| Latest | |main| | `Documentation of the latest (i.e. main) branch. `_ | -+---------+-------------+------------------------------------------------------------------------------------------------------------------------------------------+ -| Stable | |stable| | `Documentation of the stable (i.e. most recent release) branch. `_ | -+---------+-------------+------------------------------------------------------------------------------------------------------------------------------------------+ - -Install NeMo Framework ----------------------- - -The NeMo Framework can be installed in a variety of ways, depending on your needs. Depending on the domain, you may find one of the following installation methods more suitable. - -* Conda / Pip - Refer to `Conda <#conda>`_ and `Pip <#pip>`_ for installation instructions. - - * This is the recommended method for ASR and TTS domains. - * When using a Nvidia PyTorch container as the base, this is the recommended method for all domains. - -* Docker Containers - Refer to `Docker containers <#docker-containers>`_ for installation instructions. - - * NeMo Framework container - `nvcr.io/nvidia/nemo:24.05` - -* LLMs and MMs Dependencies - Refer to `LLMs and MMs Dependencies <#install-llms-and-mms-dependencies>`_ for installation instructions. - -**Important: We strongly recommended that you start with a base NVIDIA PyTorch container: nvcr.io/nvidia/pytorch:24.02-py3.** - -Conda -^^^^^ - -Install NeMo in a fresh Conda environment: - -.. code-block:: bash - - conda create --name nemo python==3.10.12 - conda activate nemo - -Install PyTorch using their `configurator `_: - -.. code-block:: bash - - conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia - -The command to install PyTorch may depend on your system. Use the configurator linked above to find the right command for your system. - -Then, install NeMo via Pip or from Source. We do not provide NeMo on the conda-forge or any other Conda channel. - -Pip -^^^ - -To install the nemo_toolkit, use the following installation method: - -.. code-block:: bash - - apt-get update && apt-get install -y libsndfile1 ffmpeg - pip install Cython packaging - pip install nemo_toolkit['all'] - -Depending on the shell used, you may need to use the ``"nemo_toolkit[all]"`` specifier instead in the above command. - -Pip from a Specific Domain -^^^^^^^^^^^^^^^^^^^^^^^^^^ - -To install a specific domain of NeMo, you must first install the nemo_toolkit using the instructions listed above. Then, you run the following domain-specific commands: - -.. code-block:: bash - - pip install nemo_toolkit['asr'] - pip install nemo_toolkit['nlp'] - pip install nemo_toolkit['tts'] - pip install nemo_toolkit['vision'] - pip install nemo_toolkit['multimodal'] - -Pip from a Source Branch -^^^^^^^^^^^^^^^^^^^^^^^^ - -If you want to work with a specific version of NeMo from a particular GitHub branch (e.g main), use the following installation method: - -.. code-block:: bash - - apt-get update && apt-get install -y libsndfile1 ffmpeg - pip install Cython packaging - python -m pip install git+https://github.com/NVIDIA/NeMo.git@{BRANCH}#egg=nemo_toolkit[all] - - -Build from Source -^^^^^^^^^^^^^^^^^ - -If you want to clone the NeMo GitHub repository and contribute to NeMo open-source development work, use the following installation method: - -.. code-block:: bash - - apt-get update && apt-get install -y libsndfile1 ffmpeg - git clone https://github.com/NVIDIA/NeMo - cd NeMo - ./reinstall.sh - -If you only want the toolkit without the additional Conda-based dependencies, you can replace ``reinstall.sh`` with ``pip install -e .`` when your PWD is the root of the NeMo repository. - -Mac Computers with Apple Silicon -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -To install NeMo on Mac computers with the Apple M-Series GPU, you need to create a new Conda environment, install PyTorch 2.0 or higher, and then install the nemo_toolkit. - -**Important: This method is only applicable to the ASR domain.** - -Run the following code: - -.. code-block:: shell - - # [optional] install mecab using Homebrew, to use sacrebleu for NLP collection - # you can install Homebrew here: https://brew.sh - brew install mecab - - # [optional] install pynini using Conda, to use text normalization - conda install -c conda-forge pynini - - # install Cython manually - pip install cython packaging - - # clone the repo and install in development mode - git clone https://github.com/NVIDIA/NeMo - cd NeMo - pip install 'nemo_toolkit[all]' - - # Note that only the ASR toolkit is guaranteed to work on MacBook - so for MacBook use pip install 'nemo_toolkit[asr]' - -Windows Computers -^^^^^^^^^^^^^^^^^ - -To install the Windows Subsystem for Linux (WSL), run the following code in PowerShell: - -.. code-block:: shell - - wsl --install - # [note] If you run wsl --install and see the WSL help text, it means WSL is already installed. - -To learn more about installing WSL, refer to `Microsoft's official documentation `_. - -After installing your Linux distribution with WSL, two options are available: - -**Option 1:** Open the distribution (Ubuntu by default) from the Start menu and follow the instructions. - -**Option 2:** Launch the Terminal application. Download it from `Microsoft's Windows Terminal page `_ if not installed. - -Next, follow the instructions for Linux systems, as provided above. For example: - -.. code-block:: bash - - apt-get update && apt-get install -y libsndfile1 ffmpeg - git clone https://github.com/NVIDIA/NeMo - cd NeMo - ./reinstall.sh - -RNNT -^^^^ - -For optimal performance of a Recurrent Neural Network Transducer (RNNT), install the Numba package from Conda. - -Run the following code: - -.. code-block:: bash - - conda remove numba - pip uninstall numba - conda install -c conda-forge numba - -Install LLMs and MMs Dependencies ---------------------------------- - -If you work with the LLM and MM domains, three additional dependencies are required: NVIDIA Apex, NVIDIA Transformer Engine, and NVIDIA Megatron Core. When working with the `main` branch, these dependencies may require a recent commit. - -The most recent working versions of these dependencies are here: - -.. code-block:: bash - - export apex_commit=810ffae374a2b9cb4b5c5e28eaeca7d7998fca0c - export te_commit=bfe21c3d68b0a9951e5716fb520045db53419c5e - export mcore_commit=02871b4df8c69fac687ab6676c4246e936ce92d0 - export nv_pytorch_tag=24.02-py3 - -When using a released version of NeMo, please refer to the `Software Component Versions `_ for the correct versions. - -PyTorch Container -^^^^^^^^^^^^^^^^^ - -We recommended that you start with a base NVIDIA PyTorch container: nvcr.io/nvidia/pytorch:24.02-py3. - -If starting with a base NVIDIA PyTorch container, you must first launch the container: - -.. code-block:: bash - - docker run \ - --gpus all \ - -it \ - --rm \ - --shm-size=16g \ - --ulimit memlock=-1 \ - --ulimit stack=67108864 \ - nvcr.io/nvidia/pytorch:$nv_pytorch_tag - -Next, you need to install the dependencies. - -Apex -^^^^ - -NVIDIA Apex is required for LLM and MM domains. Although Apex is pre-installed in the NVIDIA PyTorch container, you may need to update it to a newer version. - -To install Apex, run the following code: - -.. code-block:: bash - - git clone https://github.com/NVIDIA/apex.git - cd apex - git checkout $apex_commit - pip install . -v --no-build-isolation --disable-pip-version-check --no-cache-dir --config-settings "--build-option=--cpp_ext --cuda_ext --fast_layer_norm --distributed_adam --deprecated_fused_adam --group_norm" - -When attempting to install Apex separately from the NVIDIA PyTorch container, you might encounter an error if the CUDA version on your system is different from the one used to compile PyTorch. To bypass this error, you can comment out the relevant line in the setup file located in the Apex repository on GitHub here: https://github.com/NVIDIA/apex/blob/master/setup.py#L32. - -cuda-nvprof is needed to install Apex. The version should match the CUDA version that you are using. - -To install cuda-nvprof, run the following code: - -.. code-block:: bash - - conda install -c nvidia cuda-nvprof=11.8 - -Finally, install the packaging: - -.. code-block:: bash - - pip install packaging - -To install the most recent versions of Apex locally, it might be necessary to remove the `pyproject.toml` file from the Apex directory. - -Transformer Engine -^^^^^^^^^^^^^^^^^^ - -NVIDIA Transformer Engine is required for LLM and MM domains. Although the Transformer Engine is pre-installed in the NVIDIA PyTorch container, you may need to update it to a newer version. - -The Transformer Engine facilitates training with FP8 precision on NVIDIA Hopper GPUs and introduces many enhancements for the training of Transformer-based models. Refer to `Transformer Enginer `_ for information. - -To install Transformer Engine, run the following code: - -.. code-block:: bash - - git clone https://github.com/NVIDIA/TransformerEngine.git && \ - cd TransformerEngine && \ - git checkout $te_commit && \ - git submodule init && git submodule update && \ - NVTE_FRAMEWORK=pytorch NVTE_WITH_USERBUFFERS=1 MPI_HOME=/usr/local/mpi pip install . - -Transformer Engine requires PyTorch to be built with at least CUDA 11.8. - -Megatron Core -^^^^^^^^^^^^^ - -Megatron Core is required for LLM and MM domains. Megatron Core is a library for scaling large Transformer-based models. NeMo LLMs and MMs leverage Megatron Core for model parallelism, transformer architectures, and optimized PyTorch datasets. - -To install Megatron Core, run the following code: - -.. code-block:: bash - - git clone https://github.com/NVIDIA/Megatron-LM.git && \ - cd Megatron-LM && \ - git checkout $mcore_commit && \ - pip install . && \ - cd megatron/core/datasets && \ - make - -NeMo Text Processing --------------------- - -NeMo Text Processing, specifically Inverse Text Normalization, is now a separate repository. It is located here: `https://github.com/NVIDIA/NeMo-text-processing `_. - -Docker Containers ------------------ - -NeMo containers are launched concurrently with NeMo version updates. NeMo Framework now supports LLMs, MMs, ASR, and TTS in a single consolidated Docker container. You can find additional information about released containers on the `NeMo releases page `_. - -To use a pre-built container, run the following code: - -.. code-block:: bash - - docker pull nvcr.io/nvidia/nemo:24.05 - -To build a nemo container with Dockerfile from a branch, run the following code: - -.. code-block:: bash - - DOCKER_BUILDKIT=1 docker build -f Dockerfile -t nemo:latest - -If you choose to work with the main branch, we recommend using NVIDIA's PyTorch container version 23.10-py3 and then installing from GitHub. - -.. code-block:: bash - - docker run --gpus all -it --rm -v :/NeMo --shm-size=8g \ - -p 8888:8888 -p 6006:6006 --ulimit memlock=-1 --ulimit \ - stack=67108864 --device=/dev/snd nvcr.io/nvidia/pytorch:23.10-py3 - - -Future Work ------------ - -The NeMo Framework Launcher does not currently support ASR and TTS training, but it will soon. - -Discussions Board ------------------ - -FAQ can be found on the NeMo `Discussions board `_. You are welcome to ask questions or start discussions on the board. - -Contribute to NeMo ------------------- - -We welcome community contributions! Please refer to `CONTRIBUTING.md `_ for the process. - -Publications ------------------- - -We provide an ever-growing list of `publications `_ that utilize the NeMo Framework. - -To contribute an article to the collection, please submit a pull request to the ``gh-pages-src`` branch of this repository. For detailed information, please consult the README located at the `gh-pages-src branch `_. - -Licenses --------- - -* `NeMo GitHub Apache 2.0 license `__ - -* NeMo is licensed under the `NVIDIA AI PRODUCT AGREEMENT `__. By pulling and using the container, you accept the terms and conditions of this license. diff --git a/docs/source/asr/speech_intent_slot/api.rst b/docs/source/asr/speech_intent_slot/api.rst index d45f24f807f6..4a45715f78f7 100644 --- a/docs/source/asr/speech_intent_slot/api.rst +++ b/docs/source/asr/speech_intent_slot/api.rst @@ -15,10 +15,10 @@ Mixins .. autoclass:: nemo.collections.asr.parts.mixins.ASRModuleMixin :show-inheritance: :members: - :no-index: + :noindex: .. autoclass:: nemo.collections.asr.parts.mixins.ASRBPEMixin :show-inheritance: :members: - :no-index: + :noindex: diff --git a/docs/source/asr/ssl/api.rst b/docs/source/asr/ssl/api.rst index 8e6f83986032..77614e9ad5e3 100644 --- a/docs/source/asr/ssl/api.rst +++ b/docs/source/asr/ssl/api.rst @@ -15,12 +15,12 @@ Mixins .. autoclass:: nemo.collections.asr.parts.mixins.mixins.ASRModuleMixin :show-inheritance: :members: - :no-index: + :noindex: .. autoclass:: nemo.core.classes.mixins.access_mixins.AccessMixin :show-inheritance: :members: - :no-index: + :noindex: diff --git a/docs/source/ckpt_converters/convert_mlm.rst b/docs/source/checkpoints/convert_mlm.rst similarity index 100% rename from docs/source/ckpt_converters/convert_mlm.rst rename to docs/source/checkpoints/convert_mlm.rst diff --git a/docs/source/ckpt_converters/dev_guide.rst b/docs/source/checkpoints/dev_guide.rst similarity index 100% rename from docs/source/ckpt_converters/dev_guide.rst rename to docs/source/checkpoints/dev_guide.rst diff --git a/docs/source/checkpoints/dist_ckpt.rst b/docs/source/checkpoints/dist_ckpt.rst new file mode 100644 index 000000000000..31c89f64b55e --- /dev/null +++ b/docs/source/checkpoints/dist_ckpt.rst @@ -0,0 +1,427 @@ +Distributed Checkpoints +======================= + +This guide provides details about the distributed checkpoints format from Megatron Core. + + +Introduction +------------ + +Model parallel training requires parallelism-aware checkpointing. +Megatron Core provides a checkpointing library capable of handling all types of parallelisms used in LLM training. +Although the distributed checkpointing library is targeted at the Megatron Core model, it can also be used with other models, as long as proper integration is implemented. + +The library provides two main entrypoints: ``dist_checkpointing.save`` and ``dist_checkpointing.load`` which are meant to replace the ``torch.save`` and ``torch.load`` in the regular checkpointing flow. +Apart from that, it provides a mechanism to define how different types of local tensors should be combined and split in the global checkpoint. + + +Basic Sharding +-------------- + +The main way to define the relationship of a plain, local PyTorch tensor to tensors on other ranks is by wrapping it in a ``ShardedTensor`` class. +This allows to express the fact that a given local tensor is part of a larger *grid* of tensors of a given shape at a given offset. +Instead of saving a simple state dict with ``torch.Tensor``, we save a *sharded* state dict with ``dist_checkpointing.ShardedTensor``. + +Example: assume we have a tensor (composed of 128 elements) divided equally across the whole workload which we want to save and load with different number of ranks. + +.. code-block:: python + + from pathlib import Path + + import torch + + from megatron.core import dist_checkpointing + + # Setup + ckpt_root = Path('/tmp/checkpoints') + native_ckpt_root = ckpt_root / 'native' + native_ckpt_root.mkdir(exist_ok=True, parents=True) + dist_ckpt_root = ckpt_root / 'dist_ckpt' + dist_ckpt_root.mkdir(exist_ok=True, parents=True) + + torch.distributed.init_process_group() + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + + # Local tensor to save + assert 128 % world_size == 0 + num_elems_per_rank = 128 // world_size + local_ten = torch.arange(start=num_elems_per_rank * rank, + end=num_elems_per_rank * (rank + 1)) + + # Native checkpoint save + state_dict = { + 'weight': local_ten + } + torch.save(state_dict, native_ckpt_root / f'ckpt_{rank}.pt') + + # Distributed checkpoint save + # `(0, rank, world_size)` describes that `weight` ShardedTensor is sharded into `world_size` pieces + # along the 0th dimension and `local_ten` is the shard at position `rank`. + # Together, all shards implicitly form a "global" `torch.arange(128)` tensor. + sharded_state_dict = { + 'weight': dist_checkpointing.ShardedTensor.from_rank_offsets('weight', local_ten, (0, rank, world_size)) + } + dist_checkpointing.save(sharded_state_dict, dist_ckpt_root) + +During load, the distributed checkpoint can be easily read even if the job size changes (contrary to native checkpoints that require the same number of ranks). +The main difference with wrt. ``torch.load`` is that the user has to provide the definition of the sharded state dict that needs to be loaded. + +.. code-block:: python + + from pathlib import Path + + import torch + + from megatron.core import dist_checkpointing + + ckpt_root = Path('/tmp/checkpoints') + dist_ckpt_root = ckpt_root / 'dist_ckpt' + + torch.distributed.init_process_group() + world_size = torch.distributed.get_world_size() + rank = torch.distributed.get_rank() + assert 128 % world_size == 0 + num_elems_per_rank = 128 // world_size + + # Local tensor to load + local_ten = torch.empty(num_elems_per_rank) + sharded_state_dict = { + 'weight': dist_checkpointing.ShardedTensor.from_rank_offsets('weight', local_ten, (0, rank, world_size)) + } + loaded_state_dict = dist_checkpointing.load(sharded_state_dict, dist_ckpt_root) + expected_local_ten = torch.arange(start=num_elems_per_rank * rank, end=num_elems_per_rank * (rank + 1)) + assert torch.all(loaded_state_dict['weight'] == expected_local_ten) + + # With torch.save and torch.load, we would have to load all files that contain + # parts of the desired tensor in new configuration and concatenate appropriate fragments. + # For some distributed checkpoint backends this is actually what happens underneath. + + +Supported Entities +------------------ +The distributed checkpointing library supports saving and loading of different objects in different configurations. + +A sharded state dict is a (possibly nested) Python dictionary or list with the following elements: + +1. ShardedBase + a. ShardedTensor + b. ShardedObject + c. ShardedTensorFactory +2. LocalNonpersitentObject +3. Arbitrary object + + +ShardedBase +^^^^^^^^^^^ +ShardedBase is the base class for expressing any kind of sharding. +Each sharded entity must be uniquely identified by its ``key``, carry some ``data`` to be saved or loaded, and define ``replica_id`` which helps identify data redundancy. + +Note that the ``key`` doesn't have to (and usually doesn't) correspond to the key in the state dict. +The key in the state dict is ephemeral, while the ``ShardedTensor.key`` is used to identify the tensor in the checkpoint. + +In the following example, the state dict to be loaded contains different keys than the saved one. +What matters is that the ``ShardedTensor.key`` are equivalent (``tensor-A``): + +.. code-block:: python + + import torch + + from megatron.core import dist_checkpointing + + # Checkpoint saved with some key in the state dict that is eventually ignored + model = ... + ckpt_dir = ... + sharded_state_dict = { + 'ignored': dist_checkpointing.ShardedTensor('tensor-A', ...) + } + dist_checkpointing.save(sharded_state_dict, ckpt_dir) + + # During loading, all that matters is the ShardedTensor.key. + sharded_state_dict = { + 'different-key': dist_checkpointing.ShardedTensor('tensor-A', ...) + } + loaded_state_dict = dist_checkpointing.load(sharded_state_dict, ckpt_dir) + assert 'ignored' not in loaded_state_dict + assert 'tensor-A' not in loaded_state_dict + assert isinstance(loaded_state_dict['different-key'], torch.Tensor) + + # The key in the state dict is important only from the subsequent `model.load_state_dict` + # that usually happens after `dist_checkpointing.load` - the state dict must have + # the structure and keys corresponding to the model structure and submodule names. + model.load_state_dict(loaded_state_dict) + +ShardedTensor +^^^^^^^^^^^^^ +``ShardedTensor`` is the primary use case for distributed checkpointing - tensor sharding. +It defines how PyTorch tensors are distributed across the workload. +See the `Tensors transformations`_ section for more details on ShardedTensors. + +ShardedObject +^^^^^^^^^^^^^ +Sometimes there is a need to save arbitrary objects across the ranks. +ShardedObject allows to structure those objects into arrays of objects with a fixed ``global_shape`` and save/load parts of the arrays on specific ranks. + +ShardedTensorFactory +^^^^^^^^^^^^^^^^^^^^ +The ShardedTensorFactory class defers tensors transformations until they are actually saved. +A factory can expand a tensor into an arbitrary sub state dict (including all supported entities listed above). +The need for such deferral will be explained in the `Tensors transformations`_ section. + +LocalNonpersistentObject +^^^^^^^^^^^^^^^^^^^^^^^^ +LocalNonpersistentObject is a simple wrapper indicating that the object wrapped with this class should end up in the final loaded state dict during loading. +During saving such objects are ignored. + +Arbitrary Object +^^^^^^^^^^^^^^^^ +All objects different than dicts, lists, and the instances of the classes listed above are treated as "common" objects. + +During saving, all such objects in the sharded state dict passed to ``dist_checkpointing.save`` are assumed to be duplicated across ranks. Therefore, they are saved only by a single coordinator rank (rank 0). + +During loading, all such objects in the sharded state dict passed to ``dist_checkpointing.load`` are simply ignored - the loaded state dict contains only "common" objects that are were actually saved in the checkpoint. + + + + +Entry Points +------------ +There are several useful user entry points for checkpoint saving and loading. + +dist_checkpointing.save +^^^^^^^^^^^^^^^^^^^^^^^ +The ``dist_checkpointing.save`` function is the only entry point for checkpoint saving. +It requires providing a sharded state dict to save and saving strategies for handling different entities (see `Save and load strategies`_ for detailed explanation). +The sharded state dict is processed in the following way (see also ``save`` function `documentation `_): + +1. The ShardedTensorFactories are applied. +2. The LocalNonPersistentObjects are extracted from the sharded state dict and ignored. +3. The ShardedBase objects are extracted. +4. All other objects are treated as "common" and saved according to a sharded strategy (see `Save and load strategies`_). +5. All ShardedObjects are extracted from point (3) objects and saved with a common strategy (see `Save and load strategies`_). +6. All ShardedTensors are saved. +7. The ``metadata.json`` file with backend and version metadata is saved to the checkpoint directory. + +dist_checkpointing.load +^^^^^^^^^^^^^^^^^^^^^^^ +The ``dist_checkpointing.load`` function is the main entry point for checkpoint loading. +It requires providing a sharded state dict (in order to implicitly define mappings between local tensors and checkpoint tensors) and loading strategies. +In practice, the same sharded state dict can be usually used for both saving and loading (the sharded state dict for loading will just contain tensors with uninitialized data). + +When the sharded state dict is provided as input, it is processed in the following way (see also ``load`` function `documentation `_): + +1. The "common" state dict is loaded from the checkpoint. This forms the base of the resulting state dict. +2. The ShardedTensorFactories from the input sharded state dict are applied. +3. The LocalNonPersistentObjects are extracted from the input sharded state dict, unwrapped and added to the resulting state dict. +4. The ShardedObjects are extracted and loaded from the checkpoint into the resulting state dict. +5. The ShardedTensors are extracted and loaded from the checkpoint into the resulting state dict. +6. Factory merges are applied (see `Optimizers`_ for explanation). + +This results in a *regular* state dict with plain tensors that can be further processed by the application (which usually means running ``model.load_state_dict(state_dict)``). + + +dist_checkpointing.load_common_state_dict +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +The ``dist_checkpointing.load_common_state_dict`` function is an entry point that allows loading only the “common” part of the checkpoints. +Most of the checkpoint config and metadata can be loaded with this method, which allows skipping data loading in order to take decisions regarding checkpoint config, version, etc. + +dist_checkpointing.load_tensors_metadata +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +The ``dist_checkpointing.load_tensors_metadata`` function is an entry point that allows reading all ShardedTensors metadata from the checkpoint without loading any data. +The result is a sharded state dict with trivial sharding (every tensor is sharded into one big shard). + +dist_checkpointing.load_plain_tensors +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +The ``dist_checkpointing.load_plain_tensors`` function is an entry point that allows reading sharded tensors stored in the checkpoint without any sharding (as plain tensors). +This function is simply a composition of ``load_tensors_metadata`` and ``save``. + +Save and Load Strategies +------------------------ +There are multiple ways to save a sharded state dict into a serialized checkpoint. They can be provided by the user as saving and loading strategies (e.g. ``TorchDistLoadShardedStrategy`` and ``TorchDistSaveShardedStrategy`` as shown below). + +There are four types of strategies: + +1. Saving strategy for ShardedTensors +2. Saving strategy for "common" data +3. Loading strategy for ShardedTensors +4. Loading strategy for "common" data + +Additionally, ShardedObjects are handled with either "sharded" or "common" strategy depending on its capabilities (``can_handle_sharded_objects`` property). + +Each saving strategy is associated with a ``backend`` and a ``version``. +Each loading strategy can be associated with multiple values of ``backend`` and ``version`` it can load. +For a given backend and version, the composition of every saving and loading strategy **must be functionally equivalent**. +Strategies are the main way to introduce optimizations to the saving and loading algorithm without altering the checkpoint format. + +In the following example, the "fully parallel" wrappers modify the saving and loading *algorithm*, but the underlying checkpoint *format* (and ``backend`` in consequence) stays the same. +It makes the ``basic_save_load`` and ``fully_parallel_save_load`` functions equivalent: + +.. code-block:: python + + from megatron.core import dist_checkpointing + from megatron.core.dist_checkpointing.strategies.torch import ( + TorchDistLoadShardedStrategy, + TorchDistSaveShardedStrategy + ) + from megatron.core.dist_checkpointing.strategies.fully_parallel import ( + FullyParallelLoadStrategyWrapper, + FullyParallelSaveStrategyWrapper + ) + + # Base save and load strategies defining a regular (non-parallel) save + base_save_strategy = TorchDistSaveShardedStrategy('torch_dist', 1) + base_load_strategy = TorchDistLoadShardedStrategy() + + def basic_save_load(sharded_state_dict, ckpt_dir): + """ Save and load using some basic strategies. """ + dist_checkpointing.save(sharded_state_dict, ckpt_dir, base_save_strategy) + return dist_checkpointing.load(sharded_state_dict, ckpt_dir, base_load_strategy) + + + def fully_parallel_save_load(sharded_state_dict): + """ Save and load using basic strategies wrapped with parallelization strategies. """ + fully_parallel_save_strategy = FullyParallelSaveStrategyWrapper(base_save_strategy) + # "fully parallel" wrapper modifies the saving strategy, but not the underlying format + assert fully_parallel_save_strategy.backend == base_save_strategy.backend == 'torch_dist' + fully_parallel_load_strategy = FullyParallelLoadStrategyWrapper(base_load_strategy) + dist_checkpointing.save(sharded_state_dict, ckpt_dir, fully_parallel_save_strategy) + return dist_checkpointing.load(sharded_state_dict, ckpt_dir, fully_parallel_load_strategy) + + +The ``dist_checkpointing`` package provides default strategies for some sharded backends, so it's enough to specify a tuple ``(backend, version)`` as a saving strategy. +Backends and versions are stored in a ``metadata.json`` file inside the checkpoint so that the loading strategy can be determined automatically (provided that there exists a default loading strategy for a given backend and version). + +For "sharded" strategies, currently the backends supported by default are based on `PyTorch Distributed`_ format (``torch_dist`` backend) and `Zarr`_ format (``zarr`` backend). +Additionally, as shown in the example above, some wrappers are provided that enable it to parallelize the save and load across the whole workload (assuming some data duplication). + +For "common" strategies, currently the only supported one is ``torch`` which saves "common" data into a ``common.pt`` file. + +PyTorch Distributed +^^^^^^^^^^^^^^^^^^^ +The PyTorch Distributed based checkpoint format uses the ``torch.distributed.checkpoint`` package in order to serialize the checkpoints to storage. +The Megatron Core sharded state dicts are translated into ``torch.distributed.SharedTensor`` and then ``torch.distributed.checkpoint`` primitives are used to serialize such state dicts. +Even though Megatron Core provides several saving optimizations, the underlying checkpoint can still be read with native `PyTorch loading methods `_. +Note that the checkpoint still follows the ``dist_checkpointing`` package format by providing additional ``common.pt`` and ``metadata.json`` files described above. + +PyTorch Distributed is a recommended checkpoint format. + +Zarr +^^^^ +The Zarr based checkpoint format uses the `Zarr `__ library in order to serialize the checkpoints to storage. +This format is deprecated and it's recommended to transition to the ``torch_dist`` format (using this `converter script `_). + +Optimizers +---------- +The Optimizers module provides helper tools to the user to simplify constructing ShardedTensors for optimizer states. +The ShardedTensors that define local-to-sharded tensors mapping for model parameters should be reused for optimizer states to avoid code duplication. + +To this end, the ``dist_checkpointing.optimizers.get_param_id_to_sharded_param_map`` function can build a mapping between optimizer params ids and model ShardedTensors. +This mapping can be used by the ``dist_checkpointing.optimizers.optim_state_to_sharding_state`` function or application code (for non-standard use cases) to construct optimizer sharded state dict with ShardedTensors. +This should support most optimizer cases, but some of them might require custom sharded state dict creation. +A good example is a Distributed Optimizer which flattens the parameters - see `Tensors transformations`_ section for more details. + +Note: In order to reuse model SharderTensors to create optimizer ShardedTensors, the model **SharderTensors must wrap model parameters**, not just tensors +(obtaining a state dict with model parameters can be achieved by passing ``keep_vars=True`` to the model ``state_dict`` function). +Otherwise the correspondence between model ShardedTensors and optimizer states is impossible to recreate. +This is the reason for introducing ShardedTensorFactories - we have to register the original model parameter as ``ShardedTensorFactories.data`` and apply any subsequent transformations as a factory function in order to make sure that the same transformation can be applied to the optimizer states. +Even if the model parameters transformations are complex, in most cases the optimizer state dict is easy to recreate based on the model ShardedTensors and ShardedTensorFactories, +e.g. `FP32Optimizer.sharded_state_dict `_ is just a matter of two generic ``get_param_id_to_sharded_param_map`` and ``optim_state_to_sharding_state`` function calls regardless of the model parameters complexity. + + +Tensors Transformations +----------------------- +The ShardedTensor API enables the declaration of basic transformations that should be performed during saving and loading. + +Shape Mismatch +^^^^^^^^^^^^^^ +The ``allow_shape_mismatch`` flag relaxes the requirement of matching global tensor shapes during loading. +Extra padding is filled with zeros or stripped depending on the mismatch kind. +This is useful for layers like embedding which might be padded according to parallelism for performance reasons. + +Flattening +^^^^^^^^^^ +The ``flattened_range`` attribute declares that ``ShardedTensor.data`` represents a slice of a flattened model parameter. +This corresponds to a transformation used in Distributed Optimizers which flattens the data and shards it along the data-parallel domain. + +Extra flattening comes with an efficiency challenge during checkpoint resharding. +Since flattening is applied after the global tensors is sharded into the grid of local chunks, loading after resharding requires accessing incontiguous data fragments. +An example solution for that is implemented in the `resharding `_ module and involves saving the flattened tensor with a different global shape than the original one. + +Example: For a global tensor ``[[0, 1, 2, 3, 4, 5], [6, 7, 8, 9, 10, 11]]`` with sharding by TP (tensor-parallel) over the second axis, here are the local shards if TP=2: + +.. list-table:: + :widths: 50 50 + :header-rows: 1 + + * - Rank + - Local shards + * - 0 + - ``[[0, 1, 2], [6, 7, 8]]`` + * - 1 + - ``[[3, 4, 5], [9, 10, 11]]`` + +After flattening and sharding by DP=3 (which would happen in the Megatron Core Distributed Optimizer), the resulting local shards are as follows: + +.. list-table:: + :widths: 50 50 + :header-rows: 1 + + * - Rank + - Local shards + * - 0 + - ``[0, 1]`` + * - 2 + - ``[2, 6]`` + * - 4 + - ``[7, 8]`` + * - 1 + - ``[3, 4]`` + * - 3 + - ``[5, 9]`` + * - 5 + - ``[10, 11]`` + +After sharding by TP=6 and flattening and sharding by DP=1, the resulting local shards are as follows: + + +.. list-table:: + :widths: 50 50 + :header-rows: 1 + + * - Rank + - Local shards + * - 0 + - ``[0, 6]`` + * - 1 + - ``[1, 7]`` + * - 2 + - ``[2, 8]`` + * - 3 + - ``[3, 9]`` + * - 4 + - ``[4, 10]`` + * - 5 + - ``[5, 11]`` + + +Arbitrary Transformations +^^^^^^^^^^^^^^^^^^^^^^^^^ +The way to apply arbitrary transformations to the tensors during saving and loading is with ShardedTensorFactory. +It defines such a transformation as a function that can be reapplied to any ShardedTensor (in particular, a ShardedTensor representing optimizer states). +Such "build" function is also tied to a "merge" function that can apply an inverse transformation during loading. + +If handling an optimizer state is not required, such a transformation could be also applied directly during sharded state dict creation. +In order to apply such transformation both to model and optimizer parameters in a consistent manner, it's necessary to encode them as factory functions (with original model parameter as the ``data`` input so that the optimizer params can be properly mapped to model ShardedTensors). + +Note that implementing some transformations might be challenging or impossible while supporting flattening for a Distributed Optimizer case. +For example, if the model weights are supposed to be transposed in the checkpoint, it's almost impossible to implement a performant factory function that is capable of transposing a flattened and sliced tensor. This is because the flattening and slicing should happen in the transposed dimension. + +Application Integration +----------------------- +The ``dist_checkpointing`` package provides all general mechanisms for saving arbitrary distributed checkpoints. +The only thing required from the application side is preparing a sharded state dict with ShardedTensors, ShardedObjects, etc. (representing the sharding of the data employed by the application) +and using the ``dist_checkpointing.save`` and ``dist_checkpointing.load`` entrypoints as replacements for ``torch.save`` and ``torch.load``. + +In Megatron Core, the sharded state dictionary preparation is already implemented in a ``sharded_state_dict`` method which creates the sharded state dicts in a composable way. +For other applications (e.g. with simpler types of supported parallelisms) it might be possible to apply a straightforward conversion from a regular model state dict into a sharded state dict. + diff --git a/docs/source/checkpoints/intro.rst b/docs/source/checkpoints/intro.rst new file mode 100644 index 000000000000..7c7154d64015 --- /dev/null +++ b/docs/source/checkpoints/intro.rst @@ -0,0 +1,64 @@ +Checkpoints +=========== + + +In this section, we present key functionalities of NVIDIA NeMo related to checkpoint management. + +Understanding Checkpoint Formats +-------------------------------- + +A ``.nemo`` checkpoint is fundamentally a tar file that bundles the model configurations (specified inside a YAML file), model weights (inside a ``.ckpt`` file), and other artifacts like tokenizer models or vocabulary files. This consolidated design streamlines sharing, loading, tuning, evaluating, and inference. + +In contrast, the ``.ckpt`` file, created during PyTorch Lightning training, contains both the model weights and the optimizer states, and is usually used to resume training. + +Sharded Model Weights +--------------------- + +Within ``.nemo`` or ``.ckpt`` checkpoints, the model weights could be saved in either a regular format (one file called ``model_weights.ckpt`` inside model parallelism folders) or a sharded format (a folder called ``model_weights``). + +With sharded model weights, you can save and load the state of your training script with multiple GPUs or nodes more efficiently and avoid the need to change model partitions when you resume tuning with a different model parallelism setup. + +NeMo supports the distributed (sharded) checkpoint format from Megatron Core. Megatron Core supports two checkpoint backends: PyTorch-based (recommended) and Zarr-based (deprecated). +For a detailed explanation check the :doc:`dist_ckpt` guide. + + +Quantized Checkpoints +--------------------- + +NeMo provides a :doc:`Post-Training Quantization <../nlp/quantization>` workflow that allows you to convert regular ``.nemo`` models into a `TensorRT-LLM checkpoint `_, commonly referred to as ``.qnemo`` checkpoints in NeMo. These ``.qnemo`` checkpoints can then be used with the `NVIDIA TensorRT-LLM library `_ for efficient inference. + +A ``.qnemo`` checkpoint, similar to ``.nemo`` checkpoints, is a tar file that bundles the model configuration specified in the ``config.json`` file along with the ``rank{i}.safetensors`` files. These ``.safetensors`` files store the model weights for each rank individually. In addition, a ``tokenizer_config.yaml`` file is saved, containing only the tokenizer section from the original NeMo ``model_config.yaml`` file. This configuration file defines the tokenizer used by the given model. + +When working with large quantized LLMs, it is recommended that you leave the checkpoint uncompressed as a directory rather than a tar file. You can control this behavior by setting the ``compress`` flag when exporting quantized models in `PTQ configuration file `_. + +The following example shows the contents of a quantized model intended to be served using two GPUs (ranks): + +.. code-block:: bash + + model-qnemo + ├── config.json + ├── rank0.safetensors + ├── rank1.safetensors + ├── tokenizer.model + └── tokenizer_config.yaml + +Community Checkpoint Converter +----------------------------- +We provide easy-to-use tools that enable users to convert community checkpoints into the NeMo format. These tools facilitate various operations, including resuming training, Supervised Fine-Tuning (SFT), Parameter-Efficient Fine-Tuning (PEFT), and deployment. For detailed instructions and guidelines, please refer to our documentation. + +We offer comprehensive guides to assist both end users and developers: + +- **User Guide**: Detailed steps on how to convert community model checkpoints for further training or deployment within NeMo. For more information, please see our :doc:`user_guide`. + +- **Developer Guide**: Instructions for developers on how to implement converters for community model checkpoints, allowing for broader compatibility and integration within the NeMo ecosystem. For development details, refer to our :doc:`dev_guide`. + +- **Megatron-LM Checkpoint Conversion**: NVIDIA NeMo and NVIDIA Megatron-LM share several foundational technologies. You can convert your GPT-style model checkpoints trained with Megatron-LM into the NeMo Framework using our scripts, see our :doc:`convert_mlm`. + +.. toctree:: + :maxdepth: 1 + :caption: NeMo Checkpoints + + dist_ckpt + user_guide + dev_guide + convert_mlm diff --git a/docs/source/ckpt_converters/user_guide.rst b/docs/source/checkpoints/user_guide.rst similarity index 100% rename from docs/source/ckpt_converters/user_guide.rst rename to docs/source/checkpoints/user_guide.rst diff --git a/docs/source/ckpt_converters/intro.rst b/docs/source/ckpt_converters/intro.rst deleted file mode 100644 index 6d4da83499fa..000000000000 --- a/docs/source/ckpt_converters/intro.rst +++ /dev/null @@ -1,22 +0,0 @@ -Community Checkpoint Converter -============================== - -We provide easy-to-use tools that enable users to convert community checkpoints into the NeMo format. These tools facilitate various operations, including resuming training, Sparse Fine-Tuning (SFT), Parameter-Efficient Fine-Tuning (PEFT), and deployment. For detailed instructions and guidelines, please refer to our documentation. - -We offer comprehensive guides to assist both end users and developers: - -- **User Guide**: Detailed steps on how to convert community model checkpoints for further training or deployment within NeMo. For more information, please see our :doc:`user_guide`. - -- **Developer Guide**: Instructions for developers on how to implement converters for community model checkpoints, allowing for broader compatibility and integration within the NeMo ecosystem. For development details, refer to our :doc:`dev_guide`. - -- **Megatron-LM Checkpoint Conversion**: NVIDIA NeMo and NVIDIA Megatron-LM share several foundational technologies. You can convert your GPT-style model checkpoints trained with Megatron-LM into the NeMo Framework using our scripts, see our :doc:`convert_mlm`. - -Access the user and developer guides directly through the links below: - -.. toctree:: - :maxdepth: 1 - :caption: Conversion Guides - - user_guide - dev_guide - convert_mlm diff --git a/docs/source/collections.rst b/docs/source/collections.rst index d4bea503513b..0198ef250ce3 100644 --- a/docs/source/collections.rst +++ b/docs/source/collections.rst @@ -25,6 +25,7 @@ Documentation for the individual collections multimodal/vlm/intro multimodal/text2img/intro multimodal/nerf/intro + mumtimoda/speech_llm/intro .. toctree:: :maxdepth: 1 diff --git a/docs/source/core/adapters/api.rst b/docs/source/core/adapters/api.rst index 8922c72d63eb..dee215ba0ed8 100644 --- a/docs/source/core/adapters/api.rst +++ b/docs/source/core/adapters/api.rst @@ -9,7 +9,7 @@ Core :members: :member-order: bysource :undoc-members: adapter_module_names - :no-index: + :noindex: ----- @@ -18,7 +18,7 @@ Core :members: :member-order: bysource :undoc-members: adapter_module_names - :no-index: + :noindex: ----- @@ -30,7 +30,7 @@ Adapter Networks :show-inheritance: :members: :member-order: bysource - :no-index: + :noindex: ----- @@ -38,7 +38,7 @@ Adapter Networks :show-inheritance: :members: :member-order: bysource - :no-index: + :noindex: ----- @@ -51,7 +51,7 @@ Adapter Strategies :members: :member-order: bysource :undoc-members: adapter_module_names - :no-index: + :noindex: ----- @@ -60,7 +60,7 @@ Adapter Strategies :members: :member-order: bysource :undoc-members: adapter_module_names - :no-index: + :noindex: ----- @@ -69,4 +69,4 @@ Adapter Strategies :members: :member-order: bysource :undoc-members: adapter_module_names - :no-index: + :noindex: diff --git a/docs/source/core/adapters/components.rst b/docs/source/core/adapters/components.rst index d8bed1b23a75..d4b38bc147b2 100644 --- a/docs/source/core/adapters/components.rst +++ b/docs/source/core/adapters/components.rst @@ -28,7 +28,7 @@ Adapter modules represent the functional form of the adapter. We discuss an exam :show-inheritance: :members: :member-order: bysource - :no-index: + :noindex: ----- @@ -36,7 +36,7 @@ Adapter modules represent the functional form of the adapter. We discuss an exam :show-inheritance: :members: :member-order: bysource - :no-index: + :noindex: Insertion Form - Module Adapters @@ -72,7 +72,7 @@ We discuss a simple residual additional connection strategy below - that accepts :members: :member-order: bysource :undoc-members: adapter_module_names - :no-index: + :noindex: ----- @@ -81,7 +81,7 @@ We discuss a simple residual additional connection strategy below - that accepts :members: :member-order: bysource :undoc-members: adapter_module_names - :no-index: + :noindex: ----- diff --git a/docs/source/core/exp_manager.rst b/docs/source/core/exp_manager.rst index ce5f7a9cb087..6daa5070a16e 100644 --- a/docs/source/core/exp_manager.rst +++ b/docs/source/core/exp_manager.rst @@ -193,170 +193,132 @@ and stability. To use EMA, simply set the following via YAML or :class:`~nemo.ut every_n_steps: 1 # How often to update EMA weights validate_original_weights: False # Whether to use original weights for validation calculation or EMA weights -Support for Preemption ----------------------- +.. Support for Preemption + ---------------------- -.. _exp_manager_preemption_support-label: + .. _exp_manager_preemption_support-label: -NeMo adds support for a callback upon preemption while running the models on clusters. The callback takes care of saving the current state of training via the ``.ckpt`` -file followed by a graceful exit from the run. The checkpoint saved upon preemption has the ``*last.ckpt`` suffix and replaces the previously saved last checkpoints. -This feature is useful to increase utilization on clusters. -The ``PreemptionCallback`` is enabled by default. To disable it simply add ``create_preemption_callback: False`` under exp_manager in the config YAML file. + NeMo adds support for a callback upon preemption while running the models on clusters. The callback takes care of saving the current state of training via the ``.ckpt`` + file followed by a graceful exit from the run. The checkpoint saved upon preemption has the ``*last.ckpt`` suffix and replaces the previously saved last checkpoints. + This feature is useful to increase utilization on clusters. + The ``PreemptionCallback`` is enabled by default. To disable it simply add ``create_preemption_callback: False`` under exp_manager in the config YAML file. -Stragglers Detection ----------------------- + Stragglers Detection + ---------------------- -.. _exp_manager_straggler_det_support-label: + .. _exp_manager_straggler_det_support-label: -.. note:: - Stragglers Detection feature is included in the optional NeMo resiliency package. + .. note:: + Stragglers Detection feature is included in the optional NeMo resiliency package. -Distributed training can be affected by stragglers, which are slow workers that slow down the overall training process. -NeMo provides a straggler detection feature that can identify slower GPUs. + Distributed training can be affected by stragglers, which are slow workers that slow down the overall training process. + NeMo provides a straggler detection feature that can identify slower GPUs. -This feature is implemented in the ``StragglerDetectionCallback``, which is disabled by default. + This feature is implemented in the ``StragglerDetectionCallback``, which is disabled by default. -The callback computes normalized GPU performance scores, which are scalar values ranging from 0.0 (worst) to 1.0 (best). -A performance score can be interpreted as the ratio of current performance to reference performance. + The callback computes normalized GPU performance scores, which are scalar values ranging from 0.0 (worst) to 1.0 (best). + A performance score can be interpreted as the ratio of current performance to reference performance. -There are two types of performance scores provided by the callback: - - Relative GPU performance score: The best-performing GPU in the workload is used as a reference. - - Individual GPU performance score: The best historical performance of the GPU is used as a reference. + There are two types of performance scores provided by the callback: + - Relative GPU performance score: The best-performing GPU in the workload is used as a reference. + - Individual GPU performance score: The best historical performance of the GPU is used as a reference. -Examples: - - If the relative performance score is 0.5, it means that a GPU is twice slower than the fastest GPU. - - If the individual performance score is 0.5, it means that a GPU is twice slower than its best observed performance. + Examples: + - If the relative performance score is 0.5, it means that a GPU is twice slower than the fastest GPU. + - If the individual performance score is 0.5, it means that a GPU is twice slower than its best observed performance. -If a GPU performance score drops below the specified threshold, it is identified as a straggler. + If a GPU performance score drops below the specified threshold, it is identified as a straggler. -To enable straggler detection, add ``create_straggler_detection_callback: True`` under exp_manager in the config YAML file. -You might also want to adjust the callback parameters: + To enable straggler detection, add ``create_straggler_detection_callback: True`` under exp_manager in the config YAML file. + You might also want to adjust the callback parameters: -.. code-block:: yaml - - exp_manager: - ... - create_straggler_detection_callback: True - straggler_detection_callback_params: - report_time_interval: 300 # Interval [seconds] of the straggler check - calc_relative_gpu_perf: True # Calculate relative GPU performance - calc_individual_gpu_perf: True # Calculate individual GPU performance - num_gpu_perf_scores_to_log: 5 # Log 5 best and 5 worst GPU performance scores, even if no stragglers are detected - gpu_relative_perf_threshold: 0.7 # Threshold for relative GPU performance scores - gpu_individual_perf_threshold: 0.7 # Threshold for individual GPU performance scores - stop_if_detected: True # Terminate the workload if stragglers are detected - -Straggler detection might involve inter-rank synchronization, and should be invoked with reasonable frequency (e.g. every few minutes). - -.. _exp_manager_straggler_det_support-label: - -.. note:: - Stragglers Detection feature is included in the optional NeMo resiliency package. + .. code-block:: yaml -Distributed training can be affected by stragglers, which are slow workers that slow down the overall training process. -NeMo provides a straggler detection feature that can identify slower GPUs. + exp_manager: + ... + create_straggler_detection_callback: True + straggler_detection_callback_params: + report_time_interval: 300 # Interval [seconds] of the straggler check + calc_relative_gpu_perf: True # Calculate relative GPU performance + calc_individual_gpu_perf: True # Calculate individual GPU performance + num_gpu_perf_scores_to_log: 5 # Log 5 best and 5 worst GPU performance scores, even if no stragglers are detected + gpu_relative_perf_threshold: 0.7 # Threshold for relative GPU performance scores + gpu_individual_perf_threshold: 0.7 # Threshold for individual GPU performance scores + stop_if_detected: True # Terminate the workload if stragglers are detected -This feature is implemented in the ``StragglerDetectionCallback``, which is disabled by default. + Straggler detection might involve inter-rank synchronization, and should be invoked with reasonable frequency (e.g. every few minutes). -The callback computes normalized GPU performance scores, which are scalar values ranging from 0.0 (worst) to 1.0 (best). -A performance score can be interpreted as the ratio of current performance to reference performance. -There are two types of performance scores provided by the callback: - - Relative GPU performance score: The best-performing GPU in the workload is used as a reference. - - Individual GPU performance score: The best historical performance of the GPU is used as a reference. +.. Fault Tolerance + --------------- -Examples: - - If the relative performance score is 0.5, it means that a GPU is twice slower than the fastest GPU. - - If the individual performance score is 0.5, it means that a GPU is twice slower than its best observed performance. + .. _exp_manager_fault_tolerance_support-label: -If a GPU performance score drops below the specified threshold, it is identified as a straggler. + .. note:: + Fault Tolerance feature is included in the optional NeMo resiliency package. -To enable straggler detection, add ``create_straggler_detection_callback: True`` under exp_manager in the config YAML file. -You might also want to adjust the callback parameters: + When training DNN models, faults may occur, hindering the progress of the entire training process. + This is particularly common in distributed, multi-node training scenarios, with many nodes and GPUs involved. -.. code-block:: yaml + NeMo incorporates a fault tolerance mechanism to detect training halts. + In response, it can terminate a hung workload and, if requested, restart it from the last checkpoint. - exp_manager: - ... - create_straggler_detection_callback: True - straggler_detection_callback_params: - report_time_interval: 300 # Interval [seconds] of the straggler check - calc_relative_gpu_perf: True # Calculate relative GPU performance - calc_individual_gpu_perf: True # Calculate individual GPU performance - num_gpu_perf_scores_to_log: 5 # Log 5 best and 5 worst GPU performance scores, even if no stragglers are detected - gpu_relative_perf_threshold: 0.7 # Threshold for relative GPU performance scores - gpu_individual_perf_threshold: 0.7 # Threshold for individual GPU performance scores - stop_if_detected: True # Terminate the workload if stragglers are detected - -Straggler detection might involve inter-rank synchronization, and should be invoked with reasonable frequency (e.g. every few minutes). - -Fault Tolerance ---------------- + Fault tolerance ("FT") relies on a special launcher (``ft_launcher``), which is a modified ``torchrun``. + The FT launcher runs background processes called rank monitors. **You need to use ft_launcher to start + your workload if you are using FT**. I.e., `NeMo-Framework-Launcher `_ + can be used to generate SLURM batch scripts with FT support. -.. _exp_manager_fault_tolerance_support-label: + Each training process (rank) sends `heartbeats` to its monitor during training and validation steps. + If a rank monitor stops receiving `heartbeats`, a training failure is detected. -.. note:: - Fault Tolerance feature is included in the optional NeMo resiliency package. + Fault detection is implemented in the ``FaultToleranceCallback`` and is disabled by default. + To enable it, add a ``create_fault_tolerance_callback: True`` option under ``exp_manager`` in the + config YAML file. Additionally, you can customize FT parameters by adding ``fault_tolerance`` section: -When training DNN models, faults may occur, hindering the progress of the entire training process. -This is particularly common in distributed, multi-node training scenarios, with many nodes and GPUs involved. + .. code-block:: yaml -NeMo incorporates a fault tolerance mechanism to detect training halts. -In response, it can terminate a hung workload and, if requested, restart it from the last checkpoint. + exp_manager: + ... + create_fault_tolerance_callback: True + fault_tolerance: + initial_rank_heartbeat_timeout: 600 # wait for 10 minutes for the initial heartbeat + rank_heartbeat_timeout: 300 # wait for 5 minutes for subsequent heartbeats + calculate_timeouts: True # estimate more accurate timeouts based on observed intervals -Fault tolerance ("FT") relies on a special launcher (``ft_launcher``), which is a modified ``torchrun``. -The FT launcher runs background processes called rank monitors. **You need to use ft_launcher to start -your workload if you are using FT**. I.e., `NeMo-Framework-Launcher `_ -can be used to generate SLURM batch scripts with FT support. + Timeouts for fault detection need to be adjusted for a given workload: + * ``initial_rank_heartbeat_timeout`` should be long enough to allow for workload initialization. + * ``rank_heartbeat_timeout`` should be at least as long as the longest possible interval between steps. -Each training process (rank) sends `heartbeats` to its monitor during training and validation steps. -If a rank monitor stops receiving `heartbeats`, a training failure is detected. + **Importantly, `heartbeats` are not sent during checkpoint loading and saving**, so time for + checkpointing related operations should be taken into account. -Fault detection is implemented in the ``FaultToleranceCallback`` and is disabled by default. -To enable it, add a ``create_fault_tolerance_callback: True`` option under ``exp_manager`` in the -config YAML file. Additionally, you can customize FT parameters by adding ``fault_tolerance`` section: + If ``calculate_timeouts: True`` timeouts will be automatically estimated based on observed intervals. + Estimated timeouts take precedence over timeouts defined in the config file. **Timeouts are estimated + at the end of a training run, when checkpoint loading and saving were observed**. Hence, in a multi-part + training started from scratch, estimated timeouts won't be available during initial two runs. + Estimated timeouts are stored in a separate JSON file. -.. code-block:: yaml + ``max_subsequent_job_failures`` allows for the automatic continuation of training on a SLURM cluster. + This feature requires SLURM job to be scheduled with ``NeMo-Framework-Launcher``. If ``max_subsequent_job_failures`` + value is `>0` continuation job is prescheduled. It will continue the work until ``max_subsequent_job_failures`` + subsequent jobs failed (SLURM job exit code is `!= 0`) or the training is completed successfully + ("end of training" marker file is produced by the ``FaultToleranceCallback``, i.e. due to iters or time limit reached). - exp_manager: - ... - create_fault_tolerance_callback: True - fault_tolerance: - initial_rank_heartbeat_timeout: 600 # wait for 10 minutes for the initial heartbeat - rank_heartbeat_timeout: 300 # wait for 5 minutes for subsequent heartbeats - calculate_timeouts: True # estimate more accurate timeouts based on observed intervals - -Timeouts for fault detection need to be adjusted for a given workload: - * ``initial_rank_heartbeat_timeout`` should be long enough to allow for workload initialization. - * ``rank_heartbeat_timeout`` should be at least as long as the longest possible interval between steps. - -**Importantly, `heartbeats` are not sent during checkpoint loading and saving**, so time for -checkpointing related operations should be taken into account. - -If ``calculate_timeouts: True`` timeouts will be automatically estimated based on observed intervals. -Estimated timeouts take precedence over timeouts defined in the config file. **Timeouts are estimated after -checkpoint loading and saving was observed**. For example, in multi-part training started from scratch, -estimated timeouts won't be available during the first run. Estimated timeouts are stored in the checkpoint. - -``max_subsequent_job_failures`` allows for the automatic continuation of training on a SLURM cluster. -This feature requires SLURM job to be scheduled with ``NeMo-Framework-Launcher``. If ``max_subsequent_job_failures`` -value is `>0` continuation job is prescheduled. It will continue the work until ``max_subsequent_job_failures`` -subsequent jobs failed (SLURM job exit code is `!= 0`) or the training is completed successfully -("end of training" marker file is produced by the ``FaultToleranceCallback``, i.e. due to iters or time limit reached). - -All FT configuration items summary: - * ``workload_check_interval`` (float, default=5.0) Periodic workload check interval [seconds] in the workload monitor. - * ``initial_rank_heartbeat_timeout`` (Optional[float], default=60.0 * 60.0) Timeout for the first heartbeat from a rank. - * ``rank_heartbeat_timeout`` (Optional[float], default=45.0 * 60.0) Timeout for subsequent heartbeats from a rank. - * ``calculate_timeouts`` (bool, default=True) Try to calculate ``rank_heartbeat_timeout`` and ``initial_rank_heartbeat_timeout`` - based on the observed heartbeat intervals. - * ``rank_termination_signal`` (signal.Signals, default=signal.SIGKILL) Signal used to terminate the rank when failure is detected. - * ``log_level`` (str, default='INFO') Log level for the FT client and server(rank monitor). - * ``max_rank_restarts`` (int, default=0) Used by FT launcher. Max number of restarts for a rank. - If ``>0`` ranks will be restarted on existing nodes in case of a failure. - * ``max_subsequent_job_failures`` (int, default=0) Used by FT launcher. How many subsequent job failures are allowed until stopping autoresuming. - ``0`` means do not autoresume. - * ``additional_ft_launcher_args`` (str, default='') Additional FT launcher params (for advanced use). + All FT configuration items summary: + * ``workload_check_interval`` (float, default=5.0) Periodic workload check interval [seconds] in the workload monitor. + * ``initial_rank_heartbeat_timeout`` (Optional[float], default=60.0 * 60.0) Timeout [seconds] for the first heartbeat from a rank. + * ``rank_heartbeat_timeout`` (Optional[float], default=45.0 * 60.0) Timeout [seconds] for subsequent heartbeats from a rank. + * ``calculate_timeouts`` (bool, default=True) Try to calculate ``rank_heartbeat_timeout`` and ``initial_rank_heartbeat_timeout`` + based on the observed heartbeat intervals. + * ``safety_factor``: (float, default=5.0) When calculating the timeouts, multiply the maximum observed heartbeat interval + by this factor to obtain the timeout estimate. Can be made smaller for stable environments and larger for unstable ones. + * ``rank_termination_signal`` (signal.Signals, default=signal.SIGKILL) Signal used to terminate the rank when failure is detected. + * ``log_level`` (str, default='INFO') Log level for the FT client and server(rank monitor). + * ``max_rank_restarts`` (int, default=0) Used by FT launcher. Max number of restarts for a rank. + If ``>0`` ranks will be restarted on existing nodes in case of a failure. + * ``max_subsequent_job_failures`` (int, default=0) Used by FT launcher. How many subsequent job failures are allowed until stopping autoresuming. + ``0`` means do not autoresume. + * ``additional_ft_launcher_args`` (str, default='') Additional FT launcher params (for advanced use). .. _nemo_multirun-label: @@ -532,4 +494,4 @@ ExpManagerConfig :show-inheritance: :members: :member-order: bysource - :no-index: + :noindex: diff --git a/docs/source/core/neural_types.rst b/docs/source/core/neural_types.rst index ec7d94336c05..989cc8d998f4 100644 --- a/docs/source/core/neural_types.rst +++ b/docs/source/core/neural_types.rst @@ -24,7 +24,7 @@ Types are implemented in ``nemo.core.neural_types.NeuralType`` class. When you i are expected to include both *axes* information and *element type* information. .. autoclass:: nemo.core.neural_types.NeuralType - :no-index: + :noindex: Type Comparison Results ----------------------- @@ -32,7 +32,7 @@ Type Comparison Results When comparing two neural types, the following comparison results are generated. .. autoclass:: nemo.core.neural_types.NeuralTypeComparisonResult - :no-index: + :noindex: Examples -------- @@ -115,7 +115,7 @@ Custom element types It is possible to create user-defined element types to express the semantics of elements in your tensors. To do so, the user will need to inherit and implement abstract methods of the ``nemo.core.neural_types.elements.ElementType`` class .. autoclass:: nemo.core.neural_types.elements.ElementType - :no-index: + :noindex: Note that element types can be parametrized. Consider this example where it distinguishes between audio sampled at 8Khz and 16Khz. diff --git a/docs/source/features/mixed_precision.rst b/docs/source/features/mixed_precision.rst index ba0dfb4e945b..7e1e8c2f05fc 100644 --- a/docs/source/features/mixed_precision.rst +++ b/docs/source/features/mixed_precision.rst @@ -3,11 +3,21 @@ Mixed Precision Training ------------------------ -Mixed precision training significantly enhances computational efficiency by conducting operations in half-precision and fp8 formats, while selectively maintaining minimal data in single-precision to preserve critical information throughout key areas of the network. NeMo now supports FP16, BF16, and FP8 (via Transformer Engine) across most models. Further details will be provided shortly. +Mixed precision training significantly enhances computational efficiency by conducting operations in low-precision format, while selectively maintaining minimal data in single-precision to preserve critical information throughout key areas of the network. NeMo now supports FP16, BF16, and FP8 (via Transformer Engine) across most models. Further details will be provided shortly. -FP8 usage -========= +Half-precision Training +======================= + +NeMo supports half-precision (FP16 and BF16) computation training via Megatron Core and the distributed optimizer. +This training recipe uses half-precision in all layer computation keeping the model states (optimizer states and master parameters) in single-precision. +To avoid repeated data type casting at each layer computation, Megatron Core keeps a separate copy of half-precision parameters that is updated after each optimizer.step. + +Half-precision training is enabled when setting ``precision`` to either of ``fp16-mixed`` or ``bf16-mixed`` along with ``megatron_amp_O2=true``. +The parameter gradients are computed in the same half-precision, and the precision of gradient reduce-scatter across data-parallel GPUs can be set by ``optim.grad_sync_dtype``. + +FP8 Training +============ Overview ^^^^^^^^ diff --git a/docs/source/features/moe.rst b/docs/source/features/moe.rst new file mode 100644 index 000000000000..4c935f9f16a7 --- /dev/null +++ b/docs/source/features/moe.rst @@ -0,0 +1,75 @@ +Mixture of Experts +================== + +Overview +-------- + +NeMo Framework supports Mixture of Experts (MoE) in the feedforward block of the transformer layer. + +MoE is a machine learning technique where multiple specialized models (experts, +usually multi-layer perceptrons) are combined to solve a complex task. Each expert +focuses on a specific subtask or domain, while a gating network dynamically activates +the most appropriate expert based on the current input. + + +Use MoE +------- + +To use MoE in the NeMo Framework, adjust the ``num_moe_experts`` parameter in the model configuration: + +1. Set ``num_moe_experts`` to `8` to leverage 8 experts in the MoE module. + + .. code-block:: yaml + + num_moe_experts: 8 # Set MoE to use 8 experts + +2. Set ``moe_router_topk`` to the number of experts you want activated. For example, if you want to process each input with two experts: + + .. code-block:: yaml + + moe_router_topk: 2 # Processes each token using 2 experts. + +Configure MoE-specific Loss Functions +------------------------------------- + +In addition, NeMo provides options to configure MoE-specific loss function. +To balance token distribution across experts: + +1. Set ``moe_router_load_balancing_type`` to specify the load balancing method: + + .. code-block:: yaml + + moe_router_load_balancing_type: aux_loss # to use the auxilary loss, other options include "sinkhorn". + +2. Set ``moe_aux_loss_coeff`` to specify the weight of the auxilary loss. The auxiliary loss is added to encourage distributing tokens equally among all experts. Values in the 1e-2 range are a good start, as follows: + + .. code-block:: yaml + + moe_aux_loss_coeff: 1e-2 # set the aux-loss weight to 1e-2 + +3. Set ``moe_z_loss_coeff`` to specify the weight of the z-loss. A starting value of 1e-3 is recommended, as follows: + + .. code-block:: yaml + + moe_z_loss_coeff: 1e-3 + +Other options include: + +1. ``moe_input_jitter_eps`` adds noise to the input tensor by applying jitter with a specified epsilon value. + +2. ``moe_token_dropping`` enables selectively dropping and padding tokens for each expert to achieve + a specified capacity, similar to GShard, Switch-Transformer, and DeepSpeed-MoE. Briefly, if the number + of tokens routed to an expert exceeds its capacity, then the exceeding tokens are dropped. Note that this is + currently unsupported so should remain False. + +3. ``moe_token_dispatcher_type`` specifies the token dispatcher type, options include 'allgather' and 'alltoall'. + +4. ``moe_per_layer_logging`` enables per-layer logging for MoE, currently support aux-loss and z-loss. + +5. ``moe_expert_capacity_factor`` the capacity factor determines the maximum number of tokens that can be routed to each expert in any MoE layer. None means no token will be dropped. The default is None. + +6. ``moe_pad_expert_input_to_capacity`` if True, pads the input for each expert to match the expert capacity length. It is effective only after the moe_expert_capacity_factor is set. The default setting is False. + +7. ``moe_token_drop_policy`` the policy to drop tokens. Can be either "probs" or "position". If "probs", the tokens with the lowest probabilities will be dropped. If "position", tokens at the end of each batch will be dropped. The default value is "probs". + +8. ``moe_layer_recompute`` if True, checkpointing moe_layer to save activation memory. The default is False. diff --git a/docs/source/features/optimizations/activation_recomputation.rst b/docs/source/features/optimizations/activation_recomputation.rst new file mode 100644 index 000000000000..67de4401a4bc --- /dev/null +++ b/docs/source/features/optimizations/activation_recomputation.rst @@ -0,0 +1,52 @@ +Activation Recomputation +======================== + +The input activations of network layers are stored in the device memory to compute the gradients in back-propagation. +The input activation stores easily saturate the device memory when training a LLM with a large sequence length or a large micro-batch size. +Check-pointing a few activations and recomputing the rest of activations is a common technique to reduce the need of device memory. + +Transformer Layer Recomputation +------------------------------- + +NeMo supports Transformer layer recomputation that checkpoints the input of each Transformer layer and recomputes the activations on the rest of the layers. +Transformer layer recomputation significantly reduces the activation memory usage. +However, this approach increases per-Transformer layer computation cost by 30%, which comes from re-executing the entire layer forwarding computation. +NeMo also supports partial Transformer layer recomputation, which is beneficial when recomputing a few Transformer layers would fit the training workload on GPU memory. +This would avoid recomputing the rest of layers. + +Transformer layer recomputation is enabled by setting ``activations_checkpoint_granularity=full``. +The number of Transformer layers to recompute can be set using ``activations_checkpoint_num_layers`` along with ``activations_checkpoint_method=block``. +If one sets ``activations_checkpoint_num_layers`` as the total number of layers, the inputs of all Transformer layers are check-pointed and recomputed. +When training with the pipeline parallelism, ``activations_checkpoint_num_layers`` indicates the layers per pipeline stage. +If the virtual pipelining is used, ``activations_checkpoint_num_layers`` means the layers per virtual pipeline stage. + +NeMo also supports checkpointing the input to a block of multiple consecutive Transformer layers meaning that a block of Transformer layers becomes the recomputation granularity. +This can further save activation memory at the cost of increasing the recomputation buffer memory. +Thus, it is only beneficial for memory savings when the model has many Transformer layers or the intermediate layers of a Transformer layer hold relatively small activation stores. +This recomputation mode can be enabled by setting ``activations_checkpoint_method=uniform``, and the number of Transformer layers per recomputation block is set using ``activations_checkpoint_num_layers``. + +Self-attention Recomputation +---------------------------- + +NeMo supports the self-attention recomputation that checkpoints the inputs of each self-attention block and recomputes the intermediate input activations. +This is a cost-efficient recomputation method; achieves high memory saving with lost recomputation cost. +The intermediate layers of the self-attention block accounts for the majority portion the activation memory. +This is because the input sizes of softmax, dropout, and qkv dot-product attention layers have the memory complexity of the sequence length square. +However, their recomputation cost is relatively smaller than the other linear projection layers that are linear with the hidden size square. + +Self-attention recomputation is hard-enabled when using FlashAttention, which is supported in Transformer Engine. +Also, a user can use the self-attention recomputation without FlashAttention by setting ``activations_checkpoint_granularity=selective``. + +Scheme of full and selective checkpointing granularity: + +.. image:: https://github.com/NVIDIA/NeMo/releases/download/v2.0.0rc0/asset-post-activation-recomputation-exampe-2.jpg + :align: center + :alt: activation-recomputation-example-2 + :scale: 50% + +Scheme of uniform and block checkpointing method (full checkpointing granularity): + +.. image:: https://github.com/NVIDIA/NeMo/releases/download/v2.0.0rc0/asset-post-activation-recomputation-exampe-1.jpg + :align: center + :alt: activation-recomputation-example-1 + :scale: 50% \ No newline at end of file diff --git a/docs/source/features/memory_optimizations.rst b/docs/source/features/optimizations/attention_optimizations.rst similarity index 79% rename from docs/source/features/memory_optimizations.rst rename to docs/source/features/optimizations/attention_optimizations.rst index 4d363670fedf..d5ffe3c6fae8 100644 --- a/docs/source/features/memory_optimizations.rst +++ b/docs/source/features/optimizations/attention_optimizations.rst @@ -1,9 +1,5 @@ -Memory Optimizations -==================== - -Parallelism ------------ -Refer to :doc:`Parallelism <./parallelisms>`. +Attention Optimizations +======================= Flash Attention --------------- @@ -32,26 +28,6 @@ To disable Tri Dao flash attention, set the environment variable ``NVTE_FLASH_AT For more details on the Dot Product Attention backends supported in Transformer Engine, please refer to the source code at `Transformer Engine's Attention Mechanism `_. -Activation Recomputation ------------------------- - -Overview -^^^^^^^^ - -Full Activation Recomputation -""""""""""""""""""""""""""""" -The full activation recomputation method recalculates all the intermediate activations during the backward pass of a model's training, instead of storing them during the forward pass. This technique maximizes memory efficiency at the cost of computational overhead, as each activation is recomputed when needed. - -Partial Activation Recomputation -"""""""""""""""""""""""""""""""" -The partial activation recomputation method recomputes only a subset of layers during the backward phase. It is a trade-off between the full recomputation and no recomputation, balancing memory savings with computational efficiency. - -Selective Activation Recomputation -"""""""""""""""""""""""""""""""""" -The selective activation recomputation method reduces memory footprint of activations significantly via smart activation checkpointing. This approach involves selectively storing only crucial activations and recomputing the others as needed. It is particularly useful in large models to minimize memory usage while controlling the computational cost. - -Refer to "Reducing Activation Recomputation in Large Transformer Models" for more details: https://arxiv.org/abs/2205.05198. - Multi-query Attention (MQA) and Grouped-query Attention (GQA) ------------------------------------------------------------- @@ -104,4 +80,4 @@ Implement MQA or GQA NeMo's support for GQA and MQA is enabled through the integration of Megatron Core's Attention mechanism. The underlying implementation details can be explored within the Attention class of Megatron Core, which provides the functional backbone for these advanced attention methods. To understand the specific modifications and implementations of MQA and GQA, refer to the source code in the Attention class: -Check implementation details from Attention Class in Megatron Core Repo: https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/attention.py#L49 +Check implementation details from Attention Class in Megatron Core Repo: https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/attention.py#L49. diff --git a/docs/source/features/optimizations/communication_overlap.rst b/docs/source/features/optimizations/communication_overlap.rst new file mode 100644 index 000000000000..0ff93fe80604 --- /dev/null +++ b/docs/source/features/optimizations/communication_overlap.rst @@ -0,0 +1,64 @@ +Communication Overlap +===================== + +Data-parallel Communication Overlap +----------------------------------- + +NeMo supports the overlap of the data-parallel (DP) communications with the computations in LLM training. +NeMo features Distributed Optimizer that distributes optimizer states and the high-precision master parameters across GPUs. This introduces two types of data-parallel communications: reduce-scatter of gradients and all-gather of updated parameters. +The DP communication is chunked by the granularity of a Transformer layer and overlaps each communication chunk with computation. +This overlap method exposes only one DP communication chunk ensuring efficient large-scale LLM training. +When training with pipeline-parallelism, the granularity of DP communication becomes the Transformer layers per virtual pipeline stage. + +DP gradient reduce-scatter and parameter all-gather overlaps are enabled when setting ``overlap_grad_sync=true`` and ``overlap_param_sync=true``, respectively. +The precision of the gradient reduce-scatter is set by ``grad_sync_dtype`` and reduction in bf16 ensures improved performance at large scale training compared to the default precision of fp32. +When training in fp8 computing precision (with ``fp8=true``), setting ``fp8_params=true`` conducts the parameter all-gather in fp8, reducing the all-gather overhead by half. + +Tensor-parallel Communication Overlap +------------------------------------- + +Tensor parallelism, used with the sequence-parallel activation sharding (``sequence_parallel=true``), introduces activation (gradient) all-gather and reduce-scatter as shown in the below figure. +NeMo provides various options to overlap the tensor-parallel (TP) communications with computation. +The TP communication without direct computation dependency are overlapped with the computation in bulk (the linear layer and TP communication pairs in the yellow boxes). +The bulk TP communication is enabled by default. +The other TP communications with direct computation dependency are overlapped in pipelined fashion (the linear layer and TP communication pairs in the red boxes). +The TP communication and computation are chunked and the chunks are overlapped in pipeline. +In the pipelined overlap, the activation (gradient) tensor all-gather is replaced with multiple steps of input P2P ring exchanges, and reduce-scatter is replaced with multiple steps of GEMM output P2P ring exchanges followed by a reduction of the received outputs. +In case of the reduce-scatter overlap, NeMo also provides the option to pipeline-overlap using chunks of reduce-scatter, which exposes one reduce-scatter chunk. + + +.. image:: ../../nlp/nemo_megatron/images/tp_comm_overlap.png + :align: center + :width: 600px + :alt: Tensor-parallel communication overlap + +The pipelined TP communication overlap is implemented in Transformer Engine and is enabled by setting ``ub_tp_comm_overlap=true``. +The specific overlap methods can be set by a config dictionary, which set and is passed as a yaml file. +The individual bulk, pipelined all-gather, and reduce-scatter can be en- and disabled by ``tp_comm_bulk_wgrad``, ``tp_comm_bulk_dgrad``, ``tp_comm_overlap_ag``, and ``tp_comm_overlap_rs``, respectively. + +Pipeline-parallel Communication Overlap +--------------------------------------- + +Pipelining introduces P2P activation (gradient) sends and receives between pipeline-parallel (PP) GPUs. +The PP communication frequency increases when increasing the virtual-pipeline-parallel size because the number of Transformer layers executed per micro-batch decreases. +This increasing PP communication overhead and it cancels off the reduced the pipeline bubbles with virtual pipelining. +NeMo supports the overlap of the PP communications with non-dependant computations in the 1F1B stage (the body of pipelining, where 1X forward and 1X backward micro-batch executions are interleaved). +The PP communications in pipeline fill and flush are still exposed. + +.. image:: ../../nlp/nemo_megatron/images/pp_comm_overlap.png + :align: center + :width: 600px + :alt: Pipeline-parallel communication overlap in 1F1B pipelining phase + +The PP communication overlap is enabled when setting ``overlap_p2p_comm=true``. Also, setting ``batch_p2p_comm=false`` uses separate kernels for the send and the receive, which further improves the communication efficiency and GPU resource utilization. +NeMo supports PP communication overlap only with virtual pipelining, where PP communication becomes the performance bottleneck. +Please refer `GPT3 training config file `_ that uses the PP communication overlap. + +Context-parallel Communication Overlap +-------------------------------------- + +Context parallelism partitions activations (gradients) on all layers in the sequence domain. This introduces all-gather and reduce-scatter of activations (gradients) in self-attention forward- and back-propagations. +NeMo hides the context-parallel (CP) communications under the self-attention computation. +Like the TP communication overlaps, the CP communications are chunked then pipeline-overlapped with the self-attention computation, where the all-gather and the reduce-scatter of activations (gradients) are replaced with P2P ring exchanges of data. + +The CP communication overlap is default enabled when context parallelism is used (``context_parallel_size > 1``). diff --git a/docs/source/features/optimizations/cpu_offloading.rst b/docs/source/features/optimizations/cpu_offloading.rst new file mode 100644 index 000000000000..cf9d8951bf93 --- /dev/null +++ b/docs/source/features/optimizations/cpu_offloading.rst @@ -0,0 +1,19 @@ +CPU Offloading +============== + +Overview +-------- + +CPU Offloading in NeMo is a feature that reduces the peak memory usage of the GPU by offloading activations and inactive weights to CPU storage. NeMo supports offloading at the transformer layer level, allowing users to specify the number of transformer layers in their language model that require CPU offloading. During the forward pass, NeMo offloads activations at the optimal time and reloads them as needed during the backward pass. + +Features +-------- +- Supports training models with long sequence lengths by managing activation memory efficiently. +- Enables high batch sizes per GPU by offloading activation memory. +- Overlaps computation with data transfers (Host2Device and Device2Host) during offloading and reloading. + +Usage +----- +- Set cpu_offloading to True to enable CPU offloading. +- Set cpu_offloading_num_layers to a value between 0 and the total number of layers in the model minus one. +- Set cpu_offloading_activations and cpu_offloading_weights based on your needs to offload activations only, weights only, or both. diff --git a/docs/source/features/optimizations/index.rst b/docs/source/features/optimizations/index.rst new file mode 100644 index 000000000000..60f4428f9299 --- /dev/null +++ b/docs/source/features/optimizations/index.rst @@ -0,0 +1,12 @@ +Optimizations +============= + +.. toctree:: + :maxdepth: 1 + + ./attention_optimizations + ./sequence_packing + ./activation_recomputation + ./communication_overlap + ./cpu_offloading + diff --git a/docs/source/features/throughput_optimizations.rst b/docs/source/features/optimizations/sequence_packing.rst similarity index 96% rename from docs/source/features/throughput_optimizations.rst rename to docs/source/features/optimizations/sequence_packing.rst index dfd8b6cf9310..69e45f1e6a12 100644 --- a/docs/source/features/throughput_optimizations.rst +++ b/docs/source/features/optimizations/sequence_packing.rst @@ -1,5 +1,5 @@ -Throughput Optimizations -======================== +Sequence Packing +================ Sequence Packing for SFT/PEFT ----------------------------- @@ -140,11 +140,6 @@ please refer to the documentation below :doc:`../multimodal/mllm/sequence_packing` -Communication Overlap ---------------------- -NeMo leverages Megatron-Core's optimizations to enhance bandwidth utilization and effectively overlap computation with communication. Additional details will be provided soon. - - .. rubric:: Footnotes .. [#f1] Experiments were performed on Llama 7B with Dolly dataset. Actual performance improvement depends on dataset diff --git a/docs/source/features/parallelisms.rst b/docs/source/features/parallelisms.rst index 4cc493f40024..bf327fb18331 100644 --- a/docs/source/features/parallelisms.rst +++ b/docs/source/features/parallelisms.rst @@ -1,56 +1,48 @@ .. _parallelisms: Parallelisms ------------- +============ -NeMo Megatron supports five types of parallelism (which can be mixed together arbitrarily). +NeMo Megatron supports various data- and model-parallel deep learning workload deployment methods (which can be mixed together arbitrarily). Data Parallelism -^^^^^^^^^^^^^^^^ +---------------- -Data Parallelism (DP) creates identical copies of the model across -multiple GPUs. Data batches are distributed between GPUs so that the -GPUs can process them independently. While compute is efficiently -distributed between GPUs, communication is required in order to keep -the model copies consistent with each other. +Data Parallelism (DP) replicates the model across multiple GPUs. +Data batches are evenly distributed between GPUs and the data-parallel GPUs process them independently. +While the computation workload is efficiently distributed across GPUs, inter-GPU communication is required in order to keep the model replicas consistent between training steps. Distributed Data Parallelism -~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Distributed Data Parallelism (DDP) keeps model copies consistent by -synchronizing parameter gradients before each optimization step. More -specifically, it sums gradients over all model copies using an -all-reduce communication collective. +Distributed Data Parallelism (DDP) keeps the model copies consistent by synchronizing parameter gradients across data-parallel GPUs before each parameter update. +More specifically, it sums the gradients of all model copies using all-reduce communication collectives. .. image:: ../nlp/nemo_megatron/images/ddp.gif :align: center :width: 800px :alt: Distributed Data Parallel -Distributed Optimizer (ZeRO-1) -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Distributed Optimizer +^^^^^^^^^^^^^^^^^^^^^ -The ZeRO-1 algorithm keeps model copies consistent by sharding the -optimizer state between GPUs. During each optimization step, the -parameter gradients are first summed and sharded (with a -reduce-scatter collective), each GPU applies an optimization to its -local shard of the parameters, and the updated parameter shards are -broadcast to update all of the model copies (with an all-gather -collective). This approach is attractive for large models since -sharding the optimizer state can significantly reduce its memory -footprint on individual GPUs. It also has, in theory, the same -communication volume as DDP and its communication pattern has more -opportunities for overlapping with compute. +Distributed optimizer is a memory-optimized data-parallel deployment method. +It shards the optimizer states and the high-precision master parameters across data-parallel GPUs instead replicating them. +At the parameter optimizer step, each data-parallel GPU updates its shard of parameters. +Since each GPU needs its own gradient shard, the distributed optimizer conducts reduce-scatter of the parameter gradients instead of all-reduce of them. +Then, the updated parameter shards are all-gathered across data-parallel GPUs. +This approach significantly reduces the memory need of large scale LLM training. +Also, when the precision of the gradient is higher than the parameter precision, the split execution of gradient reduce-scatter and parameter all-gather can reduce the total communication volume. +This split collective execution increases the total computation to overlap with the communication, which improves the overlap opportunity. Enable Data Parallelism ~~~~~~~~~~~~~~~~~~~~~~~ -DDP is the default parallelism scheme when NeMo is run on multiple -GPUs. Enabling other parallelism schemes in the model configuration -will decrease the size of the DP group, that is the number of -identical model copies. +In NeMo, DDP is the default parallel deployment method. +This means that the total number of GPUs corresponds to the size of the DP group and training a LLM with model parallelism decreases the size of the DP group. -To enable the distributed optimizer, set +Currently, NeMo supports optimizer distribution only for Adam optimizer. +To enable the distributed adam optimizer, set ``model.optim.name=distributed_fused_adam`` in the model configuration. It can be configured with the following options: @@ -80,10 +72,36 @@ The distributed optimizer in NeMo is built on top of `DistributedFusedAdam `_ from Apex. +Fully-Shared Data Parallelism (FSDP) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +NeMo supports Fully-Sharded Data Parallelism (FSDP) that shards parameter gradients and low-precision parameters for computation on top of the model states that Distributed optimizer shards (optimizer states and high-precision parameters). +Since FSDP shards the entire model states, it ensures linear model state memory saving with increasing DP size. +FSDP can be preferred for the LLM training with unbalanced workload between pipeline stages (or Transformer layers) or with a large vocabulary size, where pipelining would cause huge computation bubbles due to the workload imbalance. +Also, FSDP unloads the effort to search the performance-optimal mappings with 3D parallelism (TP/PP/DP) because it has a single parallelization domain. + +NeMo uses `pytorch's FSDP interface `_ to shard LLM model states, which flattens the parameters of each Transformer layer and partitions across datap-parallel GPUs. +FSDP introduces collectives across data-parallel GPUs; all-gather of the parameters for computation and reduce-scatter of parameter gradients. +The parameter all-gather occurs in both network forward- and back-propagation phases. The gradient reduce-scatter happens only in the back-propagation. +These FSDP communications are overlapped with Transformer layer computations. + +Setting ``fsdp=true`` enables FSDP. +The mixed precision recipe can be set by ``precision`` knob, which determines both the computation and communication precisions. +Also, one can use ``grad_reduce_dtype`` to override the gradient reduction precision specifically. + + +Model Parallelism +----------------- + +Model parallelism (MP) is a distributed model deployment method that partitions the model parameters across GPUs to reduce the need of per-GPU memory. +NeMo supports various model-parallel methods, which can be mixed to maximize LLM training performance. + Tensor Parallelism ^^^^^^^^^^^^^^^^^^ -Tensor Parallelism (TP) is a method for distributing a model's computation across multiple GPUs by splitting tensors into non-overlapping pieces. This allows different parts of the tensor to be processed simultaneously on separate GPUs, enhancing performance and enabling the training of larger models. +Tensor Parallelism (TP) is a model-parallel partitioning method that distributes the parameter tensor of an individual layer across GPUs. +On top of reducing the model state memory usage, it also saves the activation memory as per-GPU tensor sizes shrinks. +However, the reduced per-GPU tensor lowers per-GPU-kernel workload sizes that increases CPU overhead. .. image:: ../nlp/nemo_megatron/images/tp.gif :align: center @@ -112,6 +130,16 @@ NeMo integrates Tensor Parallelism through the implementation from Megatron Core For detailed API usage and additional configurations, consult the `Megatron Core Developer Guide `_. +FSDP with Tensor Parallelism +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +NeMo supports FSDP along with tensor parallelism. This is done by restricting the model state sharding to the data-parallel domain. +Using FSDP with tensor parallelism can be helpful when the model doesn't have sufficient parallelism to deploy on a large scale training system with the data-parallel mapping. For example, running a model with the global batch size of 1024 on 2048 GPUs. +Also, tensor parallelism enables FSDP feasibility by reducing the model state size and the activation size per GPU, thus lower the FSDP communication overhead and the activation memory overhead. + +Using both FSDP and TP works by enabling FSDP (``fsdp=true``) and setting ``tensor_model_parllel_size > 1``. +The user should unset ``CUDA_DEVICE_MAX_CONNECTIONS`` environment variable to enable that sets the number of GPU kernel queue to overlap of the FSDP communication with computation kernels. + Pipeline Parallelism ^^^^^^^^^^^^^^^^^^^^ @@ -156,6 +184,40 @@ The NeMo implementation of PP leverages functionalities from Megatron Core. For For more detailed API usage and configurations related to PP, visit the `Megatron Core Developer Guide `_. +Expert Parallelism +^^^^^^^^^^^^^^^^^^ +Expert Parallelism (EP) is a type of model parallelism that distributes experts of an MoE across GPUs. +Unlike other model-parallel techniques, EP is applied to only the expert layers thus does not impact the parallel mapping of the rest of layers. + +.. image:: ../nlp/nemo_megatron/images/ep.png + :align: center + :width: 800px + :alt: Expert Parallelism + +Enable Expert Parallelism +~~~~~~~~~~~~~~~~~~~~~~~~~ + +To enable EP, set ``model.expert_model_parallel_size`` to the desired expert parallel size. For example, if the model has six experts (``model.num_moe_experts=6``), then setting ``model.expert_model_parallel_size=3`` results in each GPU processing two experts. The number of experts should be divisible by the expert parallel size. + + .. code-block:: yaml + + expert_model_parallel_size: 3 # Set EP to 3 + +For further information on configuration, refer to the following documentation: `NeMo Megatron GPT Config `_. + + +Implement Expert Parallelism +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The NeMo implementation of Expert Parallelism uses functionality from Megatron Core. Please consult the `Megatron Core MoE layer `_ for more MoE implementation details. + + +Activation Partitioning +----------------------- + +In LLM training, a large memory space is needed to store the input activations of the network layers. +NeMo provides effective activation distribution methods, which is critical in training LLM with a large sequence length or large per-GPU micro-batch size. + Sequence Parallelism ^^^^^^^^^^^^^^^^^^^^ @@ -185,7 +247,8 @@ The NeMo implementation of Sequence Parallelism utilizes functionality from Mega Context Parallelism ^^^^^^^^^^^^^^^^^^^ -Context Parallelism (CP) is a method for parallelizing the processing of neural network activations across multiple GPUs, focusing on the sequence dimension of the input data. Unlike Sequence Parallelism (SP) that only partitions specific types of activations, CP divides all network activations along the sequence dimension. +Context Parallelism (CP) is a method for parallelizing the processing of neural network activations across multiple GPUs, partitioning the input tensors in the sequence dimension. +Unlike Sequence Parallelism (SP) that partitions the activations of specific layers, CP divides the activations of all layers. Enable Context Parallelism ~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -212,34 +275,7 @@ Visit our source code for more insights into the implementation: - `Transformer Engine attention modules `_ -Expert Parallelism -^^^^^^^^^^^^^^^^^^ -Expert Parallelism (EP) is a type of model parallelism that distributes experts of an MoE across GPUs. - -.. image:: ../nlp/nemo_megatron/images/ep.png - :align: center - :width: 800px - :alt: Expert Parallelism - -Enable Expert Parallelism -~~~~~~~~~~~~~~~~~~~~~~~~~ - -To enable EP, set ``model.expert_model_parallel_size`` to the desired expert parallel size. For example, if the model has six experts (``model.num_moe_experts=6``), then setting ``model.expert_model_parallel_size=3`` results in each GPU processing two experts. The number of experts should be divisible by the expert parallel size. - - .. code-block:: yaml - - expert_model_parallel_size: 3 # Set EP to 3 - -For further information on configuration, refer to the following documentation: `NeMo Megatron GPT Config `_. - - -Implement Expert Parallelism -~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -The NeMo implementation of Expert Parallelism uses functionality from Megatron Core. Please consult the `Megatron Core MoE layer `_ for more MoE implementation details. - - -Parallelism nomenclature +Parallelism Nomenclature ^^^^^^^^^^^^^^^^^^^^^^^^ The following figure illustrates some terms that you may encounter in the NeMo Megatron codebase. diff --git a/docs/source/index.rst b/docs/source/index.rst index f10ae126267b..2f75014b86d1 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -41,6 +41,7 @@ For quick guides and tutorials, see the "Getting started" section below. :titlesonly: starthere/intro + starthere/fundamentals starthere/tutorials For more information, browse the developer docs for your area of interest in the contents section below or on the left sidebar. @@ -53,15 +54,15 @@ For more information, browse the developer docs for your area of interest in the features/mixed_precision features/parallelisms - features/memory_optimizations - features/throughput_optimizations + features/moe + features/optimizations/index .. toctree:: :maxdepth: 1 - :caption: Community Model Converters - :name: CheckpointConverters + :caption: Model Checkpoints + :name: Checkpoints - ckpt_converters/intro + checkpoints/intro .. toctree:: :maxdepth: 1 diff --git a/docs/source/multimodal/api.rst b/docs/source/multimodal/api.rst index 7a9fe2822d07..2ba9978b7640 100644 --- a/docs/source/multimodal/api.rst +++ b/docs/source/multimodal/api.rst @@ -8,7 +8,7 @@ Model Classes :show-inheritance: :no-members: :members: __init__, configure_optimizers - :no-index: + :noindex: .. autoclass:: nemo.collections.multimodal.models.text_to_image.stable_diffusion.ldm.ddpm.MegatronLatentDiffusion diff --git a/docs/source/multimodal/mllm/checkpoint.rst b/docs/source/multimodal/mllm/checkpoint.rst deleted file mode 100644 index d1fe7b651e66..000000000000 --- a/docs/source/multimodal/mllm/checkpoint.rst +++ /dev/null @@ -1,114 +0,0 @@ -Checkpoints -=========== - -In this section, we present four key functionalities of NVIDIA NeMo related to checkpoint management: - -1. **Checkpoint Loading**: Load local ``.nemo`` checkpoint files with the :code:`restore_from()` method. -2. **Partial Checkpoint Conversion**: Convert partially-trained ``.ckpt`` checkpoints to the ``.nemo`` format. -3. **Community Checkpoint Conversion**: Transition checkpoints from community sources, like HuggingFace, into the ``.nemo`` format. -4. **Model Parallelism Adjustment**: Modify model parallelism to efficiently train models that exceed the memory of a single GPU. NeMo employs both tensor (intra-layer) and pipeline (inter-layer) model parallelisms. Dive deeper with "Efficient Large-Scale Language Model Training on GPU Clusters Using Megatron-LM" (`link `_). This tool aids in adjusting model parallelism, accommodating users who need to deploy on larger GPU arrays due to memory constraints. - -Understanding Checkpoint Formats --------------------------------- - -A ``.nemo`` checkpoint is fundamentally a tar file that bundles the model configurations (given as a YAML file), model weights, and other pertinent artifacts like tokenizer models or vocabulary files. This consolidated design streamlines sharing, loading, tuning, evaluating, and inference. - -On the other hand, the ``.ckpt`` file is a product of PyTorch Lightning training. It stores model weights and optimizer states, and it's generally used for resuming training. - -Subsequent sections delve into each of the previously listed functionalities, emphasizing the loading of fully trained checkpoints for evaluation or additional fine-tuning. - - -Loading Local Checkpoints -------------------------- - -NeMo inherently saves any model's checkpoints in the ``.nemo`` format. To manually save a model at any stage: - -.. code-block:: python - - model.save_to(.nemo) - -To load a local ``.nemo`` checkpoint: - -.. code-block:: python - - import nemo.collections.multimodal as nemo_multimodal - model = nemo_multimodal.models..restore_from(restore_path="") - -Replace `` with the appropriate MM model class. - -Converting Local Checkpoints ----------------------------- - -The training script only auto-converts the final checkpoint into the ``.nemo`` format. To evaluate intermediate training checkpoints, conversion to ``.nemo`` might be needed. For this: - -.. code-block:: bash - - python -m torch.distributed.launch --nproc_per_node= * \ - examples/multimodal/convert_ckpt_to_nemo.py \ - --checkpoint_folder \ - --checkpoint_name \ - --nemo_file_path \ - --tensor_model_parallel_size \ - --pipeline_model_parallel_size - -Converting Community Checkpoints --------------------------------- - -NeVA Checkpoints -^^^^^^^^^^^^^^^^ - -Currently, the conversion mainly supports LLaVA checkpoints based on "llama-2 chat" checkpoints. As a reference, we'll consider the checkpoint `llava-llama-2-13b-chat-lightning-preview `_. - -After downloading this checkpoint and saving it at ``/path/to/llava-llama-2-13b-chat-lightning-preview``, undertake the following procedures: - -Modifying the Tokenizer -""""""""""""""""""""""" - -NeMo mandates adding specific tokens to the tokenizer model for peak performance. To modify an existing tokenizer located in ``/path/to/llava-llama-2-13b-chat-lightning-preview/tokenizer``, execute the following in the NeMo container: - -.. code-block:: bash - - cd /opt/sentencepiece/src/ - protoc --python_out=/opt/NeMo/scripts/tokenizers/ sentencepiece_model.proto - python /opt/NeMo/scripts/tokenizers/add_special_tokens_to_sentencepiece.py \ - --input_file /path/to/llava-llama-2-13b-chat-lightning-preview/tokenizer.model \ - --output_file /path/to/llava-llama-2-13b-chat-lightning-preview/tokenizer_neva.model \ - --is_userdefined \ - --tokens "" "" "" "" \ - "" "" "" "" - -Checkpoint Conversion -""""""""""""""""""""" - -For conversion: - -.. code-block:: bash - - python examples/multimodal/mllm/neva/convert_hf_llava_to_neva.py \ - --in-file /path/to/llava-llama-2-13b-chat-lightning-preview \ - --out-file /path/to/neva-llava-llama-2-13b-chat-lightning-preview.nemo \ - --tokenizer-model /path/to/llava-llama-2-13b-chat-lightning-preview/tokenizer_add_special.model - --conv-template llama_2 - - -Model Parallelism Adjustment ----------------------------- - -NeVA Checkpoints -^^^^^^^^^^^^^^^^ - -Adjust model parallelism with: - -.. code-block:: bash - - python examples/nlp/language_modeling/megatron_change_num_partitions.py \ - --model_file=/path/to/source.nemo \ - --target_file=/path/to/target.nemo \ - --tensor_model_parallel_size=??? \ - --target_tensor_model_parallel_size=??? \ - --pipeline_model_parallel_size=??? \ - --target_pipeline_model_parallel_size=??? \ - --model_class="nemo.collections.multimodal.models.multimodal_llm.neva.neva_model.MegatronNevaModel" \ - --precision=32 \ - --tokenizer_model_path=/path/to/tokenizer.model \ - --tp_conversion_only diff --git a/docs/source/multimodal/mllm/intro.rst b/docs/source/multimodal/mllm/intro.rst index 0e76a9737a0f..c67e47e34537 100644 --- a/docs/source/multimodal/mllm/intro.rst +++ b/docs/source/multimodal/mllm/intro.rst @@ -8,7 +8,21 @@ The endeavor to extend Language Models (LLMs) into multimodal domains by integra datasets configs - checkpoint neva video_neva sequence_packing + + +Speech-agumented Large Language Models (SpeechLLM) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The endeavor to extend Language Models (LLMs) with the ability to understand speech and audio inputs, detailed examples can be found in the `SpeechLLM example `_.. + +.. toctree:: + :maxdepth: 1 + + ../speech_llm/intro + ../speech_llm/datasets + ../speech_llm/configs + ../speech_llm/api + diff --git a/docs/source/multimodal/speech_llm/api.rst b/docs/source/multimodal/speech_llm/api.rst new file mode 100644 index 000000000000..c2415f29c720 --- /dev/null +++ b/docs/source/multimodal/speech_llm/api.rst @@ -0,0 +1,88 @@ +SpeechLLM API +============= + +Model Classes +------------- + +.. autoclass:: nemo.collections.nlp.models.language_modeling.megatron_base_model.MegatronBaseModel + :show-inheritance: + :no-members: + :members: __init__, configure_optimizers + :noindex: + + +.. autoclass:: nemo.collections.multimodal.speech_llm.models.modular_models.ModularAudioGPTModel + :show-inheritance: + :no-members: + :members: __init__, training_step, validation_step, setup, build_train_valid_test_datasets + + +.. autoclass:: nemo.collections.multimodal.speech_llm.models.modular_models.CrossAttendModularAudioGPTModel + :show-inheritance: + :no-members: + :members: __init__, training_step, validation_step, setup, build_train_valid_test_datasets + + +.. autoclass:: nemo.collections.multimodal.speech_llm.models.modular_t5_models.ModularizedAudioT5Model + :show-inheritance: + :no-members: + :members: __init__, training_step, validation_step, setup, build_train_valid_test_datasets + + +.. autoclass:: nemo.collections.multimodal.speech_llm.models.modular_t5_models.DecoderTextPromptModularizedAudioT5Model + :show-inheritance: + :no-members: + :members: __init__, training_step, validation_step, setup, build_train_valid_test_datasets + + + +Modules +------- + +.. autoclass:: nemo.collections.multimodal.speech_llm.modules.perception_modules.AudioPerceptionModule + :show-inheritance: + :no-members: + +.. autoclass:: nemo.collections.multimodal.speech_llm.modules.perception_modules.MultiAudioPerceptionModule + :show-inheritance: + :no-members: + +.. autoclass:: nemo.collections.multimodal.speech_llm.modules.TransformerCrossAttention + :show-inheritance: + :no-members: + + +Dataset Classes +--------------- +.. autoclass:: nemo.collections.multimodal.speech_llm.data.audio_text_dataset.AudioTextDataset + :show-inheritance: + :no-members: + +.. autoclass:: nemo.collections.multimodal.speech_llm.data.audio_text_dataset.TarredAudioTextDataset + :show-inheritance: + :no-members: + +.. autoclass:: nemo.collections.multimodal.speech_llm.data.audio_text_dataset.get_tarred_audio_text_dataset_from_config + :show-inheritance: + :no-members: + +.. autoclass:: nemo.collections.multimodal.speech_llm.data.audio_text_dataset.get_audio_text_dataset_from_config + :show-inheritance: + :no-members: + +.. autoclass:: nemo.collections.multimodal.speech_llm.data.lhotse_dataset.LhotseAudioQuestionAnswerDataset + :show-inheritance: + :no-members: + +.. autoclass:: nemo.collections.multimodal.speech_llm.data.build_dataset.build_speechllm_dataset + :show-inheritance: + :no-members: + +.. autoclass:: nemo.collections.multimodal.speech_llm.data.build_dataset.build_speechllm_dataloader + :show-inheritance: + :no-members: + + + + + diff --git a/docs/source/multimodal/speech_llm/configs.rst b/docs/source/multimodal/speech_llm/configs.rst new file mode 100644 index 000000000000..5edd169eed25 --- /dev/null +++ b/docs/source/multimodal/speech_llm/configs.rst @@ -0,0 +1,197 @@ +Common Configuration Files +========================== + +This section provides a detailed overview of the NeMo configuration file setup specific to models within the NeMo SpeechLLM collection. For foundational knowledge about setting up and executing experiments common to all NeMo models, such as the Experiment Manager and PyTorch Lightning trainer parameters, refer to the :doc:`core <../../core/core>` documentation. + +Within the configuration files of the NeMo SpeechLLMs, details concerning dataset(s), augmentation, optimization parameters, and model architectural specifications are central. This page explores each of these aspects. + +Discover exemplary configuration files for all SpeechLLMs in the `config directory of the examples `_. + + +Dataset Configuration +--------------------- + +The dataset configuration is based on the NeMo ASR data configuration and the NLP data configuration + +The configuration file allows setting any initialization parameter accepted by the Dataset class used in the experiment. For a comprehensive list of Datasets and their parameters, visit the `Datasets <./api.html#Datasets>`__ section of the API. + +A typical training configuration is as follows: + +.. code-block:: yaml + + train_ds: + manifest_filepath: ??? # Path to a list of JSONL files corresponding to the source data. + global_batch_size: 4 + micro_batch_size: 2 + shuffle: True + num_workers: 0 + pin_memory: True + max_seq_length: 2048 + min_seq_length: 1 + drop_last: True + concat_sampling_probabilities: null # When providing a list of datasets, this arg defines the sampling probabilities from each dataset when strategy='random' + context_key: 'context' + answer_key: 'answer' + add_eos: True + add_eos: False + end_string: null + add_sep: False + add_bos: False + separate_prompt_and_response_with_newline: False + truncation_field: "context" # Options: ['context', 'answer'] + prompt_template: "Q: {context}\nA: {answer}" # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}" + # ASR configs + sample_rate: 16000 #${model.audio_encoder.preprocessor.sample_rate} + max_duration: 24 # it is set for LibriSpeech, you may need to update it for your dataset + min_duration: 0.1 + # tarred datasets + is_tarred: false + tarred_audio_filepaths: null + shuffle_n: 2048 + # bucketing params + bucketing_strategy: "fully_randomized" + bucketing_batch_size: null + # multi-audio configs + audio_locator: null + + +Key parameters include: + +- ``manifest_filepath``: The path to the dataset in JSON lines format, where each line in the file is a python dictionary. This can either be a single file or a list of files. +- ``global_batch_size``: The global batch size that takes consideration of gradient accumulation, data parallelism. +- ``micro_batch_size``: The micro batch size that fits on each GPU. +- ``shuffle``: Whether to shuffle the dataset. +- ``num_workers``: The number of workers to use for data loading. +- ``pin_memory``: Whether to pin memory for faster data transfer. +- ``max_seq_length``: The maximum sequence length for LLM. +- ``min_seq_length``: The minimum sequence length for LLM. +- ``drop_last``: Whether to drop the last batch if it is smaller than the batch size. +- ``context_key``: The key in the JSON line that corresponds to the context used for LLM input. +- ``answer_key``: The key in the JSON line that corresponds to the answer used for groundtruth. +- ``add_eos``: Whether to add an end-of-sequence token. +- ``add_bos``: Whether to add a beginning-of-sequence token. +- ``add_sep``: Whether to add a separator token. +- ``end_string``: The string to used to trigger end of generation, default to null to use EOS token. +- ``separate_prompt_and_response_with_newline``: Whether to separate the prompt and response with a newline. +- ``truncation_field``: The field to truncate if the sequence length exceeds the maximum sequence length. +- ``prompt_template``: The fstring to use for the LLM prompt, where the context and answer will be formatted. +- ``sample_rate``: The sample rate of the audio data. +- ``max_duration``: The maximum duration of the audio data to be included. +- ``min_duration``: The minimum duration of the audio data to be included. +- ``is_tarred``: Whether the dataset is tarred. +- ``tarred_audio_filepaths``: The path to the tarred audio files. +- ``shuffle_n``: The number of samples to shuffle in tarred datasets, not used for non-tarred datasets. +- ``bucketing_strategy``: The strategy to use for bucketing, options include 'fully_randomized', 'synced_randomized'. +- ``bucketing_batch_size``: The batch size to use for each bucket, if not provided, the micro batch size is used. +- ``audio_locator``: The special string to locate the position of each audio to be put in the text prompt. + + +Trainer Configuration +--------------------- + +This section outlines arguments for the Pytorch Lightning Trainer Object. + +.. code-block:: yaml + + trainer: + devices: 1 # number of GPUs (0 for CPU), or list of the GPUs to use e.g. [0, 1] + num_nodes: 1 + max_epochs: -1 + max_steps: 2500000 # precedence over max_epochs + logger: False # Provided by exp_manager + precision: bf16 # Should be set to 16 for O1 and O2 to enable the AMP. + accelerator: gpu + log_every_n_steps: 5 # Interval of logging. + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + num_sanity_val_steps: 10 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it + enable_checkpointing: False # Provided by exp_manager + accumulate_grad_batches: 1 # do not modify, grad acc is automatic for training megatron models + gradient_clip_val: 1.0 + benchmark: False + enable_model_summary: True + +For a detailed list of arguments, refer to the `Pytorch Lightning Trainer `__ API section. + +Experiment Manager Configurations +--------------------------------- + +The NeMo Experiment Manager provides a streamlined approach to manage various tasks such as logging, saving, and resuming. + +.. code-block:: yaml + + exp_manager: + exp_dir: null # exp_dir for your experiment, if None, defaults to "./nemo_experiments" + name: ${name} + create_wandb_logger: True + wandb_logger_kwargs: # Whether you want exp_manger to create a Wandb logger + name: training-session + project: text2img + group: nemo + resume: True + create_tensorboard_logger: True # Whether you want exp_manger to create a tb logger + create_checkpoint_callback: True # Whether you want exp_manager to create a modelcheckpoint callback + checkpoint_callback_params: + monitor: reduced_train_loss + save_top_k: 5 + every_n_epochs: 0 # Save checkpoint frequency. + every_n_train_steps: 1000 # Mutually exclusive with every_n_epochs. It is recommended to set this if training on large-scale dataset. + filename: '${name}--{reduced_train_loss:.2f}-{step}-{consumed_samples}' + resume_if_exists: True + resume_ignore_no_checkpoint: True + resume_from_checkpoint: ${model.resume_from_checkpoint} + ema: + enable: True + decay: 0.9999 + validate_original_weights: False + every_n_steps: 1 + cpu_offload: False + +Optimizer Configurations +------------------------- + +.. code-block:: yaml + + optim: + name: fused_adam + lr: 0.0001 + eps: 1e-8 + betas: [ 0.9, 0.999 ] + weight_decay: 0.01 + sched: + name: WarmupPolicy + warmup_steps: 10000 + warmup_ratio: null + +The default optimizer used is ``fused_adam``. For details on all supported optimizers, refer to the NeMo user guide. The learning rate scheduler can be specified in the ``optim.sched`` section. + +Model Configurations +-------------------- + +Each configuration file should detail the model architecture used for the experiment. + +The parameters commonly shared across most multimodal language models include: + ++------------------------------------------+--------------+---------------------------------------------------------------------------------------+ +| **Parameter** | **Datatype** | **Description** | ++===========================+==============+==============+=======================================================================================+ +| :code:`micro_batch_size` | int | micro batch size that fits on each GPU | ++------------------------------------------+--------------+---------------------------------------------------------------------------------------+ +| :code:`global_batch_size` | int | global batch size that takes consideration of gradient accumulation, data parallelism | ++------------------------------------------+--------------+---------------------------------------------------------------------------------------+ +| :code:`tensor_model_parallel_size` | int | intra-layer model parallelism | ++------------------------------------------+--------------+---------------------------------------------------------------------------------------+ +| :code:`pipeline_model_parallel_size` | int | inter-layer model parallelism | ++------------------------------------------+--------------+---------------------------------------------------------------------------------------+ +| :code:`seed` | int | seed used in training | ++------------------------------------------+--------------+---------------------------------------------------------------------------------------+ + +SALM +~~~~ + +For model-specific configurations, refer to `the examples `_. + + +BESTOW +~~~~~~ + +For model-specific configurations, refer to `the examples `_. diff --git a/docs/source/multimodal/speech_llm/datasets.rst b/docs/source/multimodal/speech_llm/datasets.rst new file mode 100644 index 000000000000..c251213eb3d6 --- /dev/null +++ b/docs/source/multimodal/speech_llm/datasets.rst @@ -0,0 +1,109 @@ +SpechLLM Dataset +================ + +The dataset classes can be found on `NeMo GitHub `_. + + +Input Manifest Format +--------------------- + +You'll need to prepare data in the NeMo manifest format, where each line is a python dictionary with some keys, for example: + +.. code-block:: yaml + + { + "audio_filepath": "path/to/audio.wav", + "offset": 0.0, # offset of the audio in seconds, this is an optional field + "duration": 10.0 , # duration of the audio in seconds, can set to `None` to load the whole audio + "context": "what is the transcription of the audio?", # text prompt for the audio, see below for more details + "answer": "the transcription of the audio", # optional for inference, default to "na" in dataloader + } + + +The `context` field in the manifest is optional, and you can put a list of context in a context file (one context for each line) then set `++model.data.train_ds.context_file=` to ask the dataloader to randomly pick a context from the file for each audio sample. This is useful for training with multiple prompts for the same task. If neither `context` field nor `context_file` is provided, the dataloader will use a default context `what does the audio mean?` for all audios. During inference, it is recommended to have the `context` field in the manifest. + +Customizing the fields to use +----------------------------- + +You can also use other fields in the manifest to replace the `context` and `answer`fields, but you'll also need to change the `prompt_template` to use the new field names. For example, if you desire to use the new fields `input_text` and `output_text`, you need to set: + +.. code-block:: bash + + ++model.data.train_ds.context_key=input_text \ + ++model.data.train_ds.answer_key=output_text \ + ++model.data.train_ds.prompt_template="'Q: {input_text}\nA: {output_text}'" + +Note that there're single quotes around the prompt template (to avoid hydra errors), and the field names are wrapped in curly braces. + + +Customizing the input format +---------------------------- + +If you would like to use multiple audios, you can set the `audio_filepath` to be a list of audio file paths, and specify the location of each audio by using a special `audio_locator` string in the context. The choice of `audio_locator` should also be passed into the config. For example, if you have a manifest item like this: + +.. code-block:: yaml + + { + "audio_filepath": ["path/to/audio1.wav", "path/to/audio2.wav"], + "context": "what is the transcription of the [audio] and [audio]?", # text prompt for the audio, see below for more details + "answer": "the transcription of the audio1 and audio2", # optional for inference, default to "na" in dataloader + } + + +You can set the `audio_locator` to be `[audio]` in the config: + +.. code-block:: bash + + ++model.data.train_ds.audio_locator='[audio]' + + +By using `audio_locator`, the dataloader will replace the `audio_locator` in the context with the corresponding audio features extracted for each audio. You need to make sure that the number of audio locators in the context matches the number of audio files in the `audio_filepath` field. + + + +Multi-task Training +------------------- + + +In order to use a context file, you can set `++model.data.train_ds.context_file=` in the command line or use multiple context files with `++model.data.train_ds.context_file=[,,...]`. If the number of context files is equal to the number of provided datasets, the dataloader will assigne each context file to a dataset. Otherwise, the dataloader will randomly pick a context file from all provided context files for each audio sample. Using multiple context files is useful for training with multiple tasks, where each task has its own set of prompts. Meanwhile, you can control the weights for different tasks/datasets by using concatentated tarred datasets, where you can assign weights to datasets by: + +.. code-block:: bash + + ++model.data.train_ds.is_tarred=True \ + ++model.data.train_ds.is_concat=True \ + ++model.data.train_ds.manifest_filepath=[/path/to/data1/tarred_audio_manifest.json,/path/to/data2/tarred_audio_manifest.json] \ + ++model.data.train_ds.tarred_audio_filepaths=[/path/to/data1/audio__OP_0..1023_CL_.tar,/path/to/data2/audio__OP_0..1023_CL_.tar] \ + ++model.data.train_ds.concat_sampling_technique='random' \ + ++model.data.train_ds.concat_sampling_probabilities=[0.4,0.6] \ + + + +Use Lhotse Dataloader +--------------------- + +Speech-LLM supports NeMo dataloader and Lhotse dataloader. Most of the Lhotse specific flags can be referred to `Lhotse Dataloader `. +Example config can be referred to `Lhotse Speech-LLM examples `_. + +Lhotse Dataloader also supports using a standalone YAML file to set up the manifest info: + +.. code-block:: bash + + ++model.data.train_ds.input_cfg=$INPUT_CFG_FILE \ + +which points to a $INPUT_CFG_FILE file like the following: + +.. code-block:: yaml + + - input_cfg: + - manifest_filepath: manifest1.json + type: nemo + weight: 2.0 + tags: + default_context: "please transcribe the audio" + - manifest_filepath: manifest2.json + type: nemo + weight: 1.0 + tags: + default_context: "please translate English audio to German" + type: group + weight: 0.4 diff --git a/docs/source/multimodal/speech_llm/intro.rst b/docs/source/multimodal/speech_llm/intro.rst new file mode 100644 index 000000000000..55ea13d7d411 --- /dev/null +++ b/docs/source/multimodal/speech_llm/intro.rst @@ -0,0 +1,41 @@ +Speech-agumented Large Language Models (SpeechLLM) +================================================== + +The endeavor to extend Language Models (LLMs) with the ability to understand speech and audio inputs, detailed examples can be found in the `SpeechLLM example `_.. + +.. toctree:: + :maxdepth: 1 + datasets + configs + api + + +In general, there're three main components of a modular SpeechLLM: +- An audio encoder that processes the input audio and produces a sequence of audio embeddings. +- A modality adapter that processes the audio embeddings and produces a sequence of embeddings in the same latent space as the token embeddings of a pretrained large language model (LLM). +- A pretrained large language model (LLM) that processes embeddings from the modality adapter as well as token embeddings of input prompt, and produces the text output. The audio embeddings and text token embeddings are concatenated in time dimension before going into the LLM. +- The LLM produces text outputs based on the concatenated input audio and text embedding. + + +Model Architecture +^^^^^^^^^^^^^^^^^^ + +One way to incorporate speech into LLM is to concatenate speech features with the token embeddings of the input text prompt before being fed into the LLM. In this way, the LLM can have direct access to the speech information when generating the output text. + .. image:: https://github.com/NVIDIA/NeMo/releases/download/v1.23.0/salm.png + :align: center + :alt: SALM model + :scale: 50% + + + +Another way is to use cross-attention mechanism, by using text embeddings to attend to speech embeddings to extract task-specific information from the speech embeddings. In order to minimize the computational cost of cross-attention, we add a cross-attention module only before the LLM. + + .. image:: https://github.com/NVIDIA/NeMo/releases/download/v1.23.0/bestow.png + :align: center + :alt: BESTOW model + :scale: 50% + + + + + diff --git a/docs/source/multimodal/vlm/checkpoint.rst b/docs/source/multimodal/vlm/checkpoint.rst index 996d9828f5aa..d984f1453510 100644 --- a/docs/source/multimodal/vlm/checkpoint.rst +++ b/docs/source/multimodal/vlm/checkpoint.rst @@ -35,58 +35,36 @@ To load a local ``.nemo`` checkpoint: Replace `` with the appropriate MM model class. -Converting Local Checkpoints ----------------------------- - -Only the last checkpoint is automatically saved in the ``.nemo`` format. If intermediate training checkpoints evaluation is required, a ``.nemo`` conversion might be necessary. For this, refer to the script at `script `_: - -.. code-block:: python - - python -m torch.distributed.launch --nproc_per_node= * \ - examples/multimodal/convert_ckpt_to_nemo.py \ - --checkpoint_folder \ - --checkpoint_name \ - --nemo_file_path \ - --tensor_model_parallel_size \ - --pipeline_model_parallel_size - Converting Community Checkpoints -------------------------------- CLIP Checkpoints ^^^^^^^^^^^^^^^^ -To migrate community checkpoints: -.. code-block:: python +To migrate community checkpoints, use the following command: + +.. code-block:: bash - python examples/multimodal/foundation/clip/convert_external_clip_to_nemo.py \ - --arch=ViT-H-14 \ - --version=laion2b_s32b_b79k \ - --hparams_file=path/to/saved.yaml \ - --nemo_file_path=open_clip.nemo + torchrun --nproc-per-node=1 /opt/NeMo/scripts/checkpoint_converters/convert_clip_hf_to_nemo.py \ + --input_name_or_path=openai/clip-vit-large-patch14 \ + --output_path=openai_clip.nemo \ + --hparams_file=/opt/NeMo/examples/multimodal/vision_language_foundation/clip/conf/megatron_clip_VIT-L-14.yaml Ensure the NeMo hparams file has the correct model architectural parameters, placed at `path/to/saved.yaml`. An example can be found in `examples/multimodal/foundation/clip/conf/megatron_clip_config.yaml`. -For OpenCLIP migrations, provide the architecture (`arch`) and version (`version`) according to the OpenCLIP `model list `_. For Hugging Face conversions, set the version to `huggingface` and the architecture (`arch`) to the specific Hugging Face model identifier, e.g., `yuvalkirstain/PickScore_v1`. +After conversion, you can verify the model with the following command: -Model Parallelism Adjustment ----------------------------- +.. code-block:: bash -CLIP Checkpoints -^^^^^^^^^^^^^^^^ + wget https://upload.wikimedia.org/wikipedia/commons/0/0f/1665_Girl_with_a_Pearl_Earring.jpg + torchrun --nproc-per-node=1 /opt/NeMo/examples/multimodal/vision_language_foundation/clip/megatron_clip_infer.py \ + model.restore_from_path=./openai_clip.nemo \ + image_path=./1665_Girl_with_a_Pearl_Earring.jpg \ + texts='["a dog", "a boy", "a girl"]' -To adjust model parallelism from original model parallelism size to a new model parallelism size (Note: NeMo CLIP currently only supports `pipeline_model_parallel_size=1`): +It should generate a high probability for the "a girl" tag. For example: -.. code-block:: python +.. code-block:: text - python examples/nlp/language_modeling/megatron_change_num_partitions.py \ - --model_file=/path/to/source.nemo \ - --target_file=/path/to/target.nemo \ - --tensor_model_parallel_size=??? \ - --target_tensor_model_parallel_size=??? \ - --pipeline_model_parallel_size=-1 \ - --target_pipeline_model_parallel_size=1 \ - --precision=32 \ - --model_class="nemo.collections.multimodal.models.clip.megatron_clip_models.MegatronCLIPModel" \ - --tp_conversion_only + Given image's CLIP text probability: [('a dog', 0.0049710185), ('a boy', 0.002258187), ('a girl', 0.99277073)] diff --git a/docs/source/nlp/nemo_megatron/images/pp_comm_overlap.png b/docs/source/nlp/nemo_megatron/images/pp_comm_overlap.png new file mode 100644 index 000000000000..efaaf8f7274f Binary files /dev/null and b/docs/source/nlp/nemo_megatron/images/pp_comm_overlap.png differ diff --git a/docs/source/nlp/nemo_megatron/images/tp_comm_overlap.png b/docs/source/nlp/nemo_megatron/images/tp_comm_overlap.png new file mode 100644 index 000000000000..4b44b20a343d Binary files /dev/null and b/docs/source/nlp/nemo_megatron/images/tp_comm_overlap.png differ diff --git a/docs/source/starthere/fundamentals.rst b/docs/source/starthere/fundamentals.rst new file mode 100644 index 000000000000..6413cb9d376a --- /dev/null +++ b/docs/source/starthere/fundamentals.rst @@ -0,0 +1,242 @@ +NeMo Fundamentals +================= + +On this page, we’ll look into how NeMo works, providing you with a solid foundation to effectively use NeMo for you :ref:`specific use case `. + +NeMo Models +----------- + +NVIDIA NeMo is a powerful framework for building and deploying neural network models, including those used in generative AI, speech recognition, and natural language processing. NeMo stands for “Neural Modules,” which are the building blocks of the models created using this platform. NeMo includes all of the following components wrapped into a singular, cohesive unit: + +* neural network architecture + +* dataset and data loaders + +* preprocessing of input data and postprocessing of model outputs + +* loss function, optimizer, and schedulers + +* any other supporting infrastructure, such as tokenizers, language model configuration, and data augmentation + +NeMo models are built on PyTorch, with many of their components being subclasses of ``torch.nn.Module``. Additionally, NeMo models utilize PyTorch Lightning (PTL) for training, which helps reduce the boilerplate code required. + +NeMo models are also designed to be easily configurable; often this is done with YAML files. Below we show simplified examples of a NeMo model defined in pseudocode and a config defined in YAML. We highlight the lines where the Python config parameter is read from the YAML file. + +.. list-table:: Simplified examples of a model and config. + :widths: 1 1 + :header-rows: 0 + + * - .. code-block:: python + :caption: NeMo model definition (Python pseudocode) + :linenos: + :emphasize-lines: 4, 7, 10, 13, 16, 20 + + class ExampleEncDecModel: + # cfg is passed so it only contains "model" section + def __init__(self, cfg, trainer): + self.tokenizer = init_from_cfg(cfg.tokenizer) + + + self.encoder = init_from_cfg(cfg.encoder) + + + self.decoder = init_from_cfg(cfg.decoder) + + + self.loss = init_from_cfg(cfg.loss) + + + # optimizer configured via parent class + + + def setup_training_data(self, cfg): + self.train_dl = init_dl_from_cfg(cfg.train_ds) + + def forward(self, batch): + # forward pass defined, + # as is standard for PyTorch models + ... + + def training_step(self, batch): + log_probs = self.forward(batch) + loss = self.loss(log_probs, labels) + return loss + + + - .. code-block:: yaml + :caption: Experiment config (YAML) + :linenos: + :emphasize-lines: 4, 7, 10, 13, 16, 20 + + # + # configuration of the NeMo model + model: + tokenizer: + ... + + encoder: + ... + + decoder: + ... + + loss: + ... + + optim: + ... + + + train_ds: + ... + + # configuration of the + # PyTorch Lightning trainer object + trainer: + ... + + +Configuring and Training NeMo Models +------------------------------------ + +During initialization of the model, the "model" section of the config is passed into the model's constructor (as the variable ``cfg``, see line 3 of the left panel above). The model class will read key parameters from the ``cfg`` variable to configure the model (see highlighted lines in the left panel above). + +The other object passed into the model's constructor is a PyTorch Lightning ``trainer`` object, which manages the training process. The trainer handles the standard training `boilerplate `__. For non-standard tasks, PyTorch Lightning (PTL) relies on specific methods defined in our NeMo model. For example, PTL mandates that every model must have a specified ``training_step`` method (left panel above, line 27). + +The trainer’s configuration is also specified in the config (right panel above, line 25 onwards). This includes parameters such as ``accelerator``, (number of) ``devices``, ``max_steps``, (numerical) ``precision`` and `more `__. + + +Example Training Script +----------------------- + +Below is an example training script for our ``ExampleEncDecModel`` model. We highlight the three most important lines that combine everything we discussed in the previous section: + +.. code-block:: python + :caption: run_example_training.py + :linenos: + :emphasize-lines: 10, 11, 12 + + import pytorch_lightning as pl + from nemo.collections.path_to_model_class import ExampleEncDecModel + from nemo.core.config import hydra_runner + + @hydra_runner( + config_path="config_file_dir_path", + config_name="config_file_name" + ) + def main(cfg): + trainer = pl.Trainer(**cfg.trainer) + model = ExampleEncDecModel(cfg.model, trainer) + trainer.fit(model) + + if __name__ == '__main__': + main(cfg) + + +Let's go through the code: + +* *Lines 1-3*: import statements (second one is made up for the example). +* *Lines 5-8*: the decorator will look for a config file at ``{config_path}/{config_name}.yaml`` and load its contents into the ``cfg`` object that is passed into the ``main`` function on line 9. This functionality is provided by `Hydra `__. Instead of a YAML file, we could also have specified the default config as a dataclass and passed that into the ``@hydra_runner`` decorator. +* *Line 10*: initialize a PTL trainer object using the parameters specified in the ``trainer`` section of the config. +* *Line 11*: initialize a NeMo model, passing in both the parameters in the ``model`` section of the config, and a PTL ``trainer`` object. +* *Line 12*: call ``trainer.fit`` on the model. This one unassuming line will carry out our entire training process. PTL will make sure we iterate over our data and call the ``training_step`` we define for each batch (as well as any other PTL `callbacks `__ that may have been defined). + + + +Overriding Configs +------------------ + +The ``cfg`` object in the script above is a dictionary-like object that contains our configuration parameters. Specifically, it is an `OmegaConf `__ ``DictConfig`` object. These objects have special features such as dot-notation `access `__, `variable interpolation `__, and the ability to set `mandatory values `__. + +You can run the script above by running the following: + +.. code-block:: bash + + python run_example_training.py + +The script will use the default config file specified inside the ``@hydra_runner`` decorator. + +To specify a different config file, you can call the script like this: + +.. code-block:: diff + + python run_example_training.py \ + + --config_path="different_config_file_dir_path" \ + + --config_name="different_config_file_name" + +You can also override, delete, or add elements to the config by calling a script like this: + + +.. code-block:: diff + + python run_example_training.py \ + --config_path="different_config_file_dir_path" \ + --config_name="different_config_file_name" \ + + model.optim.lr=0.001 \ # overwriting + + model.train_ds.manifest_filepath="your_train_data.json" \ # overwriting + + ~trainer.max_epochs \ # deleting + + +trainer.max_steps=1000 # adding + +Running NeMo Scripts +-------------------- + +NeMo scripts typically take on the form shown above, where the Python script relies on a config object which has some specified default values that you can choose to override. + +The NeMo `examples `__ directory provides numerous scripts for training and inference of various existing NeMo models. It’s important to note that these scripts include default configurations for model, optimize, and training parameters, which have been fine-tuned by the NeMo team over extensive GPU-hours of experimentation. As a result, we recommend using these default configurations as a starting point for your own experiments + + +NeMo Inference Scripts +###################### + +The examples scripts directory also contains many inference scripts such as `transcribe_speech.py `_. These inference scripts typically differ in structure from training scripts, as they include additional utilities for file I/O (reading and saving files). While inference scripts still use configurations (configs), they don’t require the ``trainer`` and ``model`` sections. Additionally, the default configs for inference scripts are usually specified as dataclasses rather than separate files. You can also modify elements via the command line. + +Specifying training data +------------------------ + +NeMo will handle creation of data loaders for you, as long as you put your data into the expected input format. You may also need to train a tokenizer before starting training. To learn more about data formats, see :doc:`LLM <../nlp/nemo_megatron/gpt/gpt_training>`, :doc:`Multimodal <../multimodal/mllm/datasets>`, :ref:`Speech AI `, and :doc:`Vision models <../vision/datasets>`. + + +Model Checkpoints +----------------- + +Throughout training, the model :doc:`checkpoints <../checkpoints/intro>` will be saved inside ``.nemo`` files. These are archive files containing all the necessary components to restore a usable model. For example: + +* model weights (``.ckpt`` files) +* model configuration (``.yaml`` files) +* tokenizer files + +The NeMo team also releases pretrained models which you can browse on `NGC `_ and `HuggingFace Hub `_. + + +Fine-Tuning +---------- + +NeMo allows you to fine-tune models as well as train them from scratch. + +You can achieve this by initializing a model with random weights, then replacing some or all of those weights with the pretrained model’s weights. Afterward, continue training as usual, possibly making minor adjustments like reducing the learning rate or freezing specific model parameters. + + +.. _where_next: + +Where To Go Next? +----------- + +Here are some options: + +* Explore Examples or Tutorials: dive into NeMo by exploring our `examples `_ or :doc:`tutorials <./tutorials>` + +* Domain-Specific Documentation: + + * For Large Language Models (LLMs), checkout out the :doc:`LLM <../nlp/nemo_megatron/intro>` documentation. + * For Multimodal tasks, refer to the :doc:`Multimodal <../multimodal/mllm/intro>` documentation. + + * If you’re interested in Automatic Speech Recognition (ASR), explore the :doc:`ASR <../asr/intro>` documentation. + * For Text-to-Speech (TTS), find details in the :doc:`TTS <../tts/intro>` documentation. + * Lastly, for Vision Models, consult the :doc:`Vision Models <../vision/intro>` documentation. + +* `NeMo Primer `__: This tutorial provides a hands-on introduction to NeMo, PyTorch Lightning, and OmegaConf. It covers how to use, modify, save, and restore NeMo models. + +* `NeMo Models `__: In this tutorial, you'll learn the fundamentals of creating NeMo models. + +* NeMo Core Documentation: Explore the :doc:`NeMo Core <../core/core>` documentation for NeMo, which explains the inner workings of the framework. + diff --git a/docs/source/starthere/intro.rst b/docs/source/starthere/intro.rst index 8edb435bec62..6060726d5ba8 100644 --- a/docs/source/starthere/intro.rst +++ b/docs/source/starthere/intro.rst @@ -16,7 +16,7 @@ NeMo is built on top of NVIDIA's powerful Megatron-LM and Transformer Engine for `NVIDIA NeMo Framework `_ features separate collections for Large Language Models (LLMs), Multimodal Models (MMs), Computer Vision (CV), Automatic Speech Recognition (ASR), and Text-to-Speech (TTS) models. Each collection comprises prebuilt modules that include everything needed to train on your data. These modules can be easily customized, extended, and composed to create new generative AI model architectures. -(TODO: Still valid? LLM is not included here.) `Pre-trained NeMo models `_ are available in 14+ languages. +Pre-trained NeMo models are available to download on `NGC `__ and `HuggingFace Hub `__. Prerequisites ------------- diff --git a/docs/source/tts/data/ngc_models_codec.csv b/docs/source/tts/data/ngc_models_codec.csv index 6827c54ce7f4..852d65127d45 100644 --- a/docs/source/tts/data/ngc_models_codec.csv +++ b/docs/source/tts/data/ngc_models_codec.csv @@ -2,3 +2,5 @@ Model Name,Dataset,Sampling Rate,Model Class,Overview,Checkpoint audio_codec_16khz_small,Libri-Light,16000Hz,nemo.collections.tts.models.AudioCodecModel,`audio_codec_16khz_small `_,``https://api.ngc.nvidia.com/v2/models/nvidia/nemo/audio_codec_16khz_small/versions/v1/files/audio_codec_16khz_small.nemo`` mel_codec_22khz_medium,LibriVox and Common Voice,22050Hz,nemo.collections.tts.models.AudioCodecModel,`mel_codec_22khz_medium `_,``https://api.ngc.nvidia.com/v2/models/nvidia/nemo/mel_codec_22khz_medium/versions/v1/files/mel_codec_22khz_medium.nemo`` mel_codec_44khz_medium,LibriVox and Common Voice,44100Hz,nemo.collections.tts.models.AudioCodecModel,`mel_codec_44khz_medium `_,``https://api.ngc.nvidia.com/v2/models/nvidia/nemo/mel_codec_44khz_medium/versions/v1/files/mel_codec_44khz_medium.nemo`` +mel_codec_22khz_fullband_medium,LibriVox and Common Voice,22050Hz,nemo.collections.tts.models.AudioCodecModel,`mel_codec_22khz_fullband_medium `_,``https://api.ngc.nvidia.com/v2/models/nvidia/nemo/mel_codec_22khz_fullband_medium/versions/v1/files/mel_codec_22khz_fullband_medium.nemo`` +mel_codec_44khz_fullband_medium,LibriVox and Common Voice,44100Hz,nemo.collections.tts.models.AudioCodecModel,`mel_codec_44khz_fullband_medium `_,``https://api.ngc.nvidia.com/v2/models/nvidia/nemo/mel_codec_44khz_fullband_medium/versions/v1/files/mel_codec_44khz_fullband_medium.nemo`` diff --git a/examples/asr/conf/asr_finetune/speech_to_text_finetune.yaml b/examples/asr/conf/asr_finetune/speech_to_text_finetune.yaml index 415172b33bb9..6808f4941916 100644 --- a/examples/asr/conf/asr_finetune/speech_to_text_finetune.yaml +++ b/examples/asr/conf/asr_finetune/speech_to_text_finetune.yaml @@ -6,10 +6,6 @@ init_from_nemo_model: null # path to nemo model model: sample_rate: 16000 - compute_eval_loss: false # eval samples can be very long and exhaust memory. Disable computation of transducer loss during validation/testing with this flag. - log_prediction: true # enables logging sample predictions in the output during training - rnnt_reduction: 'mean_volume' - skip_nan_grad: false train_ds: manifest_filepath: ??? diff --git a/examples/asr/conf/asr_finetune/speech_to_text_hf_finetune.yaml b/examples/asr/conf/asr_finetune/speech_to_text_hf_finetune.yaml index b8d84d197012..172d09ccd60b 100644 --- a/examples/asr/conf/asr_finetune/speech_to_text_hf_finetune.yaml +++ b/examples/asr/conf/asr_finetune/speech_to_text_hf_finetune.yaml @@ -7,10 +7,6 @@ init_from_pretrained_model: null # name of pretrained NeMo model, e.g., `stt_en model: sample_rate: 16000 - compute_eval_loss: false # eval samples can be very long and exhaust memory. Disable computation of transducer loss during validation/testing with this flag. - log_prediction: true # enables logging sample predictions in the output during training - rnnt_reduction: 'mean_volume' - skip_nan_grad: false # configs for huggingface load_dataset function data_path: "librispeech_asr" diff --git a/examples/asr/speech_to_text_finetune.py b/examples/asr/speech_to_text_finetune.py index dbdefef34682..ee043c0bd131 100644 --- a/examples/asr/speech_to_text_finetune.py +++ b/examples/asr/speech_to_text_finetune.py @@ -19,7 +19,11 @@ 1) `init_from_nemo_model` or 2) `init_from_pretrained_model` in the configuration. -To update the model architecture in conjunction with other modifications, it is advisable to use the primary 'speech_to_text_rnnt/ctc_*.py' script. +**************************************************************************************** +This script is mainly intended for changing the dataset, optim, spec_augment, vocabulary/tokenizer of the model. +To update the model architecture in conjunction with other modifications, +it is advisable to use the primary 'speech_to_text_rnnt/ctc_*.py' script. +**************************************************************************************** Note: To create a single script for all model types, we currently only support two types of initializations: @@ -135,7 +139,7 @@ def check_vocabulary(asr_model, cfg): def update_tokenizer(asr_model, tokenizer_dir, tokenizer_type): """ - Updates the tokenizer of the model and also reinitializes the decoder if the vocabulary size + Updates the tokenizer of the model and also reinitializes the decoder if the vocabulary size of the new tokenizer differs from that of the loaded model. Args: asr_model: ASRModel instance diff --git a/examples/audio/audio_to_audio_eval.py b/examples/audio/audio_to_audio_eval.py index 4e60b2ec2b52..c7b9db6efb80 100644 --- a/examples/audio/audio_to_audio_eval.py +++ b/examples/audio/audio_to_audio_eval.py @@ -75,7 +75,7 @@ from nemo.collections.audio.data import audio_to_audio_dataset from nemo.collections.audio.data.audio_to_audio_lhotse import LhotseAudioToTargetDataset -from nemo.collections.audio.metrics.audio import AudioMetricWrapper +from nemo.collections.audio.metrics import AudioMetricWrapper, SquimMOSMetric, SquimObjectiveMetric from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config from nemo.collections.common.parts.preprocessing import manifest from nemo.core.config import hydra_runner @@ -128,7 +128,17 @@ def get_evaluation_dataloader(config): def get_metrics(cfg: AudioEvaluationConfig): """Prepare a dictionary with metrics.""" - available_metrics = ['sdr', 'sisdr', 'stoi', 'estoi', 'pesq'] + available_metrics = [ + 'sdr', + 'sisdr', + 'stoi', + 'estoi', + 'pesq', + 'squim_mos', + 'squim_stoi', + 'squim_pesq', + 'squim_si_sdr', + ] metrics = dict() for name in sorted(set(cfg.metrics)): @@ -143,6 +153,14 @@ def get_metrics(cfg: AudioEvaluationConfig): metric = AudioMetricWrapper(metric=ShortTimeObjectiveIntelligibility(fs=cfg.sample_rate, extended=True)) elif name == 'pesq': metric = AudioMetricWrapper(metric=PerceptualEvaluationSpeechQuality(fs=cfg.sample_rate, mode='wb')) + elif name == 'squim_mos': + metric = AudioMetricWrapper(metric=SquimMOSMetric(fs=cfg.sample_rate)) + elif name == 'squim_stoi': + metric = AudioMetricWrapper(metric=SquimObjectiveMetric(metric='stoi', fs=cfg.sample_rate)) + elif name == 'squim_pesq': + metric = AudioMetricWrapper(metric=SquimObjectiveMetric(metric='pesq', fs=cfg.sample_rate)) + elif name == 'squim_si_sdr': + metric = AudioMetricWrapper(metric=SquimObjectiveMetric(metric='si_sdr', fs=cfg.sample_rate)) else: raise ValueError(f'Unexpected metric: {name}. Currently available metrics: {available_metrics}') diff --git a/examples/audio/audio_to_audio_train.py b/examples/audio/audio_to_audio_train.py index 2dc91036234f..b197d2084144 100644 --- a/examples/audio/audio_to_audio_train.py +++ b/examples/audio/audio_to_audio_train.py @@ -35,6 +35,7 @@ from nemo.collections.audio.models.enhancement import ( EncMaskDecAudioToAudioModel, PredictiveAudioToAudioModel, + SchroedingerBridgeAudioToAudioModel, ScoreBasedGenerativeAudioToAudioModel, ) from nemo.core.config import hydra_runner @@ -48,6 +49,7 @@ class ModelType(str, Enum): MaskBased = 'mask_based' Predictive = 'predictive' ScoreBased = 'score_based' + SchroedingerBridge = 'schroedinger_bridge' def get_model_class(model_type: ModelType): @@ -58,6 +60,8 @@ def get_model_class(model_type: ModelType): return PredictiveAudioToAudioModel elif model_type == ModelType.ScoreBased: return ScoreBasedGenerativeAudioToAudioModel + elif model_type == ModelType.SchroedingerBridge: + return SchroedingerBridgeAudioToAudioModel else: raise ValueError(f'Unknown model type: {model_type}') diff --git a/examples/audio/conf/schroedinger_bridge.yaml b/examples/audio/conf/schroedinger_bridge.yaml new file mode 100644 index 000000000000..8751b91afaee --- /dev/null +++ b/examples/audio/conf/schroedinger_bridge.yaml @@ -0,0 +1,164 @@ +name: schroedinger_bridge + +model: + type: schroedinger_bridge + sample_rate: 16000 + skip_nan_grad: false + num_outputs: 1 + normalize_input: true + max_utts_evaluation_metrics: 50 # metric calculation needs full inference and is slow, so we limit to first few files + + train_ds: + manifest_filepath: ??? + input_key: noisy_filepath + target_key: clean_filepath + audio_duration: 2.04 # 256 frames + random_offset: true + normalize_input: ${model.normalize_input} + batch_size: 8 # batch size may be increased based on the available memory + shuffle: true + num_workers: 8 + pin_memory: true + + validation_ds: + manifest_filepath: ??? + input_key: noisy_filepath + target_key: clean_filepath + normalize_input: false # load data as is for validation, the model will normalize it for inference + batch_size: 4 + shuffle: false + num_workers: 4 + pin_memory: true + + encoder: + _target_: nemo.collections.audio.modules.transforms.AudioToSpectrogram + fft_length: 510 + hop_length: 128 + magnitude_power: 0.5 + scale: 0.33 + + decoder: + _target_: nemo.collections.audio.modules.transforms.SpectrogramToAudio + fft_length: ${model.encoder.fft_length} + hop_length: ${model.encoder.hop_length} + magnitude_power: ${model.encoder.magnitude_power} + scale: ${model.encoder.scale} + + estimator: + _target_: nemo.collections.audio.parts.submodules.ncsnpp.SpectrogramNoiseConditionalScoreNetworkPlusPlus + in_channels: 2 # concatenation of single-channel perturbed and noisy + out_channels: 1 # single-channel estimate + conditioned_on_time: true + num_res_blocks: 3 # increased number of res blocks + pad_time_to: 64 # pad to 64 frames for the time dimension + pad_dimension_to: 0 # no padding in the frequency dimension + + estimator_output: data_prediction + + noise_schedule: + _target_: nemo.collections.audio.parts.submodules.schroedinger_bridge.SBNoiseScheduleVE + k: 2.6 + c: 0.4 + time_min: 1e-4 + time_max: 1.0 + num_steps: 1000 # num steps for the forward process + + sampler: + _target_: nemo.collections.audio.parts.submodules.schroedinger_bridge.SBSampler + time_min: 1e-4 + time_max: 1.0 + num_steps: 50 # num steps for the reverse process + + # Loss in the encoded domain + loss_encoded: + _target_: nemo.collections.audio.losses.MSELoss + ndim: 4 # loss is calculated on the score in the encoded domain (batch, channel, dimension, time) + + # Loss in the time domain + loss_time: + _target_: nemo.collections.audio.losses.MAELoss + loss_time_weight: 0.001 + + metrics: + val: + sisdr: # output SI-SDR + _target_: torchmetrics.audio.ScaleInvariantSignalDistortionRatio + estoi: # output ESTOI + _target_: torchmetrics.audio.ShortTimeObjectiveIntelligibility + fs: ${model.sample_rate} + extended: true + pesq: # output PESQ + _target_: torchmetrics.audio.PerceptualEvaluationSpeechQuality + fs: ${model.sample_rate} + mode: wb + + optim: + name: adam + lr: 1e-4 + # optimizer arguments + betas: [0.9, 0.999] + weight_decay: 0.0 + +trainer: + devices: -1 # number of GPUs, -1 would use all available GPUs + num_nodes: 1 + max_epochs: -1 + max_steps: -1 # computed at runtime if not set + val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations + accelerator: auto + strategy: ddp + accumulate_grad_batches: 1 + gradient_clip_val: null + precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP. + log_every_n_steps: 25 # Interval of logging. + enable_progress_bar: true + num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it + check_val_every_n_epoch: 5 # number of evaluations on validation every n epochs + sync_batchnorm: true + enable_checkpointing: false # Provided by exp_manager + logger: false # Provided by exp_manager + +exp_manager: + exp_dir: null + name: ${name} + + # use exponential moving average for model parameters + ema: + enable: true + decay: 0.999 # decay rate + cpu_offload: false # offload EMA parameters to CPU to save GPU memory + every_n_steps: 1 # how often to update EMA weights + validate_original_weights: false # use original weights for validation calculation? + + # logging + create_tensorboard_logger: true + + # checkpointing + create_checkpoint_callback: true + checkpoint_callback_params: + # in case of multiple validation sets, first one is used + monitor: val_pesq + mode: max + save_top_k: 5 + always_save_nemo: true # saves the checkpoints as nemo files instead of PTL checkpoints + + # early stopping + create_early_stopping_callback: true + early_stopping_callback_params: + monitor: val_sisdr + mode: max + min_delta: 0.0 + patience: 20 # patience in terms of check_val_every_n_epoch + verbose: true + strict: false # Should be False to avoid a runtime error where EarlyStopping says monitor is unavailable, which sometimes happens with resumed training. + + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + # you need to set these two to true to continue the training + resume_if_exists: false + resume_ignore_no_checkpoint: false + + # You may use this section to create a W&B logger + create_wandb_logger: false + wandb_logger_kwargs: + name: null + project: null diff --git a/examples/audio/process_audio.py b/examples/audio/process_audio.py index 6cf7a8499122..e28fb4e69627 100644 --- a/examples/audio/process_audio.py +++ b/examples/audio/process_audio.py @@ -16,7 +16,7 @@ import glob import json import os -from dataclasses import dataclass, is_dataclass +from dataclasses import dataclass, field, is_dataclass from pathlib import Path from typing import List, Optional @@ -96,6 +96,10 @@ class ProcessConfig: # Override model config override_config_path: Optional[str] = None # path to a yaml config that will override the internal config file + # Override sampler config + # For example, to set number of steps, use `++sampler.num_samples=42` + sampler: dict = field(default_factory=dict) + # Set `cuda` to int to define CUDA device. If 'None', will look for CUDA # device anyway, and do inference on CPU only if CUDA device is not found. # If `cuda` is a negative number, inference will be on CPU only. @@ -155,6 +159,22 @@ def main(cfg: ProcessConfig) -> ProcessConfig: audio_to_audio_model.set_trainer(trainer) audio_to_audio_model = audio_to_audio_model.eval() + # override sampler + if cfg.sampler is not None: + logging.info('Overriding sampler with %s', cfg.sampler) + + if hasattr(audio_to_audio_model, 'sampler'): + for key, value in cfg.sampler.items(): + if not hasattr(audio_to_audio_model.sampler, key): + raise RuntimeError(f'Model sampler does not have attribute {key}') + logging.debug('Try to set model.sampler.%s to %s', key, value) + setattr(audio_to_audio_model.sampler, key, value) + if getattr(audio_to_audio_model.sampler, key) != value: + raise RuntimeError(f'Failed to set model sampler attribute {key} to {value}') + logging.info('model.sampler.%s was set to %s', key, value) + else: + raise RuntimeError('Model does not have a sampler') + if cfg.audio_dir is not None: filepaths = list(glob.glob(os.path.join(cfg.audio_dir, f"**/*.{cfg.audio_type}"), recursive=True)) else: diff --git a/examples/llm/megatron_gpt_pretraining.py b/examples/llm/megatron_gpt_pretraining.py new file mode 100644 index 000000000000..d3d049e4296e --- /dev/null +++ b/examples/llm/megatron_gpt_pretraining.py @@ -0,0 +1,109 @@ +## NOTE: This script is present for github-actions testing only. +## There are no guarantees that this script is up-to-date with latest NeMo. + +import argparse + +from megatron.core.optimizer import OptimizerConfig +from pytorch_lightning.loggers import TensorBoardLogger + +from nemo import lightning as nl +from nemo.collections import llm +from nemo.collections.llm.api import train +from nemo.collections.llm.gpt.data import PreTrainingDataModule +from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer +from nemo.lightning import NeMoLogger +from nemo.lightning.pytorch.callbacks import ModelCheckpoint +from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule + + +def get_args(): + parser = argparse.ArgumentParser(description='Train a small GPT model using NeMo 2.0') + parser.add_argument('--devices', type=int, help="Number of devices to use for training") + parser.add_argument('--max-steps', type=int, help="Number of steps to train for") + parser.add_argument('--experiment-dir', type=str, help="directory to write results and checkpoints to") + parser.add_argument('--data-path', type=str, help="Path to data file") + parser.add_argument('--vocab-path', type=str, help="Path to vocab file") + parser.add_argument('--merges-path', type=str, help="Path to merges file") + parser.add_argument('--index-mapping-dir', type=str, help="directory to write index mappings to") + + return parser.parse_args() + + +if __name__ == '__main__': + + args = get_args() + + seq_length = 2048 + + tokenizer = get_nmt_tokenizer( + "megatron", + "GPT2BPETokenizer", + vocab_file=args.vocab_path, + merges_file=args.merges_path, + ) + data = PreTrainingDataModule( + paths=args.data_path, + seq_length=2048, + global_batch_size=32, + seed=1234, + tokenizer=tokenizer, + ) + gpt_config = llm.GPTConfig( + num_layers=12, + hidden_size=768, + ffn_hidden_size=3072, + num_attention_heads=12, + seq_length=seq_length, + init_method_std=0.023, + hidden_dropout=0.1, + attention_dropout=0.1, + layernorm_epsilon=1e-5, + make_vocab_size_divisible_by=128, + ) + model = llm.GPTModel(gpt_config, tokenizer=data.tokenizer) + strategy = nl.MegatronStrategy() + checkpoint_callback = ModelCheckpoint( + every_n_train_steps=5000, + enable_nemo_ckpt_io=False, + ) + callbacks = [checkpoint_callback] + + loggers = [] + tensorboard_logger = TensorBoardLogger( + save_dir='dummy', ## NOTE: this gets overwritten by default + ) + loggers.append(tensorboard_logger) + + opt_config = OptimizerConfig( + optimizer='adam', + lr=6e-4, + min_lr=6e-5, + use_distributed_optimizer=False, + bf16=True, + ) + opt = MegatronOptimizerModule(config=opt_config) + + trainer = nl.Trainer( + devices=args.devices, + max_steps=args.max_steps, + accelerator="gpu", + strategy=strategy, + logger=loggers, + callbacks=callbacks, + log_every_n_steps=1, + limit_val_batches=2, + plugins=nl.MegatronMixedPrecision(precision="bf16-mixed", amp_O2=False), + ) + + nemo_logger = NeMoLogger( + dir=args.experiment_dir, + ) + + train( + model=model, + data=data, + trainer=trainer, + log=nemo_logger, + tokenizer='data', + optim=opt, + ) diff --git a/examples/multimodal/multimodal_llm/neva/conf/lita_config.yaml b/examples/multimodal/multimodal_llm/neva/conf/lita_config.yaml new file mode 100644 index 000000000000..2e20fe0be272 --- /dev/null +++ b/examples/multimodal/multimodal_llm/neva/conf/lita_config.yaml @@ -0,0 +1,242 @@ +name: nemo_video_lita_neva +restore_from_path: null # used when starting from a .nemo file + +trainer: + devices: 1 + num_nodes: 1 + accelerator: gpu + precision: bf16 + logger: False # logger provided by exp_manager + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: -1 # PTL default. In practice, max_steps will be reached first. + max_steps: 10000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches + log_every_n_steps: 10 + val_check_interval: 100 + check_val_every_n_epoch: null + limit_val_batches: 50 + limit_test_batches: 500 + accumulate_grad_batches: 1 # do not modify, grad acc is automatic for training megatron models + gradient_clip_val: 1.0 + benchmark: False + enable_model_summary: False # default PTL callback for this does not support model parallelism, instead we log manually + +exp_manager: + explicit_log_dir: null + exp_dir: null + name: nemo_video_neva_lita + create_wandb_logger: True + wandb_logger_kwargs: + project: null + name: null + resume_if_exists: True + resume_ignore_no_checkpoint: True + resume_from_checkpoint: ${model.resume_from_checkpoint} + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: val_loss + save_top_k: 5 + mode: min + always_save_nemo: False # saves nemo file during validation, not implemented for model parallel + save_nemo_on_train_end: True # not recommended when training large models on clusters with short time limits + filename: 'megatron_clip--{val_loss:.2f}-{step}-{consumed_samples}' + model_parallel_size: ${multiply:${model.tensor_model_parallel_size}, ${model.pipeline_model_parallel_size}} + ema: + enable: False + decay: 0.9999 + validate_original_weights: False + every_n_steps: 1 + cpu_offload: False + +model: + precision: ${trainer.precision} + + # specify micro_batch_size, global_batch_size, and model parallelism + # gradient accumulation will be done automatically based on data_parallel_size + + # Batch size guideline for different types of dataset + micro_batch_size: 1 # limited by GPU memory + global_batch_size: 2 # will use more micro batches to reach global batch size + + tensor_model_parallel_size: 1 # intra-layer model parallelism + pipeline_model_parallel_size: 1 # inter-layer model parallelism + context_parallel_size: 1 # kqv model parallelism + virtual_pipeline_model_parallel_size: null # interleaved pipeline + + restore_from_path: null # used in fine-tuning + + # Multimodal configs + mm_cfg: + llm: + from_pretrained: null #path to nemo checkpoint + freeze: False + model_type: llama_2 # `v1`, `nvgpt`, `llama_2`, `llama_3` and `mistral` supported + vision_encoder: + from_pretrained: "Lin-Chen/ShareGPT4V-13B_Pretrained_vit-large336-l12" # huggingface path or name + from_hf: True + crop_size: [336, 336] + patch_dim: 14 + hidden_size: 1024 # could be found from model but tricky in code + vision_select_layer: -2 # default to the last layer + class_token_length: 1 + freeze: True + lita: + lita_video_arch: 'temporal_all_resolution' # ['temporal_spatial_pool', 'temporal_spatial', 'temporal_all_resolution'] 'temporal_spatial_pool' is used in lita1.0 + visual_token_format: 'im_vid_start_end' # ["v1", "im_vid_start_end"] v1 means do nothing, im_vid_start_end means add image and video start and end tokens around spatial and temporal tokens + sample_frames: 4 # for lita 1.5 sample_frames are used for spatial tokens, and spatial tokens will no longer do pooling and instead, it will use full tokens + use_lita: True + pretrain_mm_mlp_adapter: null # path to pretrained mm adapter + mm_mlp_adapter_type: mlp2x_gelu # ['linear', 'mlp2x_gelu', 'mlp_downsample'] + use_im_start_end: False + + # ========LORA configs start======= + #peft: + # peft_scheme: "lora" + # restore_from_path: null + # lora_tuning: + # adapter_dim: 128 + # alpha: 256 + # target_modules: ['all'] + # adapter_dropout: 0.0 + # column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal + # row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal + # layer_selection: null # selects in which layers to add lora adapters. e.g. [1,12] will add lora to layer 1 (lowest) and 12. null will apply adapters to all layers + # weight_tying: False + # position_embedding_strategy: null # used only when weight_tying is True + # =======LORA configs end======= + + # LLM configs + # use GPTModel from megatron.core + mcore_gpt: True + + # model architecture + encoder_seq_length: 4096 + max_position_embeddings: ${.encoder_seq_length} + position_embedding_type: rope + num_layers: 32 + hidden_size: 4096 + ffn_hidden_size: 11008 # Transformer FFN hidden size. Usually 4 * hidden_size. + num_attention_heads: 32 + init_method_std: 0.014 # Standard deviation of the zero mean normal distribution used for weight initialization.') + use_scaled_init_method: True # use scaled residuals initialization + hidden_dropout: 0.0 # Dropout probability for hidden state transformer. + attention_dropout: 0.0 # Dropout probability for attention + ffn_dropout: 0.0 # Dropout probability in the feed-forward layer. + kv_channels: null # Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if null + apply_query_key_layer_scaling: True # scale Q * K^T by 1 / layer-number. + normalization: 'rmsnorm' # Normalization layer to use. Options are 'layernorm', 'rmsnorm' + layernorm_epsilon: 1e-5 + do_layer_norm_weight_decay: False # True means weight decay on all params + make_vocab_size_divisible_by: 16 # Pad the vocab size to be divisible by this value for computation efficiency. + pre_process: True # add embedding + post_process: True # add pooler + persist_layer_norm: True # Use of persistent fused layer norm kernel. + bias: False # Whether to use bias terms in all weight matrices. + activation: 'fast-swiglu' # Options ['gelu', 'geglu', 'swiglu', 'reglu', 'squared-relu', 'fast-geglu', 'fast-swiglu', 'fast-reglu'] + headscale: False # Whether to learn extra parameters that scale the output of the each self-attention head. + transformer_block_type: 'pre_ln' # Options ['pre_ln', 'post_ln', 'normformer'] + normalize_attention_scores: True # Whether to scale the output Q * K^T by 1 / sqrt(hidden_size_per_head). This arg is provided as a configuration option mostly for compatibility with models that have been weight-converted from HF. You almost always want to se this to True. + rotary_percentage: 1.0 # If using position_embedding_type=rope, then the per head dim is multiplied by this. + attention_type: 'multihead' # Attention type. Options ['multihead'] + share_embeddings_and_output_weights: False # Share embedding and output layer weights. + overlap_p2p_comm: False # Overlap p2p communication with computes. This argument is valid only when `virtual_pipeline_model_parallel_size` is larger than 1 + batch_p2p_comm: True # Batch consecutive inter-peer send/recv operations. This argument is valid only when `virtual_pipeline_model_parallel_size` is larger than 1 + seq_len_interpolation_factor: null # RoPE Interpolation factor for sequence length. This is used to build long-context models with RoPE ex: https://arxiv.org/abs/2306.15595. + num_query_groups: null # Number of query groups for group query attention. If None, normal attention is used. + + ## Activation Checkpointing + activations_checkpoint_granularity: null # 'selective' or 'full' + activations_checkpoint_method: null # 'uniform', 'block', not used with 'selective' + activations_checkpoint_num_layers: null # not used with 'selective' + num_micro_batches_with_partial_activation_checkpoints: null + activations_checkpoint_layers_per_pipeline: null + sequence_parallel: False + + # precision + native_amp_init_scale: 4294967296 # 2 ** 32 + native_amp_growth_interval: 1000 + hysteresis: 2 # Gradient scale hysteresis + fp32_residual_connection: False # Move residual connections to fp32 + fp16_lm_cross_entropy: False # Move the cross entropy unreduced loss calculation for lm head to fp16 + + # model fusions + masked_softmax_fusion: True # Use a kernel that fuses the attention softmax with it's mask. + bias_dropout_add_fusion: False # Use a kernel that fuses the bias addition, dropout and residual connection addition. + + use_cpu_initialization: False # Init weights on the CPU (slow for large models) + onnx_safe: False # Use work-arounds for known problems with Torch ONNX exporter. + gradient_accumulation_fusion: False # Fuse weight gradient accumulation to GEMMs. Only used with pipeline parallelism. + openai_gelu: False + bias_activation_fusion: False + megatron_legacy: False + + transformer_engine: True + fp8: False # enables fp8 in TransformerLayer forward + fp8_e4m3: False # sets fp8_format = recipe.Format.E4M3 + fp8_hybrid: False # sets fp8_format = recipe.Format.HYBRID + fp8_margin: 0 # scaling margin + fp8_interval: 1 # scaling update interval + fp8_amax_history_len: 1 # Number of steps for which amax history is recorded per tensor + fp8_amax_compute_algo: most_recent # 'most_recent' or 'max'. Algorithm for computing amax from history + use_emha: False # Use fused multi-head attention for large sequence-length. Note this is not yet supported. Please set to False. + + # Megatron O2-style half-precision + megatron_amp_O2: False # Enable O2-level automatic mixed precision using main parameters + async_grad_allreduce: False + grad_allreduce_chunk_size_mb: 125 + grad_div_ar_fusion: True # Fuse grad division into torch.distributed.all_reduce + + # miscellaneous + seed: 1234 + resume_from_checkpoint: null # manually set the checkpoint file to load from + apex_transformer_log_level: 30 # Python logging level displays logs with severity greater than or equal to this + gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory) + + tokenizer: + library: 'sentencepiece' + type: null + model: /ws/converted_nemo_model/tokenizer_1_5.model + vocab_file: null + merge_file: null + delimiter: null # only used for tabular tokenizer + sentencepiece_legacy: False # Legacy=True allows you to add special tokens to sentencepiece tokenizers. + additional_special_tokens: null # ["", "", "", "", "", ""] + + data: + packed_sequence: False + num_workers: 8 + dataloader_type: cyclic + data_path: null + lazy_preprocess: True + is_multimodal: True + media_type: video # currently supported: image or video + splice_single_frame: null # 'first', 'middle', 'last' will represent video as first / middle / last frame only, all other frames discarded. + num_frames: 256 # selects the number of frames to use from the video + sep_token_between_frames: False # TODO: allow usage of separator tokens between frames + sep_image_conv_front: False + image_token_len: 576 #lita 1.0 uses 256 + conv_template: v1 # check `nemo/collections/multimodal/data/neva/conversation.py` + image_folder: null + video_folder: null + image_aspect_ratio: 'pad' # lita 1.0 uses 'square' + + # Nsys profiling options + nsys_profile: + enabled: False + start_step: 10 # Global batch to start profiling + end_step: 10 # Global batch to end profiling + ranks: [ 0 ] # Global rank IDs to profile + gen_shape: False # Generate model and kernel details including input shapes + + optim: + name: fused_adam + lr: 2e-5 + weight_decay: 0. + betas: + - 0.9 + - 0.95 + sched: + name: CosineAnnealing + warmup_steps: 140 + constant_steps: 0 + min_lr: 2e-7 diff --git a/examples/multimodal/multimodal_llm/neva/conf/llava_config.yaml b/examples/multimodal/multimodal_llm/neva/conf/llava_config.yaml index 3ec90b2d1b53..d8a31fa19ca9 100644 --- a/examples/multimodal/multimodal_llm/neva/conf/llava_config.yaml +++ b/examples/multimodal/multimodal_llm/neva/conf/llava_config.yaml @@ -71,10 +71,10 @@ model: freeze: False model_type: llama_2 # Only support nvgpt or llama_2 vision_encoder: - from_pretrained: "openai/clip-vit-large-patch14" # path or name + from_pretrained: "openai/clip-vit-large-patch14-336" # path or name from_hf: True patch_dim: 14 - crop_size: [224, 224] + crop_size: [336, 336] hidden_size: 1024 # could be found from model but tricky in code vision_select_layer: -2 # default to the last layer class_token_length: 1 diff --git a/examples/multimodal/multimodal_llm/neva/conf/neva_config.yaml b/examples/multimodal/multimodal_llm/neva/conf/neva_config.yaml index 9ec6e51bb004..9315b0fa3712 100644 --- a/examples/multimodal/multimodal_llm/neva/conf/neva_config.yaml +++ b/examples/multimodal/multimodal_llm/neva/conf/neva_config.yaml @@ -69,7 +69,7 @@ model: llm: from_pretrained: null # path to nemo checkpoint freeze: True - model_type: llama_2 # `nvgpt` or `llama_2` supported + model_type: llama_2 # `v1`, `nvgpt`, `llama_2`, `llama_3` and `mistral` supported vision_encoder: from_pretrained: "" # path or name from_hf: True diff --git a/examples/multimodal/multimodal_llm/neva/conf/neva_export.yaml b/examples/multimodal/multimodal_llm/neva/conf/neva_export.yaml index 5a163b250566..1ab9bdbd6398 100644 --- a/examples/multimodal/multimodal_llm/neva/conf/neva_export.yaml +++ b/examples/multimodal/multimodal_llm/neva/conf/neva_export.yaml @@ -6,10 +6,11 @@ infer: max_input_len: 4096 max_output_len: 256 max_multimodal_len: 3072 + vision_max_batch_size: 1 #256 for lita/vita when inference with video dataset model: - type: neva + type: neva #neva, video-neva, lita, vila, vita precision: bfloat16 visual_model_path: /path/to/visual.nemo llm_model_path: /path/to/llm.nemo - llm_model_type: llama + llm_model_type: llama diff --git a/examples/multimodal/multimodal_llm/neva/conf/neva_mixtral_config.yaml b/examples/multimodal/multimodal_llm/neva/conf/neva_mixtral_config.yaml new file mode 100644 index 000000000000..6e3fb19cdab6 --- /dev/null +++ b/examples/multimodal/multimodal_llm/neva/conf/neva_mixtral_config.yaml @@ -0,0 +1,220 @@ +trainer: + devices: 1 + num_nodes: 1 + accelerator: gpu + precision: bf16 + logger: False # logger provided by exp_manager + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: -1 # PTL default. In practice, max_steps will be reached first. + max_steps: 4650 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches + log_every_n_steps: 10 + val_check_interval: 100 + check_val_every_n_epoch: null + limit_val_batches: 50 + limit_test_batches: 500 + accumulate_grad_batches: 1 # do not modify, grad acc is automatic for training megatron models + gradient_clip_val: 1.0 + benchmark: False + enable_model_summary: False # default PTL callback for this does not support model parallelism, instead we log manually + +exp_manager: + explicit_log_dir: null + exp_dir: null + name: nemo_neva + create_wandb_logger: False + wandb_logger_kwargs: + project: null + name: null + resume_if_exists: True + resume_ignore_no_checkpoint: True + resume_from_checkpoint: ${model.resume_from_checkpoint} + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: val_loss + save_top_k: 10 + mode: min + always_save_nemo: False # saves nemo file during validation, not implemented for model parallel + save_nemo_on_train_end: True # not recommended when training large models on clusters with short time limits + filename: 'megatron_clip--{val_loss:.2f}-{step}-{consumed_samples}' + model_parallel_size: ${multiply:${model.tensor_model_parallel_size}, ${model.pipeline_model_parallel_size}} + ema: + enable: False + decay: 0.9999 + validate_original_weights: False + every_n_steps: 1 + cpu_offload: False + +model: + precision: ${trainer.precision} + + # specify micro_batch_size, global_batch_size, and model parallelism + # gradient accumulation will be done automatically based on data_parallel_size + + # Batch size guideline for different types of dataset + micro_batch_size: 1 # limited by GPU memory + global_batch_size: 1 # will use more micro batches to reach global batch size + + tensor_model_parallel_size: 1 # intra-layer model parallelism + pipeline_model_parallel_size: 1 # inter-layer model parallelism + expert_model_parallel_size: 1 + context_parallel_size: 1 + virtual_pipeline_model_parallel_size: null # interleaved pipeline + + restore_from_path: null # used in fine-tuning + + # Multimodal configs + mm_cfg: + llm: + from_pretrained: null + freeze: True + model_type: mistral # `v1`, `nvgpt`, `llama_2`, `llama_3` and `mistral` supported + vision_encoder: + from_pretrained: 'google/siglip-so400m-patch14-384' # path or name + from_hf: True + patch_dim: 14 + crop_size: [384, 384] + hidden_size: 1152 # could be found from model but tricky in code + vision_select_layer: -2 # default to the last layer + class_token_length: 0 + freeze: True + pretrain_mm_mlp_adapter: null # path to pretrained mm adapter + mm_mlp_adapter_type: mlp_downsample + use_im_start_end: False + + + # LLM configs + # use GPTModel from megatron.core + mcore_gpt: True + + moe_grouped_gemm: False + moe_token_dispatcher_type: alltoall + moe_aux_loss_coeff: 0.01 + # model architecture + encoder_seq_length: 4096 + max_position_embeddings: 32768 + position_embedding_type: rope + num_layers: 32 + hidden_size: 4096 + ffn_hidden_size: 14336 # Transformer FFN hidden size. Usually 4 * hidden_size. + num_attention_heads: 32 + init_method_std: 0.02 # Standard deviation of the zero mean normal distribution used for weight initialization.') + use_scaled_init_method: True # use scaled residuals initialization + hidden_dropout: 0.0 # Dropout probability for hidden state transformer. + attention_dropout: 0.0 # Dropout probability for attention + ffn_dropout: 0.0 # Dropout probability in the feed-forward layer. + kv_channels: null # Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if null + apply_query_key_layer_scaling: False # scale Q * K^T by 1 / layer-number. + normalization: rmsnorm # Type of normalization layers + layernorm_epsilon: 1.0e-05 + do_layer_norm_weight_decay: False # True means weight decay on all params + make_vocab_size_divisible_by: 128 # Pad the vocab size to be divisible by this value for computation efficiency. + pre_process: True # add embedding + post_process: True # add pooler + persist_layer_norm: True # Use of persistent fused layer norm kernel. + bias: False # Whether to use bias terms in all weight matrices. + activation: 'fast-swiglu' # Options ['gelu', 'geglu', 'swiglu', 'reglu', 'squared-relu', 'fast-geglu', 'fast-swiglu', 'fast-reglu'] + headscale: False # Whether to learn extra parameters that scale the output of the each self-attention head. + transformer_block_type: 'pre_ln' # Options ['pre_ln', 'post_ln', 'normformer'] + normalize_attention_scores: True # Whether to scale the output Q * K^T by 1 / sqrt(hidden_size_per_head). This arg is provided as a configuration option mostly for compatibility with models that have been weight-converted from HF. You almost always want to se this to True. + rotary_percentage: 1.0 # If using position_embedding_type=rope, then the per head dim is multiplied by this. + rotary_base: 1000000.0 + moe_router_topk: 2 + num_moe_experts: 8 + attention_type: 'multihead' # Attention type. Options ['multihead'] + share_embeddings_and_output_weights: False # Share embedding and output layer weights. + overlap_p2p_comm: False # Overlap p2p communication with computes. This argument is valid only when `virtual_pipeline_model_parallel_size` is larger than 1 + batch_p2p_comm: True # Batch consecutive inter-peer send/recv operations. This argument is valid only when `virtual_pipeline_model_parallel_size` is larger than 1 + seq_len_interpolation_factor: null # RoPE Interpolation factor for sequence length. This is used to build long-context models with RoPE ex: https://arxiv.org/abs/2306.15595. + num_query_groups: 8 # Number of query groups for group query attention. If None, normal attention is used. + use_flash_attention: True + + ## Activation Checkpointing + activations_checkpoint_granularity: null # 'selective' or 'full' + activations_checkpoint_method: null # 'uniform', 'block', not used with 'selective' + activations_checkpoint_num_layers: null # not used with 'selective' + num_micro_batches_with_partial_activation_checkpoints: null + activations_checkpoint_layers_per_pipeline: null + sequence_parallel: True + + # precision + native_amp_init_scale: 4294967296 # 2 ** 32 + native_amp_growth_interval: 1000 + hysteresis: 2 # Gradient scale hysteresis + fp32_residual_connection: False # Move residual connections to fp32 + fp16_lm_cross_entropy: False # Move the cross entropy unreduced loss calculation for lm head to fp16 + + # model fusions + masked_softmax_fusion: True # Use a kernel that fuses the attention softmax with it's mask. + bias_dropout_add_fusion: False # Use a kernel that fuses the bias addition, dropout and residual connection addition. + + use_cpu_initialization: False # Init weights on the CPU (slow for large models) + onnx_safe: False # Use work-arounds for known problems with Torch ONNX exporter. + gradient_accumulation_fusion: False # Fuse weight gradient accumulation to GEMMs. Only used with pipeline parallelism. + openai_gelu: False + bias_activation_fusion: False + megatron_legacy: False + + transformer_engine: True + fp8: False # enables fp8 in TransformerLayer forward + fp8_e4m3: False # sets fp8_format = recipe.Format.E4M3 + fp8_hybrid: False # sets fp8_format = recipe.Format.HYBRID + fp8_margin: 0 # scaling margin + fp8_interval: 1 # scaling update interval + fp8_amax_history_len: 1 # Number of steps for which amax history is recorded per tensor + fp8_amax_compute_algo: most_recent # 'most_recent' or 'max'. Algorithm for computing amax from history + use_emha: False # Use fused multi-head attention for large sequence-length. Note this is not yet supported. Please set to False. + + # Megatron O2-style half-precision + megatron_amp_O2: True # Enable O2-level automatic mixed precision using main parameters + async_grad_allreduce: False + grad_allreduce_chunk_size_mb: 125 + grad_div_ar_fusion: True # Fuse grad division into torch.distributed.all_reduce + + # miscellaneous + seed: 1234 + resume_from_checkpoint: null # manually set the checkpoint file to load from + apex_transformer_log_level: 30 # Python logging level displays logs with severity greater than or equal to this + gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory) + + tokenizer: + library: 'sentencepiece' + type: null + model: null + vocab_file: null + merge_file: null + delimiter: null # only used for tabular tokenizer + sentencepiece_legacy: False # Legacy=True allows you to add special tokens to sentencepiece tokenizers. + + data: + num_workers: 1 + dataloader_type: cyclic + data_path: null + lazy_preprocess: True + is_multimodal: True + media_type: image + sep_image_conv_front: False + conv_template: mistral + image_folder: null + image_aspect_ratio: 'square' + + # Nsys profiling options + nsys_profile: + enabled: False + start_step: 10 # Global batch to start profiling + end_step: 10 # Global batch to end profiling + ranks: [ 0 ] # Global rank IDs to profile + gen_shape: False # Generate model and kernel details including input shapes + + optim: + name: fused_adam + lr: 1e-3 + weight_decay: 0. + betas: + - 0.9 + - 0.95 + sched: + name: CosineAnnealing + warmup_steps: 70 + constant_steps: 0 + min_lr: 2e-5 diff --git a/examples/multimodal/multimodal_llm/neva/conf/video_neva_config.yaml b/examples/multimodal/multimodal_llm/neva/conf/video_neva_config.yaml index 8341ff857202..10d9230fc78e 100644 --- a/examples/multimodal/multimodal_llm/neva/conf/video_neva_config.yaml +++ b/examples/multimodal/multimodal_llm/neva/conf/video_neva_config.yaml @@ -70,7 +70,7 @@ model: llm: from_pretrained: #path to nemo checkpoint freeze: True - model_type: llama_2 # `nvgpt` or `llama_2` supported + model_type: llama_2 # `v1`, `nvgpt`, `llama_2`, `llama_3` and `mistral` supported vision_encoder: from_pretrained: "" # path or name from_hf: True diff --git a/examples/multimodal/multimodal_llm/neva/conf/vita_config.yaml b/examples/multimodal/multimodal_llm/neva/conf/vita_config.yaml new file mode 100644 index 000000000000..7be99308a280 --- /dev/null +++ b/examples/multimodal/multimodal_llm/neva/conf/vita_config.yaml @@ -0,0 +1,231 @@ +name: nemo_video_lita_neva +restore_from_path: null # used when starting from a .nemo file + +trainer: + devices: 8 + num_nodes: 1 + accelerator: gpu + precision: bf16 + logger: False # logger provided by exp_manager + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: -1 # PTL default. In practice, max_steps will be reached first. + max_steps: 10000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches + log_every_n_steps: 10 + val_check_interval: 100 + check_val_every_n_epoch: null + limit_val_batches: 50 + limit_test_batches: 500 + accumulate_grad_batches: 1 # do not modify, grad acc is automatic for training megatron models + gradient_clip_val: 1.0 + benchmark: False + enable_model_summary: False # default PTL callback for this does not support model parallelism, instead we log manually + +exp_manager: + explicit_log_dir: null + exp_dir: null + name: nemo_video_neva_lita + create_wandb_logger: True + wandb_logger_kwargs: + project: null + name: null + resume_if_exists: True + resume_ignore_no_checkpoint: True + resume_from_checkpoint: ${model.resume_from_checkpoint} + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: val_loss + save_top_k: 5 + mode: min + always_save_nemo: False # saves nemo file during validation, not implemented for model parallel + save_nemo_on_train_end: True # not recommended when training large models on clusters with short time limits + filename: 'megatron_clip--{val_loss:.2f}-{step}-{consumed_samples}' + model_parallel_size: ${multiply:${model.tensor_model_parallel_size}, ${model.pipeline_model_parallel_size}} + ema: + enable: False + decay: 0.9999 + validate_original_weights: False + every_n_steps: 1 + cpu_offload: False + +model: + precision: ${trainer.precision} + + # specify micro_batch_size, global_batch_size, and model parallelism + # gradient accumulation will be done automatically based on data_parallel_size + + # Batch size guideline for different types of dataset + micro_batch_size: 1 # limited by GPU memory + global_batch_size: 128 # will use more micro batches to reach global batch size + + tensor_model_parallel_size: 1 # intra-layer model parallelism + pipeline_model_parallel_size: 1 # inter-layer model parallelism + context_parallel_size: 1 # kqv model parallelism + virtual_pipeline_model_parallel_size: null # interleaved pipeline + + restore_from_path: null # used in fine-tuning + + # Multimodal configs + mm_cfg: + llm: + from_pretrained: null #path to nemo checkpoint + freeze: False + model_type: vita + vision_encoder: + from_pretrained: null # path or name + model_type: null + from_hf: True + crop_size: [384, 384] + patch_dim: 14 + hidden_size: 1152 # could be found from model but tricky in code + vision_select_layer: -2 # default to the last layer + vision_select_feature: 'cls_patch' # default is patch + class_token_length: 1 + freeze: True + lita: + lita_video_arch: 'temporal_all_resolution' # ['temporal_spatial_pool', 'temporal_spatial', 'temporal_all_resolution'] + visual_token_format: 'im_vid_start_end' # ["v1", "im_vid_start_end"] v1 means do nothing, im_vid_start_end means add image and video start and end tokens around spatial and temporal tokens + sample_frames: 4 # for lita 1.5 sample_frames are used for spatial tokens, and spatial tokens will no longer do pooling and instead, it will use full tokens + use_lita: True + pretrain_mm_mlp_adapter: null # path to pretrained mm adapter + mm_mlp_adapter_type: mlp_downsample # ['linear', 'mlp2x_gelu', 'mlp_downsample'] + + use_im_start_end: False + + + # LLM configs + # use GPTModel from megatron.core + mcore_gpt: True + + # model architecture + encoder_seq_length: 8192 + max_position_embeddings: ${.encoder_seq_length} + position_embedding_type: rope + num_layers: 32 + hidden_size: 4096 + ffn_hidden_size: 14336 # Transformer FFN hidden size. Usually 4 * hidden_size. + num_attention_heads: 32 + init_method_std: 0.014 # Standard deviation of the zero mean normal distribution used for weight initialization.') + use_scaled_init_method: True # use scaled residuals initialization + hidden_dropout: 0.0 # Dropout probability for hidden state transformer. + attention_dropout: 0.0 # Dropout probability for attention + ffn_dropout: 0.0 # Dropout probability in the feed-forward layer. + kv_channels: null # Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if null + apply_query_key_layer_scaling: True # scale Q * K^T by 1 / layer-number. + normalization: 'rmsnorm' # Normalization layer to use. Options are 'layernorm', 'rmsnorm' + layernorm_epsilon: 1e-5 + do_layer_norm_weight_decay: False # True means weight decay on all params + make_vocab_size_divisible_by: 16 # Pad the vocab size to be divisible by this value for computation efficiency. + pre_process: True # add embedding + post_process: True # add pooler + persist_layer_norm: True # Use of persistent fused layer norm kernel. + bias: False # Whether to use bias terms in all weight matrices. + activation: 'fast-swiglu' # Options ['gelu', 'geglu', 'swiglu', 'reglu', 'squared-relu', 'fast-geglu', 'fast-swiglu', 'fast-reglu'] + headscale: False # Whether to learn extra parameters that scale the output of the each self-attention head. + transformer_block_type: 'pre_ln' # Options ['pre_ln', 'post_ln', 'normformer'] + normalize_attention_scores: True # Whether to scale the output Q * K^T by 1 / sqrt(hidden_size_per_head). This arg is provided as a configuration option mostly for compatibility with models that have been weight-converted from HF. You almost always want to se this to True. + rotary_percentage: 1.0 # If using position_embedding_type=rope, then the per head dim is multiplied by this. + rotary_base: 500000.0 # default is 10000 + attention_type: 'multihead' # Attention type. Options ['multihead'] + share_embeddings_and_output_weights: False # Share embedding and output layer weights. + overlap_p2p_comm: False # Overlap p2p communication with computes. This argument is valid only when `virtual_pipeline_model_parallel_size` is larger than 1 + batch_p2p_comm: True # Batch consecutive inter-peer send/recv operations. This argument is valid only when `virtual_pipeline_model_parallel_size` is larger than 1 + seq_len_interpolation_factor: null # RoPE Interpolation factor for sequence length. This is used to build long-context models with RoPE ex: https://arxiv.org/abs/2306.15595. + num_query_groups: 8 # Number of query groups for group query attention. If None, normal attention is used. + + ## Activation Checkpointing + activations_checkpoint_granularity: null # 'selective' or 'full' + activations_checkpoint_method: null # 'uniform', 'block', not used with 'selective' + activations_checkpoint_num_layers: null # not used with 'selective' + num_micro_batches_with_partial_activation_checkpoints: null + activations_checkpoint_layers_per_pipeline: null + sequence_parallel: False + + # precision + native_amp_init_scale: 4294967296 # 2 ** 32 + native_amp_growth_interval: 1000 + hysteresis: 2 # Gradient scale hysteresis + fp32_residual_connection: False # Move residual connections to fp32 + fp16_lm_cross_entropy: False # Move the cross entropy unreduced loss calculation for lm head to fp16 + + # model fusions + masked_softmax_fusion: True # Use a kernel that fuses the attention softmax with it's mask. + bias_dropout_add_fusion: False # Use a kernel that fuses the bias addition, dropout and residual connection addition. + + use_cpu_initialization: False # Init weights on the CPU (slow for large models) + onnx_safe: False # Use work-arounds for known problems with Torch ONNX exporter. + gradient_accumulation_fusion: False # Fuse weight gradient accumulation to GEMMs. Only used with pipeline parallelism. + openai_gelu: False + bias_activation_fusion: False + megatron_legacy: False + + transformer_engine: True + fp8: False # enables fp8 in TransformerLayer forward + fp8_e4m3: False # sets fp8_format = recipe.Format.E4M3 + fp8_hybrid: False # sets fp8_format = recipe.Format.HYBRID + fp8_margin: 0 # scaling margin + fp8_interval: 1 # scaling update interval + fp8_amax_history_len: 1 # Number of steps for which amax history is recorded per tensor + fp8_amax_compute_algo: most_recent # 'most_recent' or 'max'. Algorithm for computing amax from history + use_emha: False # Use fused multi-head attention for large sequence-length. Note this is not yet supported. Please set to False. + + # Megatron O2-style half-precision + megatron_amp_O2: False # Enable O2-level automatic mixed precision using main parameters + async_grad_allreduce: False + grad_allreduce_chunk_size_mb: 125 + grad_div_ar_fusion: True # Fuse grad division into torch.distributed.all_reduce + + # miscellaneous + seed: 1234 + resume_from_checkpoint: null # manually set the checkpoint file to load from + apex_transformer_log_level: 30 # Python logging level displays logs with severity greater than or equal to this + gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory) + + tokenizer: + library: 'huggingface' + type: /ws/converted_models/tokenizer # set huggingface tokenizer here; And check `LITA Tutorial.ipynb` for how to add time tokens to tokenizer + model: null # set sentencepiece model path here if tokenizer is sentencepiece + vocab_file: null + merge_file: null + delimiter: null # only used for tabular tokenizer + sentencepiece_legacy: False # Legacy=True allows you to add special tokens to sentencepiece tokenizers. + additional_special_tokens: null # ["", "", "", "", "", ""] + + data: + packed_sequence: False + num_workers: 8 + dataloader_type: cyclic + data_path: null + lazy_preprocess: True + is_multimodal: True + media_type: video # currently supported: image or video + splice_single_frame: null # 'first', 'middle', 'last' will represent video as first / middle / last frame only, all other frames discarded. + num_frames: 256 # selects the number of frames to use from the video + sep_token_between_frames: False # TODO: allow usage of separator tokens between frames + sep_image_conv_front: False + image_token_len: 784 # 28x28 + conv_template: llama_3 # check `nemo/collections/multimodal/data/neva/conversation.py` + image_folder: null + video_folder: null + image_aspect_ratio: 'pad' # in vila, it's `resize` + + # Nsys profiling options + nsys_profile: + enabled: False + start_step: 10 # Global batch to start profiling + end_step: 10 # Global batch to end profiling + ranks: [ 0 ] # Global rank IDs to profile + gen_shape: False # Generate model and kernel details including input shapes + + optim: + name: fused_adam + lr: 2e-5 + weight_decay: 0. + betas: + - 0.9 + - 0.95 + sched: + name: CosineAnnealing + warmup_steps: 140 + constant_steps: 0 + min_lr: 2e-7 diff --git a/examples/multimodal/multimodal_llm/neva/convert_hf_llava_to_neva.py b/examples/multimodal/multimodal_llm/neva/convert_llava_to_neva.py similarity index 73% rename from examples/multimodal/multimodal_llm/neva/convert_hf_llava_to_neva.py rename to examples/multimodal/multimodal_llm/neva/convert_llava_to_neva.py index 2cbb4c2b3b82..d02b737c750a 100644 --- a/examples/multimodal/multimodal_llm/neva/convert_hf_llava_to_neva.py +++ b/examples/multimodal/multimodal_llm/neva/convert_llava_to_neva.py @@ -13,15 +13,22 @@ # limitations under the License. r""" -Script to convert HuggingFace LLaVA checkpoints into .nemo file. - Example to run this conversion script: - python convert_hf_llava_to_neva.py \ - --in-file \ - --out-file \ - --tokenizer-model \ - --conv-template llama_2 # nvgpt, llama_2, v1 (vicuna) +Script to convert LLaVA checkpoints into .nemo file. +This script depend on llava github project: +https://github.com/haotian-liu/LLaVA/tree/main + +If you want to convert huggingface LLaVA checkpoint such as llava-hf/llava-1.5-7b-hf, +you should check `NeMo/scripts/checkpoint_converters/convert_llava_hf_to_nemo.py` + +Example to run this conversion script: + python convert_hf_llava_to_neva.py \ + --in-file \ + --out-file \ + --tokenizer-model \ + --conv-template llama_2 # nvgpt, llama_2, v1, llama_3 (vicuna) """ +import json import os from argparse import ArgumentParser from collections import OrderedDict @@ -31,6 +38,7 @@ from omegaconf import OmegaConf from pytorch_lightning.core.saving import _load_state as ptl_load_state from pytorch_lightning.trainer.trainer import Trainer +from safetensors import safe_open from transformers import LlamaTokenizer from nemo.collections.multimodal.models.multimodal_llm.neva.neva_model import MegatronNevaModel @@ -47,7 +55,11 @@ def get_args(): parser = ArgumentParser() parser.add_argument( - "--in-file", type=str, default=None, required=True, help="Path to Huggingface LLaMA checkpoints", + "--in-file", + type=str, + default=None, + required=True, + help="Path to LLaVA checkpoints", ) parser.add_argument("--out-file", type=str, default=None, required=True, help="Path to output .nemo file.") parser.add_argument( @@ -61,6 +73,16 @@ def get_args(): "--tokenizer-model", type=str, default=None, required=False, help="Path to sentencepiece tokenizer model." ) parser.add_argument("--precision", type=str, default="32", help="Model precision") + parser.add_argument("--config-file", type=str, default="llava_config.yaml") + parser.add_argument( + "--mm-projector-ckpt-dir", + type=str, + default=None, + help="Path to multimodal projector checkpoint directory \ + This will overlap the projector weights in in-file hf checkpoint", + ) + parser.add_argument("--mm-vision-tower", type=str, default=None) + parser.add_argument("--model-type", type=str, default=None) args = parser.parse_args() return args @@ -110,13 +132,32 @@ def load_model(cls, checkpoint, strict, **kwargs): def load_config(args, llava_config): - nemo_config = OmegaConf.load(os.path.join(os.path.dirname(__file__), 'conf/llava_config.yaml')).model + nemo_config = OmegaConf.load(os.path.join(os.path.dirname(__file__), 'conf', args.config_file)).model nemo_config.mm_cfg.mm_mlp_adapter_type = llava_config.get('mm_projector_type', 'linear') - nemo_config.mm_cfg.vision_encoder.from_pretrained = llava_config.get( - 'mm_vision_tower', 'openai/clip-vit-large-patch14' - ) - if '336' in nemo_config.mm_cfg.vision_encoder.from_pretrained: - nemo_config.data.image_token_len = 576 + + mm_vision_tower = llava_config.get('mm_vision_tower', 'openai/clip-vit-large-patch14') + + if args.mm_vision_tower is not None: + mm_vision_tower = args.mm_vision_tower + + nemo_config.mm_cfg.vision_encoder.from_pretrained = mm_vision_tower + if args.mm_vision_tower is not None: + config_file = os.path.join(args.mm_vision_tower, "config.json") + if os.path.exists(config_file): + with open(config_file, "r") as f: + vision_model_config = json.load(f) + nemo_config.mm_cfg.vision_encoder["model_type"] = vision_model_config.get("model_type", 'clip') + crop_size = vision_model_config.get("image_size", 224) + nemo_config.mm_cfg.vision_encoder.crop_size = [crop_size, crop_size] + else: + if '336' in mm_vision_tower: + nemo_config.data.image_token_len = 576 + nemo_config.mm_cfg.vision_encoder.crop_size = [336, 336] + else: + nemo_config.data.image_token_len = 256 + nemo_config.mm_cfg.vision_encoder.crop_size = [224, 224] + nemo_config.mm_cfg.vision_encoder.patch_dim = 14 + nemo_config.encoder_seq_length = llava_config['max_position_embeddings'] nemo_config.num_layers = int(llava_config['num_hidden_layers']) nemo_config.hidden_size = llava_config['hidden_size'] @@ -130,16 +171,34 @@ def load_config(args, llava_config): nemo_config.use_cpu_initialization = True nemo_config.activation = 'fast-swiglu' nemo_config.data.conv_template = args.conv_template - nemo_config.mm_cfg.model_type = args.conv_template + nemo_config.data.image_aspect_ratio = llava_config.get('image_aspect_ratio', 'square') + if args.model_type is None: + nemo_config.mm_cfg.model_type = args.conv_template + else: + nemo_config.mm_cfg.model_type = args.model_type if args.tokenizer_model is None: - nemo_config.tokenizer.model = llava_config['tokenizer_model'] + if 'tokenizer_model' in llava_config: + nemo_config.tokenizer.library = 'sentencepiece' + nemo_config.tokenizer.model = llava_config['tokenizer_model'] + else: + # Llama3 uses converted TikToken Tokenizer + tokenizer_dict = {'library': 'huggingface', 'type': args.in_file, 'use_fast': True, 'model': None} + nemo_config.tokenizer.update(tokenizer_dict) else: - nemo_config.tokenizer.model = args.tokenizer_model + # if tokenizer_model is directory + if os.path.isdir(args.tokenizer_model): + tokenizer_dict = {'library': 'huggingface', 'type': args.tokenizer_model, 'use_fast': True, 'model': None} + nemo_config.tokenizer.update(tokenizer_dict) + else: + nemo_config.tokenizer.library = 'sentencepiece' + nemo_config.tokenizer.model = args.tokenizer_model if llava_config['rope_scaling'] is not None: if llava_config['rope_scaling']['type'] == 'linear': nemo_config['seq_len_interpolation_factor'] = llava_config['rope_scaling']['factor'] else: raise ValueError("Only linear rope scaling type is supported now") + if llava_config.get('rope_theta', None): + nemo_config['rotary_base'] = llava_config['rope_theta'] base = 128 while llava_config['vocab_size'] % base != 0: @@ -152,16 +211,15 @@ def load_config(args, llava_config): def convert(args): logging.info(f"loading checkpoint {args.in_file}") model = LlavaLlamaForCausalLM.from_pretrained(args.in_file) - tokenizer = LlamaTokenizer.from_pretrained(args.in_file) hf_config = vars(model.config) - hf_config['tokenizer_model'] = str(tokenizer.vocab_file) - print(f"hf_config: {hf_config}") - print("named parameters:") + if os.path.exists(f'{args.in_file}/tokenizer.model'): + tokenizer = LlamaTokenizer.from_pretrained(args.in_file) + hf_config['tokenizer_model'] = str(tokenizer.vocab_file) + for name, param in model.named_parameters(): print(f"- {name}") nemo_config = load_config(args, hf_config) - print(nemo_config) if args.precision in ["32", "16"]: precision = int(float(args.precision)) @@ -179,7 +237,7 @@ def convert(args): scaler = None if precision in [16, '16', '16-mixed']: scaler = GradScaler( - init_scale=nemo_config.get('native_amp_init_scale', 2 ** 32), + init_scale=nemo_config.get('native_amp_init_scale', 2**32), growth_interval=nemo_config.get('native_amp_growth_interval', 1000), hysteresis=nemo_config.get('hysteresis', 2), ) @@ -235,10 +293,42 @@ def convert(args): for key in model.state_dict(): if 'mm_projector' in key: mm_projection_layer_suffix = key.split('mm_projector')[1] - checkpoint['state_dict'][ - f'{mm_projection_layer_base_name}{mm_projection_layer_suffix}' - ] = param_to_weights(model.state_dict()[key]) + checkpoint['state_dict'][f'{mm_projection_layer_base_name}{mm_projection_layer_suffix}'] = ( + param_to_weights(model.state_dict()[key]) + ) + # Replace or add the projection weights + proj_ckpt = None + if args.mm_projector_ckpt_dir is not None: + if os.path.exists(args.mm_projector_ckpt_dir): + ckpt_path = os.path.join(args.mm_projector_ckpt_dir, "mm_projector.bin") + if os.path.exists(ckpt_path): + proj_ckpt = torch.load(ckpt_path) + else: + ckpt_path = os.path.join(args.mm_projector_ckpt_dir, "model.safetensors") + proj_ckpt = {} + with safe_open(ckpt_path, framework="pt", device="cuda") as f: + for key in f.keys(): + new_key = key.replace("layers.", "mm_projector.") + proj_ckpt[new_key] = f.get_tensor(key) + else: + raise FileNotFoundError(f"mm_projector_ckpt_dir {args.mm_projector_ckpt_dir} does not exist.") + for key in proj_ckpt.keys(): + if 'mm_projector' in key: + mm_projection_layer_suffix = key.split('mm_projector')[1] + checkpoint['state_dict'][f'{mm_projection_layer_base_name}{mm_projection_layer_suffix}'] = ( + param_to_weights(proj_ckpt[key]) + ) + + proj_conf_file = open(os.path.join(args.mm_projector_ckpt_dir, "config.json")) + + proj_conf = json.load(proj_conf_file) + if proj_conf['mm_projector_type'] != nemo_config.mm_cfg.mm_mlp_adapter_type: + logging.warning( + f"Overriding mm_projector_type from {nemo_config.mm_cfg.mm_mlp_adapter_type} to {proj_conf['mm_projector_type']}" + ) + nemo_config.mm_cfg.mm_mlp_adapter_type = proj_conf['mm_projector_type'] + proj_conf_file.close() embed_weight = model.state_dict()[f'model.embed_tokens.weight'] if mcore_gpt: embed_weights_base_name = f'model.embedding.word_embeddings.weight' diff --git a/examples/multimodal/multimodal_llm/neva/eval/eval_video_rtl.py b/examples/multimodal/multimodal_llm/neva/eval/eval_video_rtl.py new file mode 100644 index 000000000000..3567cf431d87 --- /dev/null +++ b/examples/multimodal/multimodal_llm/neva/eval/eval_video_rtl.py @@ -0,0 +1,196 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +This script is used for evaluating RTL (Reasoning Temporal Localization) task. +It accepts one JSON file. The JSON file should have the following structure: +[ + { + "video": "rY7eLyJF31M_6.mp4", + "question_id": "rY7eLyJF31M_6_0", + "question": "When is \"Apply mascara , false lashes on the lashes \" depicted in the video? Convey your answer using start and end timestamps exclusively.", + "ref_answer": "<0> <53> Apply mascara , false lashes on the lashes ", + "duration": 102.002002002002, + "pred_answer": "<1> <53> Apply mascara , false lashes on the lashes ", + }, + { + "video": "rY7eLyJF31M_6.mp4", + "question_id": "rY7eLyJF31M_6_1", + "question": "When is \"Apply foundation on the face with a brush\" depicted in the video? Provide a response using only start and end timestamps.", + "ref_answer": "<56> <97> Apply foundation on the face with a brush", + "duration": 102.002002002002, + "pred_answer": "<50> <97> Apply foundation on the face with a brush", + }, +] + +The `xxx_answer` field should contain the start and end timestamps such as `<56>` and `<97>` of the event along with the sentence. +If not, the [0, duration] will be used as the predicted timestamps. + +USAGE: +python eval_rtl.py --input_file \ + --output_dir \ + --save_mid_result +""" +import argparse +import json +import os +import re +from collections import defaultdict + + +def iou(seg1, seg2): + """Compute the intersection over union (IoU) between two segments. + + Args: + seg1 (list): [start, end] + seg2 (list): [start, end] + + Returns: + float: IoU value + """ + assert seg1[1] >= seg1[0] and seg2[1] >= seg2[0] + + x1 = max(seg1[0], seg2[0]) + x2 = min(seg1[1], seg2[1]) + inter = max(x2 - x1, 0) + + len1 = max(seg1[1] - seg1[0], 0) + len2 = max(seg2[1] - seg2[0], 0) + + union = len1 + len2 - inter + + if union == 0: + return 0.0 + else: + return inter / union + + +def precision_func(thres): + """calculate the precision based on the threshold. + If the IoU value is greater than or equal to the threshold, \ + the precision is 1.0, otherwise 0.0. + + Args: + thres (float): threshold value [0.0, 1.0] + """ + + def precision(seg1, seg2): + return float(iou(seg1, seg2) >= thres) + + return precision + + +def parse_start_end_timestamps(outputs, duration, strict=False): + timestamp_pattern = '\<(?: (?: \d* \.? \d+ ) | (?: \d+ \.? ) )\>' + rx = re.compile(timestamp_pattern, re.VERBOSE) + matches = list(rx.finditer(outputs)) + if strict: + assert len(list(matches)) >= 2, "cannot find timestamps" + elif len(list(matches)) < 2: + return outputs, [0, duration] + + prev_end = 0 + sentence = "" + timestamps = [] + for i in range(2): + m = matches[i] + start = m.start(0) + end = m.end(0) + timestamp = float(m.group(0)[1:-1]) + timestamp = min(max(timestamp, 0), duration) + timestamps.append(timestamp) + sentence += outputs[prev_end:start] + prev_end = end + sentence += outputs[prev_end:] + sentence = sentence.strip() + + return sentence, [min(timestamps), max(timestamps)] + + +def eval(pred_file, output_dir, save_mid_result=True): + """Evaluate the predictions against the ground truth. + + Args: + pred_file (str): path to the predictions JSON file + output_dir (str): path to the output directory, + where the `answers.json` and `metrics.json` result will be saved. + """ + metric_func = {'iou': iou, 'precision@0.5': precision_func(0.5)} + metrics = {} + for metric in metric_func: + metrics[metric] = defaultdict(list) + + with open(pred_file, 'r') as f: + pred_data = json.load(f) + + out_list = [] + for pred in pred_data: + assert "pred_answer" in pred, "pred_answer field is missing" + assert "ref_answer" in pred, "answer field is missing" + duration = pred['duration'] + pred_answer, pred_timestamps = parse_start_end_timestamps(pred['pred_answer'], duration, strict=False) + ref_answer, ref_timestamps = parse_start_end_timestamps(pred['ref_answer'], duration, strict=False) + + for metric in metric_func: + metrics[metric][pred['video']].append(metric_func[metric](pred_timestamps, ref_timestamps)) + + out_list.append( + { + 'video': pred['video'], + 'question_id': pred['question_id'], + 'question': pred['question'], + 'pred_answer': pred_answer, + 'ref_answer': ref_answer, + 'pred_timestamps': pred_timestamps, + 'ref_timestamps': ref_timestamps, + } + ) + # save result + os.makedirs(output_dir, exist_ok=True) + if save_mid_result: + output_file = os.path.join(output_dir, 'answers.json') + print(f"Saving intermediate result to {output_file}") + with open(output_file, 'w') as f: + json.dump(out_list, f, indent=2) + + final_result = {} + for metric in metrics: + values = [] + for vid in metrics[metric]: + # get single video metric value + cur_metric_values = metrics[metric][vid] + values.append(sum(cur_metric_values) / len(cur_metric_values)) + # get global average video metric value + values = sum(values) / len(values) + final_result[metric] = values + + print(final_result) + output_file = os.path.join(output_dir, 'metrics.json') + with open(output_file, 'w') as f: + json.dump(final_result, f, indent=2) + + +def main(): + parser = argparse.ArgumentParser(description="Evaluate the predictions against the ground truth") + parser.add_argument("--input_file", help="Path to the input JSON file", required=True) + parser.add_argument("--output_dir", help="Path to the output directory", required=True) + parser.add_argument("--save_mid_result", action="store_true", help="Save intermediate result") + args = parser.parse_args() + + eval(args.input_file, args.output_dir, args.save_mid_result) + + +if __name__ == "__main__": + main() diff --git a/examples/multimodal/multimodal_llm/neva/eval/eval_vqa.py b/examples/multimodal/multimodal_llm/neva/eval/eval_vqa.py new file mode 100644 index 000000000000..8929648a3f97 --- /dev/null +++ b/examples/multimodal/multimodal_llm/neva/eval/eval_vqa.py @@ -0,0 +1,207 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +This script is used for evaluating Video Question Answering task by leveraging LLM API as a judge. +It accepts one JSON file. The JSON file should have the following structure: +[ + { + "video": "YRvBOLRgZNc_2".mp4", + "question_id": "v_yVgL8sJQxYo_2_5", + "question": "What tools are used to apply foundation on the skin between <5s> and <60s>?", + "ref_answer": "A brush and blender.", + "duration": 102.002002002002, + "pred_answer": "A brush", + }, + { + "video": "yVgL8sJQxYo_2.mp4", # not a must-to-have field + "question": "How long does the action of applying foundation take?", + "question_id": "v_yVgL8sJQxYo_2_5" + "ref_answer": "The action takes around 55 seconds (<60s> - <5s>)." + "duration": 102.002002002002, # not a must-to-have field + "pred_answer": "This action takes around 50 seconds.", + } + + ... +] + +`video` and `duration` are two optional fields. If not provided, the script will ignore them. + +Notice that the time token here is represented as '<%ss>'.format(time_in_seconds). + +For the external LLM API, we use `meta/llama3-70b-instruct"` as an example. +You can go to: https://build.nvidia.com/explore/discover to choose the one that fits your needs. +Notice the API might be a little bit different. + +You also need an `API_TOKEN` from here: https://build.nvidia.com/explore/discover#llama3-70b +Click the `Get API Key` and save your key in the environment variable `API_TOKEN`. + +USAGE: +API_TOKEN= python eval_qa.py --input_file --output_dir --save_mid_result +""" + +import argparse +import ast +import json +import os +import re + +import requests + + +def parse_args(): + parser = argparse.ArgumentParser(description="Evaluate Video Question Answering task.") + parser.add_argument("--input_file", type=str, required=True, help="Path to the prediction file. json list file") + parser.add_argument("--output_dir", type=str, required=True, help="Path to the output directory.") + parser.add_argument("--save_mid_result", action="store_true", help="Whether to save the intermediate results.") + return parser.parse_args() + + +INVOKE_URL = "https://integrate.api.nvidia.com/v1/chat/completions" +# MODEL="mistralai/mixtral-8x22b-instruct-v0.1" # no `system` role +MODEL = "meta/llama3-70b-instruct" + + +def request_nvidia_api(messages): + API_TOKEN = os.getenv("API_TOKEN", "") # ADD NGC API TOKEN HERE + if not API_TOKEN: + raise ValueError("Please provide the API_TOKEN in the environment variable.") + headers = { + "Authorization": f"Bearer {API_TOKEN}", + "accept": "text/event-stream", + "content-type": "application/json", + } + payload = { + "model": MODEL, + "messages": messages, + "temperature": 0.5, + "top_p": 1.0, + "max_tokens": 2048, + "seed": 42, + "stream": True, + } + invoke_url = INVOKE_URL + response = requests.post(invoke_url, headers=headers, json=payload, stream=True) + output = "" + for line in response.iter_lines(): + if line == b'data: [DONE]': + break + if line: + res = json.loads(line.decode("utf-8").split("data: ")[1]) + if 'content' in res['choices'][0]['delta']: + output += res['choices'][0]['delta']['content'] + return output.lstrip().strip() + + +def convert_time_token(text): + # use regular expression to convert <12> <56> to <12s> <56s> + return re.sub(r'<(\d+)>', r'<\1s>', text) + + +def get_result(question, answer, pred, key, output_dir, save_mid_result=False): + messages = [ + { + "role": "system", + "content": "You are an intelligent chatbot designed for evaluating the correctness of generative outputs for question-answer pairs. " + "Your task is to compare the predicted answer with the correct answer and determine if they match meaningfully. Here's how you can accomplish the task:" + "------" + "##INSTRUCTIONS: " + "- Focus on the meaningful match between the predicted answer and the correct answer.\n" + "- Consider synonyms or paraphrases as valid matches.\n" + "- Evaluate the correctness of the prediction compared to the answer.", + }, + { + "role": "user", + "content": "Please evaluate the following video-based question-answer pair:\n\n" + f"Question: {question}\n" + f"Correct Answer: {answer}\n" + f"Predicted Answer: {pred}\n\n" + "Provide your evaluation only as a yes/no and score where the score is an integer value between 0 and 5, with 5 indicating the highest meaningful match. " + "Please generate the response in the form of a Python dictionary string with keys 'pred' and 'score', where value of 'pred' is a string of 'yes' or 'no' and value of 'score' is in INTEGER, not STRING." + "DO NOT PROVIDE ANY OTHER OUTPUT TEXT OR EXPLANATION. Only provide the Python dictionary string. " + "For example, your response should look like this: {'pred': 'yes', 'score': 4.8}.", + }, + ] + try: + response_message = request_nvidia_api(messages) + response_dict = ast.literal_eval(response_message) + except Exception as e: + print(f"Error processing file {key}: {e}") + return [] + qa_set = {"question": question, "ref_answer": answer, "pred_answer": pred} + result_qa_pair = [response_dict, qa_set] + if save_mid_result: + with open(f"{output_dir}/{key}.json", "w") as f: + json.dump(result_qa_pair, f) + return result_qa_pair + + +def main(): + args = parse_args() + input_file = args.input_file + output_dir = args.output_dir + save_mid_result = args.save_mid_result + with open(input_file, "r") as f: + data = json.load(f) + + tasks = [] + key = 0 + for item in data: + question = item["question"] + item["ref_answer"] = convert_time_token(item["ref_answer"]) + tasks.append((question, item["ref_answer"], item["pred_answer"], key, output_dir, save_mid_result)) + key += 1 + + # TODO: parallelize the requests + results = [] + while len(tasks) > 0: + task = tasks.pop() + key = task[3] + cur_result = get_result(*task) + if cur_result == []: + tasks.append(task) + continue + results.append((key, cur_result)) + + score_sum = count = yes_count = no_count = 0 + for key, result in results: + try: + count += 1 + score_sum += int(result[0]["score"]) + + if "yes" in result[0]["pred"].lower(): + yes_count += 1 + elif "no" in result[0]["pred"].lower(): + no_count += 1 + except Exception as e: + print(f"Error processing file {key}") + + average_score = score_sum / count + accuracy = yes_count / (yes_count + no_count) + result_file = os.path.join(output_dir, "metrics.json") + metrics = { + "average_score": average_score, + "accuracy": accuracy, + "no_count": no_count, + "yes_count": yes_count, + "model": MODEL, + } + print("Metrics: ", metrics) + with open(result_file, "w") as f: + json.dump(metrics, f, indent=2) + + +if __name__ == "__main__": + main() diff --git a/examples/multimodal/multimodal_llm/neva/neva_evaluation.py b/examples/multimodal/multimodal_llm/neva/neva_evaluation.py index dcc79029463c..75d8a907b796 100644 --- a/examples/multimodal/multimodal_llm/neva/neva_evaluation.py +++ b/examples/multimodal/multimodal_llm/neva/neva_evaluation.py @@ -15,7 +15,7 @@ import json import os import torch -from torch.utils.data import Dataset +from torch.utils.data import DataLoader, Dataset from nemo.collections.multimodal.parts.utils import create_neva_model_and_processor from nemo.collections.nlp.modules.common.transformer.text_generation import LengthParam, SamplingParam @@ -36,24 +36,109 @@ raise EnvironmentError("GPU is needed for the inference") -class RequestDataSet(Dataset): - def __init__(self, sentences): - super().__init__() - self.sentences = sentences - - def __len__( +class TemporalNevaDataset(Dataset): + def __init__( self, + prompt_dicts, + media_base_path, + media_token, + insert_media_token=None, + image_processor=None, + video_processor=None, + add_media_sep=False, ): - return len(self.sentences) + self.prompt_dicts = prompt_dicts + self.media_token = media_token + self.insert_media_token = insert_media_token + self.media_base_path = media_base_path + self.image_processor = image_processor + self.video_processor = video_processor + self.add_media_sep = add_media_sep + # [(media_name, [prompt_dict, prompt_dict, ...]), ...} + self.media_prompt_list = [] + self.group_by_media(media_token) + + def group_by_media(self, media_token): + """ + This function groups the prompt dicts by the media/video/image file name + """ + media_dict = {} + media = media_token.lstrip('<').rstrip('>') + for prompt_dict in self.prompt_dicts: + media_name = prompt_dict[media] # video or image file name + if media_name not in media_dict: + media_dict[media_name] = [] + media_dict[media_name].append(prompt_dict) + self.media_prompt_list = list(media_dict.items()) + + def __len__(self) -> int: + return len(self.media_prompt_list) + + def __getitem__(self, idx) -> dict: + """ + Return a list of prompt dicts for the idx-th media + For a single media file, only one media feature is returned + This would help improve performance as well as save GPU memory + """ + prompt_dict_list = self.media_prompt_list[idx][1] + cur_item = [] + cur_media_feature = None + for prompt_dict in prompt_dict_list: + if 'prompt' not in prompt_dict: + prompt_dict['prompt'] = prompt_dict['text'] if 'text' in prompt_dict else prompt_dict['question'] + if self.insert_media_token == 'left': + if self.add_media_sep: + prompt_dict['prompt'] = self.media_token + " \n" + prompt_dict['prompt'] + else: + prompt_dict['prompt'] = self.media_token + prompt_dict['prompt'] + elif self.insert_media_token == 'right': + if self.add_media_sep: + prompt_dict['prompt'] = prompt_dict['prompt'] + self.media_token + " \n" + else: + prompt_dict['prompt'] = prompt_dict['prompt'] + self.media_token + if 'image' in prompt_dict: + prompt_dict['image_path'] = prompt_dict['image'] + image_path = os.path.join(self.media_base_path, prompt_dict['image']) + if cur_media_feature is None: + cur_media_feature = ("image", self.image_processor(image_path)) + if 'video' in prompt_dict: + prompt_dict['video_path'] = prompt_dict['video'] + video_path = os.path.join(self.media_base_path, prompt_dict['video']) + if cur_media_feature is None: + cur_media_feature = ("video", self.video_processor(video_path)) + cur_item.append(prompt_dict) + return cur_media_feature, cur_item + - def __getitem__(self, idx): - return self.sentences[idx] +def collate_function(batch): + # do nothing + return batch + + +def do_inference(dataloader, model, length_params, sampling_params, cfg): + responses = [] + all_prompts = [] + for idx, batch_media_prompts in enumerate(dataloader): + if idx % 10 == 0: + print(f"Processed {idx} batch media") + for media_media_feature, prompts in batch_media_prompts: + media, media_feature = media_media_feature + all_prompts.extend(prompts.copy()) + for prompt in prompts: + prompt[media] = media_feature + cur_batch_responses = model.generate( + input_prompts=prompts, + length_params=length_params, + sampling_params=sampling_params, + inference_config=cfg, + ) + responses.extend(cur_batch_responses) + return responses, all_prompts @hydra_runner(config_path="conf", config_name="neva_inference") def main(cfg) -> None: model, image_processor, video_processor = create_neva_model_and_processor(cfg) - length_params: LengthParam = { "max_length": cfg.inference.tokens_to_generate, "min_length": cfg.inference.min_tokens_to_generate, @@ -71,35 +156,43 @@ def main(cfg) -> None: "end_strings": cfg.inference.end_strings, } - with open(cfg.prompt_file, 'r') as f: - lines = f.readlines() + prompt_dicts = [] + if cfg.prompt_file.endswith('.json'): + with open(cfg.prompt_file, 'r') as f: + prompt_dicts = json.load(f) + elif cfg.prompt_file.endswith('.jsonl'): + with open(cfg.prompt_file, 'r') as f: + lines = f.readlines() + for line in lines: + prompt_dicts.append(json.loads(line)) + else: + raise ValueError(f"Unsupported prompt file format: {cfg.prompt_file}") media_type_token = cfg.inference.get("media_type", "image") media_token = f"<{media_type_token}>" insert_media_token = cfg.inference.get("insert_media_token", None) - final_prompts = [] - for line in lines: - prompt_dict = json.loads(line) - assert 'prompt' in prompt_dict or 'text' in prompt_dict - if 'prompt' not in prompt_dict: - prompt_dict['prompt'] = prompt_dict['text'] - if insert_media_token == 'left': - prompt_dict['prompt'] = media_token + prompt_dict['prompt'] - elif insert_media_token == 'right': - prompt_dict['prompt'] = prompt_dict['prompt'] + media_token - if 'image' in prompt_dict: - prompt_dict['image_path'] = prompt_dict['image'] - prompt_dict['image'] = image_processor(os.path.join(cfg.inference.media_base_path, prompt_dict['image'])) - if 'video' in prompt_dict: - prompt_dict['video_path'] = prompt_dict['video'] - prompt_dict['video'] = video_processor(os.path.join(cfg.inference.media_base_path, prompt_dict['video'])) - final_prompts.append(prompt_dict) - - responses = model.generate( - input_prompts=final_prompts, length_params=length_params, sampling_params=sampling_params, inference_config=cfg + dataset = TemporalNevaDataset( + prompt_dicts, + cfg.inference.media_base_path, + media_token, + insert_media_token, + image_processor, + video_processor, + cfg.get("add_media_sep", False), ) + num_workers = 2 + dataloader = DataLoader( + dataset, + batch_size=cfg.inference.get("batch_size", 1), + shuffle=False, + collate_fn=collate_function, + num_workers=num_workers, + persistent_workers=True, + ) + responses, final_prompts = do_inference(dataloader, model, length_params, sampling_params, cfg) + # =================== Start Quantization ==================== if HAVE_MODELOPT and cfg.quantization.enable == True: print(f"Using quantization algorithm: {cfg.quantization.algorithm}") @@ -113,21 +206,33 @@ def main(cfg) -> None: raise ValueError(f"Unsupported quantization algorithm: {cfg.quantization.algorithm}") def forward_loop(): - model.generate( - input_prompts=final_prompts, - length_params=length_params, - sampling_params=sampling_params, - inference_config=cfg, + num_samples = cfg.quantization.get("num_samples", 100) + if num_samples == -1: + cur_prompt_dicts = prompt_dicts + else: + cur_prompt_dicts = prompt_dicts[:num_samples] + cur_dataset = TemporalNevaDataset( + cur_prompt_dicts, + cfg.inference.media_base_path, + media_token, + insert_media_token, + image_processor, + video_processor, + cfg.get("add_media_sep", False), ) + cur_dataloader = DataLoader( + cur_dataset, + batch_size=cfg.inference.get("batch_size", 1), + shuffle=False, + collate_fn=collate_function, + num_workers=num_workers, + ) + _, _ = do_inference(cur_dataloader, model, length_params, sampling_params, cfg) mtq.quantize(model, mtq_config, forward_loop) - responses = model.generate( - input_prompts=final_prompts, - length_params=length_params, - sampling_params=sampling_params, - inference_config=cfg, - ) + responses, final_prompts = do_inference(dataloader, model, length_params, sampling_params, cfg) + # ============== Quantization End ========================= # PP middle stages do not yield any responses @@ -138,7 +243,7 @@ def forward_loop(): results = [] for response, prompt in zip(responses, final_prompts): prompt['full_text'] = response["clean_text"] - prompt['text'] = response["clean_response"] + prompt['pred_answer'] = response["clean_response"] prompt['model_id'] = cfg.neva_model_file if 'image_path' in prompt: prompt['image'] = prompt.pop('image_path') @@ -151,8 +256,11 @@ def forward_loop(): results.append(prompt) with open(cfg.output_file, 'w') as f: - for result in results: - f.write(json.dumps(result) + '\n') + if cfg.output_file.endswith('.json'): + json.dump(results, f, indent=2) + else: + for result in results: + f.write(json.dumps(result) + '\n') if __name__ == '__main__': diff --git a/examples/multimodal/multimodal_llm/neva/neva_export.py b/examples/multimodal/multimodal_llm/neva/neva_export.py index 2c081d00a003..6cf44084a564 100644 --- a/examples/multimodal/multimodal_llm/neva/neva_export.py +++ b/examples/multimodal/multimodal_llm/neva/neva_export.py @@ -27,6 +27,7 @@ def main(cfg): tensor_parallel_size=cfg.infer.tensor_parallelism, max_input_len=cfg.infer.max_input_len, max_output_len=cfg.infer.max_output_len, + vision_max_batch_size=cfg.infer.vision_max_batch_size, max_batch_size=cfg.infer.max_batch_size, max_multimodal_len=cfg.infer.max_multimodal_len, dtype=cfg.model.precision, diff --git a/examples/multimodal/text_to_image/stable_diffusion/conf/sd2_train.yaml b/examples/multimodal/text_to_image/stable_diffusion/conf/sd2_train.yaml index b725b15f1ab2..d7b5ed717fc4 100644 --- a/examples/multimodal/text_to_image/stable_diffusion/conf/sd2_train.yaml +++ b/examples/multimodal/text_to_image/stable_diffusion/conf/sd2_train.yaml @@ -153,6 +153,7 @@ model: resume_from_checkpoint: null # manually set the checkpoint file to load from apex_transformer_log_level: 30 # Python logging level displays logs with severity greater than or equal to this gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory) + ddp_overlap: False # True for using PyTorch DDP overlap. optim: name: fused_adam diff --git a/examples/multimodal/text_to_image/stable_diffusion/conf/sd_train.yaml b/examples/multimodal/text_to_image/stable_diffusion/conf/sd_train.yaml index dff963590864..da03a1de96cf 100644 --- a/examples/multimodal/text_to_image/stable_diffusion/conf/sd_train.yaml +++ b/examples/multimodal/text_to_image/stable_diffusion/conf/sd_train.yaml @@ -17,7 +17,6 @@ trainer: enable_model_summary: True limit_val_batches: 0 - exp_manager: exp_dir: null name: ${name} diff --git a/examples/multimodal/text_to_image/stable_diffusion/conf/sd_xl_base.yaml b/examples/multimodal/text_to_image/stable_diffusion/conf/sd_xl_base.yaml index c536bae15926..7e83093eb780 100644 --- a/examples/multimodal/text_to_image/stable_diffusion/conf/sd_xl_base.yaml +++ b/examples/multimodal/text_to_image/stable_diffusion/conf/sd_xl_base.yaml @@ -58,8 +58,6 @@ model: lossconfig: target: torch.nn.Identity - - conditioner_config: _target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.GeneralConditioner emb_models: diff --git a/examples/multimodal/text_to_image/stable_diffusion/conf/sd_xl_base_train.yaml b/examples/multimodal/text_to_image/stable_diffusion/conf/sd_xl_base_train.yaml index 7aa765db2e5f..aa1d2782d15b 100644 --- a/examples/multimodal/text_to_image/stable_diffusion/conf/sd_xl_base_train.yaml +++ b/examples/multimodal/text_to_image/stable_diffusion/conf/sd_xl_base_train.yaml @@ -125,7 +125,6 @@ model: target: torch.nn.Identity - conditioner_config: _target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.GeneralConditioner emb_models: diff --git a/examples/multimodal/text_to_image/stable_diffusion/conf/sd_xl_infer.yaml b/examples/multimodal/text_to_image/stable_diffusion/conf/sd_xl_infer.yaml index eb1f6d7ccb8e..632f1634af50 100644 --- a/examples/multimodal/text_to_image/stable_diffusion/conf/sd_xl_infer.yaml +++ b/examples/multimodal/text_to_image/stable_diffusion/conf/sd_xl_infer.yaml @@ -31,9 +31,9 @@ infer: sampling: base: sampler: EulerEDMSampler - width: 256 - height: 256 - steps: 40 + width: 512 + height: 512 + steps: 50 discretization: "LegacyDDPMDiscretization" guider: "VanillaCFG" thresholder: "None" @@ -48,8 +48,8 @@ sampling: s_noise: 1.0 eta: 1.0 order: 4 - orig_width: 1024 - orig_height: 1024 + orig_width: 512 + orig_height: 512 crop_coords_top: 0 crop_coords_left: 0 aesthetic_score: 5.0 diff --git a/examples/multimodal/text_to_image/stable_diffusion/conf/sd_xl_infer_v2.yaml b/examples/multimodal/text_to_image/stable_diffusion/conf/sd_xl_infer_v2.yaml new file mode 100644 index 000000000000..9dc838dcc5c5 --- /dev/null +++ b/examples/multimodal/text_to_image/stable_diffusion/conf/sd_xl_infer_v2.yaml @@ -0,0 +1,189 @@ +trainer: + devices: 1 + num_nodes: 1 + accelerator: gpu + precision: 32 + logger: False # logger provided by exp_manager + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: -1 # PTL default. In practice, max_steps will be reached first. + max_steps: -1 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches + log_every_n_steps: 10 + accumulate_grad_batches: 1 # do not modify, grad acc is automatic for training megatron models + gradient_clip_val: 1.0 + benchmark: False + enable_model_summary: True + limit_val_batches: 0 + + +infer: + num_samples_per_batch: 1 + num_samples: 4 + prompt: + - "A professional photograph of an astronaut riding a pig" + - 'A photo of a Shiba Inu dog with a backpack riding a bike. It is wearing sunglasses and a beach hat.' + - 'A cute corgi lives in a house made out of sushi.' + - 'A high contrast portrait of a very happy fuzzy panda dressed as a chef in a high end kitchen making dough. There is a painting of flowers on the wall behind him.' + - 'A brain riding a rocketship heading towards the moon.' + negative_prompt: "" + seed: 123 + + +sampling: + base: + sampler: EulerEDMSampler + width: 512 + height: 512 + steps: 50 + discretization: "LegacyDDPMDiscretization" + guider: "VanillaCFG" + thresholder: "None" + scale: 5.0 + img2img_strength: 1.0 + sigma_min: 0.0292 + sigma_max: 14.6146 + rho: 3.0 + s_churn: 0.0 + s_tmin: 0.0 + s_tmax: 999.0 + s_noise: 1.0 + eta: 1.0 + order: 4 + orig_width: 512 + orig_height: 512 + crop_coords_top: 0 + crop_coords_left: 0 + aesthetic_score: 5.0 + negative_aesthetic_score: 5.0 + +# model: +# is_legacy: False + +use_refiner: False +use_fp16: False # use fp16 model weights +out_path: ./output + +base_model_config: /opt/NeMo/examples/multimodal/generative/stable_diffusion/conf/sd_xl_base.yaml +refiner_config: /opt/NeMo/examples/multimodal/generative/stable_diffusion/conf/sd_xl_refiner.yaml + +model: + scale_factor: 0.13025 + disable_first_stage_autocast: True + is_legacy: False + restore_from_path: "" + + fsdp: False + fsdp_set_buffer_dtype: null + fsdp_sharding_strategy: 'full' + use_cpu_initialization: True + # hidden_size: 4 + # pipeline_model_parallel_size: 4 + + optim: + name: fused_adam + lr: 1e-4 + weight_decay: 0.0 + betas: + - 0.9 + - 0.999 + sched: + name: WarmupHoldPolicy + warmup_steps: 10 + hold_steps: 10000000000000 # Incredibly large value to hold the lr as constant + + denoiser_config: + _target_: nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.denoiser.DiscreteDenoiser + num_idx: 1000 + + weighting_config: + _target_: nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.denoiser_weighting.EpsWeighting + scaling_config: + _target_: nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.denoiser_scaling.EpsScaling + discretization_config: + _target_: nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.discretizer.LegacyDDPMDiscretization + + unet_config: + _target_: nemo.collections.multimodal.modules.stable_diffusion.diffusionmodules.openaimodel.UNetModel + from_pretrained: /opt/nemo-aligner/checkpoints/sdxl/unet_nemo.ckpt + from_NeMo: True + adm_in_channels: 2816 + num_classes: sequential + use_checkpoint: False + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [ 4, 2 ] + num_res_blocks: 2 + channel_mult: [ 1, 2, 4 ] + num_head_channels: 64 + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: [ 1, 2, 10 ] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16 + context_dim: 2048 + image_size: 64 # unused +# spatial_transformer_attn_type: softmax #note: only default softmax is supported now + legacy: False + use_flash_attention: False + + first_stage_config: + # _target_: nemo.collections.multimodal.models.stable_diffusion.ldm.autoencoder.AutoencoderKLInferenceWrapper + _target_: nemo.collections.multimodal.models.text_to_image.stable_diffusion.ldm.autoencoder.AutoencoderKLInferenceWrapper + from_pretrained: /opt/nemo-aligner/checkpoints/sdxl/vae_nemo.ckpt + from_NeMo: True + embed_dim: 4 + monitor: val/rec_loss + ddconfig: + attn_type: vanilla + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: [ 1, 2, 4, 4 ] + num_res_blocks: 2 + attn_resolutions: [ ] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity + + conditioner_config: + _target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.GeneralConditioner + emb_models: + # crossattn cond + - is_trainable: False + input_key: txt + emb_model: + _target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.FrozenCLIPEmbedder + layer: hidden + layer_idx: 11 + # crossattn and vector cond + - is_trainable: False + input_key: txt + emb_model: + _target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.FrozenOpenCLIPEmbedder2 + arch: ViT-bigG-14 + version: laion2b_s39b_b160k + freeze: True + layer: penultimate + always_return_pooled: True + legacy: False + # vector cond + - is_trainable: False + input_key: original_size_as_tuple + emb_model: + _target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.ConcatTimestepEmbedderND + outdim: 256 # multiplied by two + # vector cond + - is_trainable: False + input_key: crop_coords_top_left + emb_model: + _target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.ConcatTimestepEmbedderND + outdim: 256 # multiplied by two + # vector cond + - is_trainable: False + input_key: target_size_as_tuple + emb_model: + _target_: nemo.collections.multimodal.modules.stable_diffusion.encoders.modules.ConcatTimestepEmbedderND + outdim: 256 # multiplied by two + diff --git a/examples/multimodal/text_to_image/stable_diffusion/sd_train.py b/examples/multimodal/text_to_image/stable_diffusion/sd_train.py index 968d9bec2884..7e151699b38c 100644 --- a/examples/multimodal/text_to_image/stable_diffusion/sd_train.py +++ b/examples/multimodal/text_to_image/stable_diffusion/sd_train.py @@ -74,7 +74,11 @@ def main(cfg) -> None: n, c, h = cfg.model.micro_batch_size, cfg.model.channels, cfg.model.image_size x = torch.randn((n, c, h, h), dtype=torch.float32, device="cuda") t = torch.randint(77, (n,), device="cuda") - cc = torch.randn((n, 77, cfg.model.unet_config.context_dim), dtype=torch.float32, device="cuda",) + cc = torch.randn( + (n, 77, cfg.model.unet_config.context_dim), + dtype=torch.float32, + device="cuda", + ) if cfg.model.precision in [16, '16']: x = x.type(torch.float16) cc = cc.type(torch.float16) @@ -93,9 +97,7 @@ def main(cfg) -> None: model.zero_grad() if cfg.model.get('peft', None): - peft_cfg_cls = PEFT_CONFIG_MAP[cfg.model.peft.peft_scheme] - if cfg.model.peft.restore_from_path is not None: # initialize peft weights from a checkpoint instead of randomly # This is not the same as resume training because optimizer states are not restored. diff --git a/examples/multimodal/text_to_image/stable_diffusion/sd_xl_infer.py b/examples/multimodal/text_to_image/stable_diffusion/sd_xl_infer.py index 8d18be517c69..de66db1725c4 100644 --- a/examples/multimodal/text_to_image/stable_diffusion/sd_xl_infer.py +++ b/examples/multimodal/text_to_image/stable_diffusion/sd_xl_infer.py @@ -26,32 +26,44 @@ def model_cfg_modifier(model_cfg): model_cfg.precision = cfg.trainer.precision model_cfg.ckpt_path = None model_cfg.inductor = False - model_cfg.unet_config.from_pretrained = None - model_cfg.first_stage_config.from_pretrained = None + # model_cfg.unet_config.from_pretrained = "/opt/nemo-aligner/checkpoints/sdxl/unet_nemo.ckpt" + # model_cfg.unet_config.from_NeMo = True + # model_cfg.first_stage_config.from_pretrained = "/opt/nemo-aligner/checkpoints/sdxl/vae_nemo.ckpt" + # model_cfg.first_stage_config.from_NeMo = True model_cfg.first_stage_config._target_ = 'nemo.collections.multimodal.models.text_to_image.stable_diffusion.ldm.autoencoder.AutoencoderKLInferenceWrapper' - model_cfg.fsdp = False + # model_cfg.fsdp = True torch.backends.cuda.matmul.allow_tf32 = True trainer, megatron_diffusion_model = setup_trainer_and_model_for_inference( model_provider=MegatronDiffusionEngine, cfg=cfg, model_cfg_modifier=model_cfg_modifier ) + ### Manually configure sharded model + # model = megatron_diffusion_model + # model = trainer.strategy._setup_model(model) + # model = model.cuda(torch.cuda.current_device()) + # get the diffusion part only model = megatron_diffusion_model.model model.cuda().eval() - base = SamplingPipeline(model, use_fp16=cfg.use_fp16, is_legacy=cfg.model.is_legacy) - use_refiner = cfg.get('use_refiner', False) - for i, prompt in enumerate(cfg.infer.prompt): - samples = base.text_to_image( - params=cfg.sampling.base, - prompt=[prompt], - negative_prompt=cfg.infer.negative_prompt, - samples=cfg.infer.num_samples, - return_latents=True if use_refiner else False, - seed=int(cfg.infer.seed + i * 100), - ) - - perform_save_locally(cfg.out_path, samples) + with torch.no_grad(): + base = SamplingPipeline(model, use_fp16=cfg.use_fp16, is_legacy=cfg.model.is_legacy) + use_refiner = cfg.get('use_refiner', False) + num_samples_per_batch = cfg.infer.get('num_samples_per_batch', cfg.infer.num_samples) + num_batches = cfg.infer.num_samples // num_samples_per_batch + + for i, prompt in enumerate(cfg.infer.prompt): + for batchid in range(num_batches): + samples = base.text_to_image( + params=cfg.sampling.base, + prompt=[prompt], + negative_prompt=cfg.infer.negative_prompt, + samples=num_samples_per_batch, + return_latents=True if use_refiner else False, + seed=int(cfg.infer.seed + i * 100 + batchid * 200), + ) + # samples=cfg.infer.num_samples, + perform_save_locally(cfg.out_path, samples) if __name__ == "__main__": diff --git a/examples/multimodal/text_to_image/stable_diffusion/sd_xl_train.py b/examples/multimodal/text_to_image/stable_diffusion/sd_xl_train.py index a91beca93761..44412aee0d14 100644 --- a/examples/multimodal/text_to_image/stable_diffusion/sd_xl_train.py +++ b/examples/multimodal/text_to_image/stable_diffusion/sd_xl_train.py @@ -41,7 +41,10 @@ def _training_strategy(self) -> NLPDDPStrategy: _IS_INTERACTIVE = hasattr(sys, "ps1") or bool(sys.flags.interactive) if _IS_INTERACTIVE and self.cfg.trainer.devices == 1: logging.info("Detected interactive environment, using NLPDDPStrategyNotebook") - return NLPDDPStrategyNotebook(no_ddp_communication_hook=True, find_unused_parameters=False,) + return NLPDDPStrategyNotebook( + no_ddp_communication_hook=True, + find_unused_parameters=False, + ) if self.cfg.model.get('fsdp', False): assert ( @@ -81,9 +84,7 @@ def main(cfg) -> None: model = MegatronDiffusionEngine(cfg.model, trainer) if cfg.model.get('peft', None): - peft_cfg_cls = PEFT_CONFIG_MAP[cfg.model.peft.peft_scheme] - if cfg.model.peft.restore_from_path is not None: # initialize peft weights from a checkpoint instead of randomly # This is not the same as resume training because optimizer states are not restored. diff --git a/examples/multimodal/vision_language_foundation/clip/conf/megatron_siglip_config.yaml b/examples/multimodal/vision_language_foundation/clip/conf/megatron_siglip_config.yaml new file mode 100644 index 000000000000..59f21813ce01 --- /dev/null +++ b/examples/multimodal/vision_language_foundation/clip/conf/megatron_siglip_config.yaml @@ -0,0 +1,253 @@ +name: megatron_siglip +restore_from_path: null # used when starting from a .nemo file + +trainer: + devices: 1 + num_nodes: 1 + accelerator: gpu + precision: bf16 + logger: False # logger provided by exp_manager + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: -1 # PTL default. In practice, max_steps will be reached first. + max_steps: 375000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches + log_every_n_steps: 10 + val_check_interval: 100 + check_val_every_n_epoch: null + limit_val_batches: 50 + limit_test_batches: 500 + accumulate_grad_batches: 1 # do not modify, grad acc is automatic for training megatron models + gradient_clip_val: 1.0 + benchmark: False + enable_model_summary: False # default PTL callback for this does not support model parallelism, instead we log manually + +exp_manager: + explicit_log_dir: null + exp_dir: null + name: megatron_siglip + create_wandb_logger: False + wandb_logger_kwargs: + project: null + name: null + resume_if_exists: True + resume_ignore_no_checkpoint: True + resume_from_checkpoint: ${model.resume_from_checkpoint} + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: val_loss + save_top_k: 10 + mode: min + always_save_nemo: False # saves nemo file during validation, not implemented for model parallel + save_nemo_on_train_end: False # not recommended when training large models on clusters with short time limits + filename: 'megatron_siglip--{val_loss:.2f}-{step}-{consumed_samples}' + model_parallel_size: ${multiply:${model.tensor_model_parallel_size}, ${model.pipeline_model_parallel_size}} + ema: + enable: False + decay: 0.9999 + validate_original_weights: False + every_n_steps: 1 + cpu_offload: False + +model: + precision: ${trainer.precision} + # specify micro_batch_size, global_batch_size, and model parallelism + # gradient accumulation will be done automatically based on data_parallel_size + micro_batch_size: 32 # limited by GPU memory + global_batch_size: 32 # will use more micro batches to reach global batch size + tensor_model_parallel_size: 1 # intra-layer model parallelism + pipeline_model_parallel_size: 1 # inter-layer model parallelism + virtual_pipeline_model_parallel_size: null # interleaved pipeline + + restore_from_path: null # used in fine-tuning + # multimodal configs + output_dim: 1152 + # As the number of devices used to train increases, so does the space complexity of + # the logit matrix. Using a naïve all-gather scheme, space complexity will be + # `O(n^2)`. Instead, complexity may become effectively linear if the flags + # `--gather-with-grad` and `--local-loss` are used. This alteration results in one-to-one + # numerical results as the naïve method. + + use_siglip: True + mcore_gpt: True + transformer_engine: True + + vision: + precision: ${trainer.precision} + # vision configs + patch_dim: 14 + img_h: 378 + img_w: 378 + image_mean: null + image_std: null + num_channels: 3 + drop_patch_rate: 0.0 + drop_path_rate: 0.0 + global_average_pool: False + output_dim: ${model.output_dim} + class_token_length: 0 + preprocess_layernorm: True # apply layer norm to embedded tokens + + # model architecture + encoder_seq_length: 196 + max_position_embeddings: ${.encoder_seq_length} + position_embedding_type: learned_absolute + num_layers: 27 + hidden_size: 1152 + ffn_hidden_size: 4304 # Transformer FFN hidden size. Usually 4 * hidden_size. + num_attention_heads: 16 + init_method_std: 0.02 # Standard deviation of the zero mean normal distribution used for weight initialization.') + use_scaled_init_method: True # use scaled residuals initialization + hidden_dropout: 0. # Dropout probability for hidden state transformer. + attention_dropout: 0. + kv_channels: null # Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if null + apply_query_key_layer_scaling: True # scale Q * K^T by 1 / layer-number. + normalization: layernorm # Type of normalization layers + layernorm_epsilon: 1e-5 + do_layer_norm_weight_decay: False # True means weight decay on all params + pre_process: True # add embedding + post_process: True # add pooler + persist_layer_norm: True # Use of persistent fused layer norm kernel. + + ## Activation Checkpointing + activations_checkpoint_granularity: null # 'selective' or 'full' + activations_checkpoint_method: null # 'uniform', 'block', not used with 'selective' + activations_checkpoint_num_layers: null # not used with 'selective' + sequence_parallel: False + + # precision + native_amp_init_scale: 4294967296 # 2 ** 32 + native_amp_growth_interval: 1000 + hysteresis: 2 # Gradient scale hysteresis + fp32_residual_connection: False # Move residual connections to fp32 + fp16_lm_cross_entropy: False # Move the cross entropy unreduced loss calculation for lm head to fp16 + + # model fusions + masked_softmax_fusion: True # Use a kernel that fuses the attention softmax with it's mask. + bias_dropout_add_fusion: True # Use a kernel that fuses the bias addition, dropout and residual connection addition. + + use_cpu_initialization: False # Init weights on the CPU (slow for large models) + onnx_safe: False # Use work-arounds for known problems with Torch ONNX exporter. + gradient_accumulation_fusion: False # Fuse weight gradient accumulation to GEMMs. Only used with pipeline parallelism. + openai_gelu: True + bias_activation_fusion: False + activation: approx-gelu + megatron_legacy: False + + + text: + precision: ${trainer.precision} + # text configs + output_dim: ${model.output_dim} + + # model architecture + encoder_seq_length: 64 + max_position_embeddings: ${.encoder_seq_length} + position_embedding_type: learned_absolute + num_layers: 27 + hidden_size: 1152 + ffn_hidden_size: 4304 # Transformer FFN hidden size. Usually 4 * hidden_size. + num_attention_heads: 16 + init_method_std: 0.02 # Standard deviation of the zero mean normal distribution used for weight initialization.') + use_scaled_init_method: True # use scaled residuals initialization + hidden_dropout: 0. # Dropout probability for hidden state transformer. + attention_dropout: 0. + kv_channels: null # Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if null + apply_query_key_layer_scaling: True # scale Q * K^T by 1 / layer-number. + normalization: layernorm # Type of normalization layers + layernorm_epsilon: 1e-5 + do_layer_norm_weight_decay: False # True means weight decay on all params + pre_process: True # add embedding + post_process: True # add pooler + persist_layer_norm: True # Use of persistent fused layer norm kernel. + + ## Activation Checkpointing + activations_checkpoint_granularity: null # 'selective' or 'full' + activations_checkpoint_method: null # 'uniform', 'block', not used with 'selective' + activations_checkpoint_num_layers: null # not used with 'selective' + num_micro_batches_with_partial_activation_checkpoints: null + activations_checkpoint_layers_per_pipeline: null + sequence_parallel: False + + # precision + native_amp_init_scale: 4294967296 # 2 ** 32 + native_amp_growth_interval: 1000 + hysteresis: 2 # Gradient scale hysteresis + fp32_residual_connection: False # Move residual connections to fp32 + fp16_lm_cross_entropy: False # Move the cross entropy unreduced loss calculation for lm head to fp16 + + # model fusions + masked_softmax_fusion: True # Use a kernel that fuses the attention softmax with it's mask. + bias_dropout_add_fusion: True # Use a kernel that fuses the bias addition, dropout and residual connection addition. + + use_cpu_initialization: False # Init weights on the CPU (slow for large models) + onnx_safe: False # Use work-arounds for known problems with Torch ONNX exporter. + gradient_accumulation_fusion: False # Fuse weight gradient accumulation to GEMMs. Only used with pipeline parallelism. + openai_gelu: True + bias_activation_fusion: False + megatron_legacy: False + + fp8: False # enables fp8 in TransformerLayer forward + fp8_e4m3: False # sets fp8_format = recipe.Format.E4M3 + fp8_hybrid: False # sets fp8_format = recipe.Format.HYBRID + fp8_margin: 0 # scaling margin + fp8_interval: 1 # scaling update interval + fp8_amax_history_len: 1 # Number of steps for which amax history is recorded per tensor + fp8_amax_compute_algo: most_recent # 'most_recent' or 'max'. Algorithm for computing amax from history + use_emha: False # Use fused multi-head attention for large sequence-length. Note this is not yet supported. Please set to False. + activation: approx-gelu + + # Megatron O2-style half-precision + megatron_amp_O2: True # Enable O2-level automatic mixed precision using main parameters + grad_allreduce_chunk_size_mb: 125 + grad_div_ar_fusion: True # Fuse grad division into torch.distributed.all_reduce + + # miscellaneous + seed: 1234 + resume_from_checkpoint: null # manually set the checkpoint file to load from + apex_transformer_log_level: 30 # Python logging level displays logs with severity greater than or equal to this + gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory) + + tokenizer: + library: 'huggingface' + type: 'google/siglip-so400m-patch14-384' + model: null + vocab_file: null + merge_file: null + delimiter: null # only used for tabular tokenizer + sentencepiece_legacy: False # Legacy=True allows you to add special tokens to sentencepiece tokenizers. + make_vocab_size_divisible_by: 128 # Pad the vocab size to be divisible by this value for computation efficiency. + + data: + num_workers: 8 + train: + dataset_path: # List of paths to pkl files or tar files + - /datasets/coyo/test.pkl + validation: # List of paths to pkl files or tar files + dataset_path: + - /datasets/coyo/test.pkl + webdataset: + infinite_sampler: False + local_root_path: /datasets/coyo + + imagenet_val: null # Path to imagenet val set for conducting zero shot evaluation. + + # Nsys profiling options + nsys_profile: + enabled: False + start_step: 10 # Global batch to start profiling + end_step: 10 # Global batch to end profiling + ranks: [ 0 ] # Global rank IDs to profile + gen_shape: False # Generate model and kernel details including input shapes + + optim: + name: fused_adam + lr: 1e-3 + weight_decay: 0.2 + betas: + - 0.9 + - 0.98 + sched: + name: CosineAnnealing + warmup_steps: 2000 + constant_steps: 0 + min_lr: 1e-5 \ No newline at end of file diff --git a/examples/multimodal/vision_language_foundation/clip/convert_external_clip_to_nemo.py b/examples/multimodal/vision_language_foundation/clip/convert_external_clip_to_nemo.py index 9af25181d07e..178140aac828 100644 --- a/examples/multimodal/vision_language_foundation/clip/convert_external_clip_to_nemo.py +++ b/examples/multimodal/vision_language_foundation/clip/convert_external_clip_to_nemo.py @@ -283,7 +283,10 @@ def convert(local_rank, rank, world_size, args): if __name__ == '__main__': - logging.warning("This script is going to be deprecated soon. Please use ") + logging.warning( + "This script is going to be deprecated soon. Please use " + "`scripts/checkpoint_converters/convert_clip_hf_to_nemo.py`" + ) args = get_args() local_rank, rank, world_size = initialize_distributed(args) convert(local_rank, rank, world_size, args) diff --git a/examples/nlp/information_retrieval/conf/megatron_bert_embedding_config.yaml b/examples/nlp/information_retrieval/conf/megatron_bert_embedding_config.yaml index 0b57313fb0a0..7e4ecf09f5a0 100644 --- a/examples/nlp/information_retrieval/conf/megatron_bert_embedding_config.yaml +++ b/examples/nlp/information_retrieval/conf/megatron_bert_embedding_config.yaml @@ -77,6 +77,11 @@ model: vocab_file: null merge_file: null + # embedding-specific arguemnts + softmax_temp: 0.02 # softmax temp for contrastive loss + global_inbatch_negatives: True # whether to use in-batch negatives from other ranks during training + backprop_type: 'global' # whether to use `global` or `local` backpropagation during training. Refer to Flava paper for details. + # precision native_amp_init_scale: 4294967296 # 2 ** 32 native_amp_growth_interval: 1000 @@ -93,7 +98,7 @@ model: use_cpu_initialization: False # Init weights on the CPU (slow for large models) onnx_safe: False # Use work-arounds for known problems with Torch ONNX exporter. gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory) - + ## Activation Checkpointing # NeMo Megatron supports 'selective' activation checkpointing where only the memory intensive part of attention is checkpointed. # These memory intensive activations are also less compute intensive which makes activation checkpointing more efficient for LLMs (20B+). @@ -127,7 +132,7 @@ model: # Path to data must be specified by the user. data_train: null data_validation: null - hard_negatives_to_train: 4 + hard_negatives_to_train: 4 # number of hard negatives to use per example for training index_mapping_dir: null # path to save index mapping .npy files, by default will save in the same location as data_prefix data_impl: mmap splits_string: 900,50,50 diff --git a/examples/nlp/information_retrieval/conf/megatron_gpt_embedder_generate_config.yaml b/examples/nlp/information_retrieval/conf/megatron_gpt_embedder_generate_config.yaml index 1a81d21dd9a8..e407aec167e9 100644 --- a/examples/nlp/information_retrieval/conf/megatron_gpt_embedder_generate_config.yaml +++ b/examples/nlp/information_retrieval/conf/megatron_gpt_embedder_generate_config.yaml @@ -120,7 +120,6 @@ model: tunable_base_param_names: ["self_attention", "word_embeddings"] # TODO: regex support @adithyre data: - return_output_tensors: True test_ds: query_file_names: ??? # Path to a list of JSONL files corresponding to the query data. Data format is identical to validation_ds. doc_file_names: ??? # Path to a list of JSONL files corresponding to the doc data. Data format is identical to validation_ds. diff --git a/examples/nlp/information_retrieval/conf/megatron_gpt_embedder_tuning_config.yaml b/examples/nlp/information_retrieval/conf/megatron_gpt_embedder_tuning_config.yaml index 6677dc2ed46c..1c2db1a862f4 100644 --- a/examples/nlp/information_retrieval/conf/megatron_gpt_embedder_tuning_config.yaml +++ b/examples/nlp/information_retrieval/conf/megatron_gpt_embedder_tuning_config.yaml @@ -84,6 +84,7 @@ model: use_flash_attention: True precision: bf16 apply_rope_fusion: False + reward_model_loss: False # Set this to true to perform RLHF style reward model loss -log(sigmoid(accept_logit - reject_logit)) peft: peft_scheme: "lora" # can be either adapter,ia3, or ptuning @@ -126,7 +127,6 @@ model: tunable_base_param_names: ["self_attention", "word_embeddings"] # TODO: regex support @adithyre data: - return_output_tensors: True train_ds: # Example of how to specify paths to multiple datasets # file_names: diff --git a/examples/nlp/information_retrieval/conf/megatron_gpt_reranker_tuning_config.yaml b/examples/nlp/information_retrieval/conf/megatron_gpt_reranker_tuning_config.yaml new file mode 100644 index 000000000000..863b5fb475a0 --- /dev/null +++ b/examples/nlp/information_retrieval/conf/megatron_gpt_reranker_tuning_config.yaml @@ -0,0 +1,222 @@ +name: megatron_gpt_peft_reranker_tuning + +trainer: + devices: 1 + accelerator: gpu + num_nodes: 1 + precision: bf16 + logger: False # logger provided by exp_manager + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: null + max_steps: 20000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches + log_every_n_steps: 10 # frequency with which training steps are logged + val_check_interval: ${trainer.max_steps} # If is an int n > 1, will run val every n training steps, if a float 0.0 - 1.0 will run val every epoch fraction, e.g. 0.25 will run val every quarter epoch + gradient_clip_val: null + num_sanity_val_steps: 0 + +exp_manager: + explicit_log_dir: null + exp_dir: null + name: ${name} + create_wandb_logger: False + wandb_logger_kwargs: + project: null + name: null + resume_if_exists: True + resume_ignore_no_checkpoint: True + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: validation_${model.data.validation_ds.metric.name} + save_top_k: 1 + mode: min + save_nemo_on_train_end: True + filename: '${name}--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}-{consumed_samples}' + model_parallel_size: ${model.tensor_model_parallel_size} + always_save_nemo: False + save_best_model: True + create_early_stopping_callback: False + early_stopping_callback_params: + monitor: "val_loss" + mode: "min" + min_delta: 0.001 + patience: 10 + verbose: True + strict: False # Should be False to avoid a runtime error where EarlyStopping says monitor is unavailable, which sometimes happens with resumed training. + +model: + seed: 1234 + tensor_model_parallel_size: 1 # intra-layer model parallelism + pipeline_model_parallel_size: 1 # inter-layer model parallelism + + global_batch_size: 128 + micro_batch_size: 4 + restore_from_path: ??? # Path to an existing .nemo model you wish to add new tasks to or run inference with + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + save_nemo_on_validation_end: False # Saves an inference ready .nemo file every time a checkpoint is saved during training. + sync_batch_comm: False + megatron_amp_O2: True + + ## Sequence Parallelism + # Makes tensor parallelism more memory efficient for LLMs (20B+) by parallelizing layer norms and dropout sequentially + # See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details. + sequence_parallel: False + + ## Activation Checkpoint + activations_checkpoint_granularity: selective # 'selective' or 'full' + activations_checkpoint_method: uniform # 'uniform', 'block', not used with 'selective' + # 'uniform' divides the total number of transformer layers and checkpoints the input activation + # of each chunk at the specified granularity + # 'block' checkpoints the specified number of layers per pipeline stage at the specified granularity + activations_checkpoint_num_layers: null # not used with 'selective' + activations_checkpoint_layers_per_pipeline: null + gradient_as_bucket_view: False + + hidden_dropout: 0.0 + attention_dropout: 0.0 + ffn_dropout: 0.0 + temperature: 0.02 + num_soft_negatives: 0 # Number of soft negatives to use for contrastive loss,it should be max(batch_size - 1), 0 means use hard negatives only + use_all_possible_negatives: False # If True, use all possible negatives for contrastive loss, otherwise use num_soft_negatives, if num_soft_negatives is 0, use hard negatives only + post_process: False # should be False. + apply_rope_fusion: False + transformer_engine: True # required to be True for newer versions of Megatron-LM based models + mcore_gpt: True # required to be True for newer versions of Megatron-LM based models + use_flash_attention: True + precision: bf16 + + peft: + peft_scheme: "mlp_head,lora" # can be either adapter,ia3, or ptuning + restore_from_path: null + + # Used for adapter peft training + adapter_tuning: + type: 'parallel_adapter' # this should be either 'parallel_adapter' or 'linear_adapter' + adapter_dim: 32 + adapter_dropout: 0.0 + norm_position: 'pre' # This can be set to 'pre', 'post' or null, 'pre' is normally what is used. + column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal + row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal + norm_type: 'mixedfusedlayernorm' # IGNORED if layer_adapter is used, options are ['layernorm', 'mixedfusedlayernorm'] + layer_selection: null # selects in which layers to add adapters, e.g. [1,12] will add adapters to layer 1 (lowest) and 12. null will apply adapters to all layers + weight_tying: False + position_embedding_strategy: null # used only when weight_tying is True + + lora_tuning: + target_modules: ['attention_qkv', 'attention_dense', 'mlp_fc1', 'mlp_fc2'] # + adapter_dim: 32 + adapter_dropout: 0.0 + column_init_method: 'xavier' # IGNORED if linear_adapter is used, options: xavier, zero or normal + row_init_method: 'zero' # IGNORED if linear_adapter is used, options: xavier, zero or normal + layer_selection: null # selects in which layers to add lora adapters. e.g. [1,12] will add lora to layer 1 (lowest) and 12. null will apply adapters to all layers + weight_tying: False + position_embedding_strategy: null # used only when weight_tying is True + + # Used for p-tuning peft training + p_tuning: + virtual_tokens: 10 # The number of virtual tokens the prompt encoder should add at the start of the sequence + bottleneck_dim: 1024 # the size of the prompt encoder mlp bottleneck + embedding_dim: 1024 # the size of the prompt encoder embeddings + init_std: 0.023 + + # Instead of using the GPT LM Head, we can use a custom head for the reranking task + mlp_head_tuning: + out_features: 1 + + ia3_tuning: + layer_selection: null # selects in which layers to add ia3 adapters. e.g. [1,12] will add lora to layer 1 (lowest) and 12. null will apply adapters to all layers + + selective_tuning: + tunable_base_param_names: ["self_attention", "word_embeddings"] # TODO: regex support @adithyre + + data: + train_ds: + # Example of how to specify paths to multiple datasets + # file_names: + # - /path/to/squad.jsonl + # - /path/to/mnli.jsonl + # - /path/to/boolq.jsonl + # Example of how each dataset is formatted + # {'input': 'John von Neumann\nVon Neumann made fundamental contributions .... Q: What did the math of artificial viscosity do?', 'output': 'smoothed the shock transition without sacrificing basic physics'} + file_names: ??? # Path to a list of JSONL files corresponding to the source data. + global_batch_size: ${model.global_batch_size} + micro_batch_size: ${model.micro_batch_size} + shuffle: True + num_workers: 0 + memmap_workers: 2 + pin_memory: True + max_seq_length: 512 # Even if the base model can handle longer sequences, 512 is generally a good choice for training efficiency. + min_seq_length: 1 + drop_last: True + # Example of how to specify concat_sampling_probabilities + # concat_sampling_probabilities: + # - 0.5 + # - 0.25 + # - 0.25 + concat_sampling_probabilities: + - 1.0 + label_key: 'output' + add_eos: True + add_bos: False + index_mapping_dir: null # Path to a directory to write index mapping files. + truncation_method: 'right' # Truncation from which position, Options: ['left', 'right'] + validation_ds: + file_names: ??? # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds. + names: ["validation"] # Names of the corresponding datasets used to log metrics. + global_batch_size: ${model.global_batch_size} + micro_batch_size: ${model.micro_batch_size} + shuffle: False + num_workers: 0 + memmap_workers: ${model.data.train_ds.memmap_workers} + pin_memory: True + max_seq_length: ${model.data.train_ds.max_seq_length} + min_seq_length: 1 + drop_last: False + label_key: ${model.data.train_ds.label_key} + add_eos: ${model.data.train_ds.add_eos} + add_bos: ${model.data.train_ds.add_bos} + write_embeddings_to_file: False + output_file_path_prefix: "validation_rankings" # Prefix of the file to write predictions to. + index_mapping_dir: null # Path to a directory to write index mapping files. + truncation_method: 'right' # Truncation from which position, Options: ['left', 'right'] + metric: + name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss'] + average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported. + num_classes: null + test_ds: + file_names: null # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds. + names: null # Names of the corresponding datasets used to log metrics. + global_batch_size: ${model.global_batch_size} + micro_batch_size: ${model.micro_batch_size} + shuffle: False + num_workers: 0 + memmap_workers: ${model.data.train_ds.memmap_workers} + pin_memory: True + max_seq_length: ${model.data.train_ds.max_seq_length} + min_seq_length: 1 + drop_last: False + add_eos: ${model.data.train_ds.add_eos} + add_bos: ${model.data.train_ds.add_bos} + write_predictions_to_file: True + output_file_path_prefix: "test_embeddings" # Prefix of the file to write predictions to. + index_mapping_dir: null # Path to a directory to write index mapping files. + truncation_method: 'right' # Truncation from which position, Options: ['left', 'right'] + metric: + name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss'] + average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported. + num_classes: null + + optim: + name: fused_adam + lr: 1e-4 + weight_decay: 0.01 + betas: + - 0.9 + - 0.98 + sched: + name: CosineAnnealing + warmup_steps: 50 + min_lr: 0.0 # min_lr must be 0.0 for prompt learning when pipeline parallel > 1 + constant_steps: 0 # Constant steps should also be 0 when min_lr=0 + monitor: val_loss + reduce_on_plateau: false \ No newline at end of file diff --git a/examples/nlp/information_retrieval/megatron_bert_embedding_finetuning.py b/examples/nlp/information_retrieval/megatron_bert_embedding_finetuning.py index 04d12fed9eca..7486b470425a 100644 --- a/examples/nlp/information_retrieval/megatron_bert_embedding_finetuning.py +++ b/examples/nlp/information_retrieval/megatron_bert_embedding_finetuning.py @@ -37,7 +37,7 @@ def main(cfg) -> None: model_cfg = MegatronBertEmbeddingModel.merge_cfg_with(cfg.restore_from_path, cfg) assert ( - model_cfg.micro_batch_size * cfg.trainer.devices == model_cfg.global_batch_size + model_cfg.micro_batch_size * cfg.trainer.devices * cfg.trainer.num_nodes == model_cfg.global_batch_size ), "Gradiant accumulation is not supported for contrastive learning yet" OmegaConf.set_struct(model_cfg, True) diff --git a/examples/nlp/information_retrieval/megatron_bert_embedding_generate.py b/examples/nlp/information_retrieval/megatron_bert_embedding_generate.py new file mode 100644 index 000000000000..9814129b837d --- /dev/null +++ b/examples/nlp/information_retrieval/megatron_bert_embedding_generate.py @@ -0,0 +1,56 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch.multiprocessing as mp +from omegaconf.omegaconf import OmegaConf, open_dict + +from nemo.collections.nlp.models.information_retrieval.megatron_bert_embedding_model import MegatronBertEmbeddingModel +from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronBertTrainerBuilder +from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.exp_manager import exp_manager + + +@hydra_runner(config_path="conf", config_name="megatron_bert_embedding_config") +def main(cfg) -> None: + if cfg.model.data.dataloader_type != "LDDL": + mp.set_start_method("spawn", force=True) + + logging.info("\n\n************** Experiment configuration ***********") + logging.info(f'\n{OmegaConf.to_yaml(cfg)}') + + trainer = MegatronBertTrainerBuilder(cfg).create_trainer() + exp_manager(trainer, cfg.exp_manager) + + model_cfg = MegatronBertEmbeddingModel.merge_cfg_with(cfg.restore_from_path, cfg) + + OmegaConf.set_struct(model_cfg, True) + with open_dict(model_cfg): + model_cfg.precision = trainer.precision + + logging.info(f"Loading model from {cfg.restore_from_path}") + model = MegatronBertEmbeddingModel.restore_from( + restore_path=cfg.restore_from_path, + trainer=trainer, + save_restore_connector=NLPSaveRestoreConnector(), + override_config_path=model_cfg, + strict=True, + ) + + trainer.test(model) + + +if __name__ == '__main__': + main() diff --git a/examples/nlp/information_retrieval/megatron_gpt_embedding_generate.py b/examples/nlp/information_retrieval/megatron_gpt_embedding_generate.py index 8cddcebbab62..d66ddb339773 100644 --- a/examples/nlp/information_retrieval/megatron_gpt_embedding_generate.py +++ b/examples/nlp/information_retrieval/megatron_gpt_embedding_generate.py @@ -68,7 +68,9 @@ def use_inference_server(cfg, model, trainer): web_ui = get_demo loop = asyncio.new_event_loop() thread = threading.Thread( - target=web_ui, daemon=True, args=(cfg.share, cfg.username, cfg.password, cfg.port, cfg.web_port, loop), + target=web_ui, + daemon=True, + args=(cfg.share, cfg.username, cfg.password, cfg.port, cfg.web_port, loop), ) thread.start() server = MegatronServer(model.cuda()) @@ -93,7 +95,6 @@ def main(cfg) -> None: model_cfg = MegatronGPTEmbeddingModel.merge_inference_cfg(cfg.model.restore_from_path, cfg) with open_dict(model_cfg): - model_cfg.data.return_output_tensors = True model_cfg.post_process = False model = MegatronGPTEmbeddingModel.restore_from(cfg.model.restore_from_path, model_cfg, trainer=trainer) diff --git a/examples/nlp/information_retrieval/megatron_gpt_reranker_finetuning.py b/examples/nlp/information_retrieval/megatron_gpt_reranker_finetuning.py new file mode 100644 index 000000000000..cf65840bb843 --- /dev/null +++ b/examples/nlp/information_retrieval/megatron_gpt_reranker_finetuning.py @@ -0,0 +1,76 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import MutableMapping + +import torch.multiprocessing as mp +from omegaconf.omegaconf import OmegaConf +from pytorch_lightning.loggers import WandbLogger + +from nemo.collections.nlp.models.information_retrieval.megatron_gpt_reranker_model import MegatronGPTRerankerModel +from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronLMPPTrainerBuilder +from nemo.collections.nlp.parts.peft_config import PEFT_CONFIG_MAP +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.exp_manager import exp_manager + +mp.set_start_method("spawn", force=True) + + +def flatten_dict(d: MutableMapping, parent_key: str = '', sep: str = '.') -> MutableMapping: + items = [] + for k, v in d.items(): + new_key = parent_key + sep + k if parent_key else k + if isinstance(v, MutableMapping): + items.extend(flatten_dict(v, new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) + + +@hydra_runner(config_path="conf", config_name="megatron_gpt_reranker_tuning_config") +def main(cfg) -> None: + logging.info("\n\n************** Experiment configuration ***********") + logging.info(f'\n{OmegaConf.to_yaml(cfg)}') + + trainer = MegatronLMPPTrainerBuilder(cfg).create_trainer() + exp_manager(trainer, cfg.exp_manager) + + model_cfg = MegatronGPTRerankerModel.merge_cfg_with(cfg.model.restore_from_path, cfg) + if trainer.global_rank == 0: + for logger in trainer.loggers: + if isinstance(logger, WandbLogger): + fd = flatten_dict(dict(model_cfg), sep="/") + logger.experiment.config.update(fd) + model = MegatronGPTRerankerModel.restore_from(cfg.model.restore_from_path, model_cfg, trainer=trainer) + peft_cfg_cls_lst = [PEFT_CONFIG_MAP[s] for s in cfg.model.peft.peft_scheme.split(",")] + peft_cfg_cls = [_peft_cfg(model_cfg) for _peft_cfg in peft_cfg_cls_lst] + + if cfg.model.peft.restore_from_path is not None: + # initialize peft weights from a checkpoint instead of randomly + # This is not the same as resume training because optimizer states are not restored. + logging.info("PEFT Weights will be loaded from", cfg.model.peft.restore_from_path) + model.load_adapters(cfg.model.peft.restore_from_path, peft_cfg_cls) + elif peft_cfg_cls is not None: + logging.info("Adding adapter weights to the model for PEFT") + # model.add_adapter(peft_cfg_cls(model_cfg)) + model.add_adapter(peft_cfg_cls) + else: + logging.info(f"Running full finetuning since no peft scheme is given.\n{model.summarize()}") + + trainer.fit(model) + + +if __name__ == '__main__': + main() diff --git a/examples/nlp/information_retrieval/megatron_gpt_reranker_generate.py b/examples/nlp/information_retrieval/megatron_gpt_reranker_generate.py new file mode 100644 index 000000000000..a91449c3deda --- /dev/null +++ b/examples/nlp/information_retrieval/megatron_gpt_reranker_generate.py @@ -0,0 +1,138 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import asyncio +import os +import threading +from functools import partial + +import torch +import torch.multiprocessing as mp +from omegaconf.omegaconf import OmegaConf, open_dict + +from nemo.collections.nlp.models.information_retrieval.megatron_gpt_reranker_model import MegatronGPTRerankerModel +from nemo.collections.nlp.modules.common.text_generation_server import MegatronServer +from nemo.collections.nlp.modules.common.text_generation_utils import generate +from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronLMPPTrainerBuilder +from nemo.collections.nlp.parts.peft_config import PEFT_CONFIG_MAP +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.model_utils import inject_model_parallel_rank + +try: + from megatron.core import parallel_state + + HAVE_MEGATRON_CORE = True +except (ImportError, ModuleNotFoundError): + + HAVE_MEGATRON_CORE = False + +mp.set_start_method("spawn", force=True) + + +def use_inference_server(cfg, model, trainer): + if not HAVE_MEGATRON_CORE: + raise ValueError('Megatron-core needs to be installed to use this feature!') + + from nemo.collections.nlp.modules.common.megatron_web_server import get_chatbot_demo, get_demo + + trainer.test(model, dataloaders=None) + + if parallel_state.is_pipeline_first_stage() and parallel_state.get_tensor_model_parallel_rank() == 0: + if cfg.web_server: + if cfg.chat: + defaults = { + 'user': cfg.chatbot_config.user, + 'assistant': cfg.chatbot_config.assistant, + 'system': cfg.chatbot_config.system, + } + web_ui = partial( + get_chatbot_demo, + defaults=defaults, + value=cfg.chatbot_config.value, + attributes=cfg.chatbot_config.attributes, + ) + else: + web_ui = get_demo + loop = asyncio.new_event_loop() + thread = threading.Thread( + target=web_ui, + daemon=True, + args=(cfg.share, cfg.username, cfg.password, cfg.port, cfg.web_port, loop), + ) + thread.start() + server = MegatronServer(model.cuda()) + server.run("0.0.0.0", port=cfg.port) + + while True: + choice = torch.cuda.LongTensor(1) + torch.distributed.broadcast(choice, 0) + if choice[0].item() == 0: + generate(model.cuda()) + + +@hydra_runner(config_path="conf", config_name="megatron_gpt_reranker_generate_config") +def main(cfg) -> None: + logging.info("\n\n************** Experiment configuration ***********") + logging.info(f"\n{OmegaConf.to_yaml(cfg)}") + trainer = MegatronLMPPTrainerBuilder(cfg).create_trainer() + + if cfg.model.peft.restore_from_path: + model_cfg = MegatronGPTRerankerModel.merge_inference_cfg(cfg.model.peft.restore_from_path, cfg) + else: + model_cfg = MegatronGPTRerankerModel.merge_inference_cfg(cfg.model.restore_from_path, cfg) + + with open_dict(model_cfg): + model_cfg.post_process = False + + model = MegatronGPTRerankerModel.restore_from(cfg.model.restore_from_path, model_cfg, trainer=trainer) + + if cfg.model.peft.restore_from_path: + model.load_adapters(cfg.model.peft.restore_from_path) + elif cfg.model.peft.restore_from_ckpt.checkpoint_dir and cfg.model.peft.restore_from_ckpt.checkpoint_name: + peft_cfg_cls_lst = [PEFT_CONFIG_MAP[s] for s in cfg.model.peft.peft_scheme.split(",")] + peft_cfg_cls = [_peft_cfg(model_cfg) for _peft_cfg in peft_cfg_cls_lst] + + checkpoint_path = os.path.join( + cfg.model.peft.restore_from_ckpt.checkpoint_dir, cfg.model.peft.restore_from_ckpt.checkpoint_name + ) + # checkpoint_path is a dir in case of distributed checkpointing + if not os.path.isdir(checkpoint_path): + # legacy checkpoint needs model parallel rank injection + checkpoint_path = inject_model_parallel_rank( + os.path.join( + cfg.model.peft.restore_from_ckpt.checkpoint_dir, cfg.model.peft.restore_from_ckpt.checkpoint_name + ) + ) + model.load_adapters(checkpoint_path, peft_cfgs=peft_cfg_cls) + else: + raise NotImplementedError("distributed checkpointing of PEFT weights is not supported") + + model.freeze() + logging.info(f"Freezing parameters for PEFT eval:\n{model.summarize()}") + + if not cfg.model.get('use_flash_attention', False): + cfg.inference.compute_attention_mask = True + config = OmegaConf.to_container(cfg.inference, resolve=True) + model.set_inference_config(config) + + if not cfg.server: + trainer.test(model) + else: + use_inference_server(cfg, model, trainer) + + +if __name__ == "__main__": + main() diff --git a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml index ac1f4a37b232..85609c2dd9b0 100755 --- a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml +++ b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml @@ -176,11 +176,13 @@ model: # Distributed checkpoint setup dist_ckpt_format: 'zarr' # Set to 'torch_dist' to use PyTorch distributed checkpoint format. dist_ckpt_load_on_device: True # whether to load checkpoint weights directly on GPU or to CPU - dist_ckpt_parallel_save: False # if true, each worker will write its own part of the dist checkpoint + dist_ckpt_parallel_save: True # if true, each worker will write its own part of the dist checkpoint + dist_ckpt_parallel_save_within_dp: False # if true, save will be parallelized only within a DP group (whole world otherwise), which might slightly reduce the save overhead dist_ckpt_parallel_load: False # if true, each worker will load part of the dist checkpoint and exchange with NCCL. Might use some extra GPU memory dist_ckpt_torch_dist_multiproc: 2 # number of extra processes per rank used during ckpt save with PyTorch distributed format dist_ckpt_assume_constant_structure: False # set to True only if the state dict structure doesn't change within a single job. Allows caching some computation across checkpoint saves. dist_ckpt_parallel_dist_opt: True # parallel save/load of a DistributedOptimizer. 'True' allows performant save and reshardable checkpoints. Set to 'False' only in order to minimize the number of checkpoint files. + dist_ckpt_load_strictness: null # defines checkpoint keys mismatch behavior (only during dist-ckpt load). Choices: assume_ok_unexpected (default - try loading without any check), log_all (log mismatches), raise_all (raise mismatches) ## Activation Checkpointing # NeMo Megatron supports 'selective' activation checkpointing where only the memory intensive part of attention is checkpointed. @@ -274,6 +276,7 @@ model: seq_length: ${model.encoder_seq_length} skip_warmup: True num_workers: 2 + num_dataset_builder_threads: 1 dataloader_type: single # cyclic reset_position_ids: False # Reset position ids after end-of-document token reset_attention_mask: False # Reset attention mask after end-of-document token @@ -282,7 +285,8 @@ model: no_seqlen_plus_one_input_tokens: False # Set to True to disable fetching (sequence length + 1) input tokens, instead get (sequence length) input tokens and mask the last token pad_samples_to_global_batch_size: False # Set to True if you want to pad the last partial batch with -1's to equal global batch size shuffle_documents: True # Set to False to disable documents shuffling. Sample index will still be shuffled - exchange_indices_distributed: False # Set to True to exchange indices via torch.distributed instead of filesystem + exchange_indices_distributed: False # Set to True to exchange indices via torch.distributed instead of filesystem + data_cache_generation_only: False # Set to True to generate only the data cache and stop the training script # Nsys profiling options nsys_profile: diff --git a/examples/nlp/language_modeling/conf/megatron_gpt_ptq.yaml b/examples/nlp/language_modeling/conf/megatron_gpt_ptq.yaml index c70719f51210..f603ebb58eb7 100644 --- a/examples/nlp/language_modeling/conf/megatron_gpt_ptq.yaml +++ b/examples/nlp/language_modeling/conf/megatron_gpt_ptq.yaml @@ -44,4 +44,5 @@ export: inference_pipeline_parallel: 1 # Default using 1 PP for inference dtype: ${trainer.precision} # Default precision data type save_path: llama2-7b-${quantization.algorithm}.qnemo # Path where the quantized model will be saved - compress: false # Wheter save_path should be a tarball or a directory + compress: false # Whether save_path should be a tarball or a directory + sample_output: true # Whether to run a sample prompt before saving diff --git a/examples/nlp/language_modeling/conf/megatron_mamba_config.yaml b/examples/nlp/language_modeling/conf/megatron_mamba_config.yaml index f4f37d7c4ce0..8b70263d5553 100644 --- a/examples/nlp/language_modeling/conf/megatron_mamba_config.yaml +++ b/examples/nlp/language_modeling/conf/megatron_mamba_config.yaml @@ -71,9 +71,6 @@ model: apply_query_key_layer_scaling: True # scale Q * K^T by 1 / layer-number. normalization: RMSNorm layernorm_epsilon: 1e-5 - num_moe_experts: 16 - moe_router_topk: 2 - moe_aux_loss_coeff: 0.001 make_vocab_size_divisible_by: 128 # Pad the vocab size to be divisible by this value for computation efficiency. pre_process: True # add embedding post_process: True # add pooler diff --git a/examples/nlp/language_modeling/conf/megatron_mamba_inference.yaml b/examples/nlp/language_modeling/conf/megatron_mamba_inference.yaml new file mode 100644 index 000000000000..c52b61715403 --- /dev/null +++ b/examples/nlp/language_modeling/conf/megatron_mamba_inference.yaml @@ -0,0 +1,96 @@ +inference: + greedy: False # Whether or not to use sampling ; use greedy decoding otherwise + top_k: 0 # The number of highest probability vocabulary tokens to keep for top-k-filtering. + top_p: 0.9 # If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation. + temperature: 1.0 # sampling temperature + add_BOS: True # add the bos token at the begining of the prompt + tokens_to_generate: 30 # The minimum length of the sequence to be generated. + all_probs: False # whether return the log prob for all the tokens in vocab + repetition_penalty: 1.2 # The parameter for repetition penalty. 1.0 means no penalty. + min_tokens_to_generate: 0 # The minimum length of the sequence to be generated. + compute_logprob: False # a flag used to compute logprob of all the input text, a very special case of running inference, default False + end_strings: ["<|endoftext|>"] # generation will stop when one of these tokens is generated + +trainer: + devices: 1 + num_nodes: 1 + accelerator: gpu + logger: False # logger provided by exp_manager + precision: bf16 # 16, 32, or bf16 + use_distributed_sampler: False + + +tensor_model_parallel_size: 1 +pipeline_model_parallel_size: 1 +pipeline_model_parallel_split_rank: 0 # used for encoder and decoder model (0 for others) +megatron_amp_O2: False # Enable O2-level automatic mixed precision to save memory +mamba_model_file: null # Mamba nemo file path +checkpoint_dir: null # checkpoint file dir. This is used to load the PTL checkpoint generated during the Mamba training +checkpoint_name: null # PTL checkpoint file name, only used for PTL checkpoint loading +hparams_file: null # model configuration file, only used for PTL checkpoint loading +prompts: # prompts for Mamba inference + - "Q: How are you?" + - "Q: How big is the universe?" +prompts_jsonl: null +server: False # whether launch the API server +port: 5555 # the port number for the inference server +web_server: False # whether launch the web inference server +share: False # whether create a public URL +username: test # user name for web client +password: test2 # password for web client +web_port: 9889 # the port number of the web server +chat: False # use the chat interface +chatbot_config: + value: False # whether to inject the value attributes + attributes: + - name: Quality + min: 0 + max: 4 + key: quality + type: int + default: 4 + - name: Toxicity + min: 0 + max: 4 + key: toxcity + type: int + default: 0 + - name: Humor + min: 0 + max: 4 + key: humor + type: int + default: 0 + - name: Creativity + min: 0 + max: 4 + key: creativity + type: int + default: 0 + - name: Violence + min: 0 + max: 4 + key: violence + type: int + default: 0 + - name: Helpfulness + min: 0 + max: 4 + key: helpfulness + type: int + default: 4 + - name: Not_Appropriate + min: 0 + max: 4 + key: not_appropriate + type: int + default: 0 + - name: Language + choices: ['ar', 'bg', 'bn', 'ca', 'cs', 'da', 'de', 'el', 'en', 'eo', 'es', 'eu', 'fa', 'fi', 'fr', 'gl', 'he', 'hu', 'id', 'it', 'ja', 'ko', 'nb', 'nl', 'pl', 'pt', 'ro', 'ru', 'sk', 'sv', 'th', 'tr', 'uk', 'vi', 'zh'] + key: lang + type: list + default: en + + user: User + assistant: Assistant + system: "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\n\n" diff --git a/examples/nlp/language_modeling/mamba_change_num_partition.py b/examples/nlp/language_modeling/mamba_change_num_partition.py index bc76b3215a74..ced2b43cd312 100644 --- a/examples/nlp/language_modeling/mamba_change_num_partition.py +++ b/examples/nlp/language_modeling/mamba_change_num_partition.py @@ -49,12 +49,13 @@ --d-model=4096 \ --mamba-version=2 \ --mamba2-n-groups=8 \ - --mamba2-head-dim=64 + --mamba2-head-dim=64 \ + --tokenizer_path= """ tp_split_dim = { 'word_embeddings.weight': 0, - 'norm.weight': -1, + 'in_proj.layer_norm_weight': -1, 'final_norm.weight': -1, 'output_layer.weight': 0, # mamba1/2 @@ -81,12 +82,6 @@ def get_split_dim(tensor_name): - # norm.weight will match tensor_name of mixer.norm.weight and norm.weight, need to distinguish - if 'norm.weight' in tensor_name: - if 'mixer.norm.weight' in tensor_name: - return tp_split_dim['mixer.norm.weight'] - else: - return tp_split_dim['norm.weight'] for key in tp_split_dim.keys(): if key in tensor_name: @@ -167,6 +162,90 @@ def split_tensor_for_tp(params, key, dim, tensor): return tensor_sliced +def combine_tp_tensors(params, key, dim, tensors): + tp_size = len(tensors) + + if 'mixer.in_proj.weight' in key and params.mamba_version == 1: + xs = [] + zs = [] + for tensor in tensors: + x, z = torch.split(tensor, [params.mamba_d_inner // tp_size, params.mamba_d_inner // tp_size], dim=dim) + xs.append(x) + zs.append(z) + return torch.cat([torch.cat(xs, dim=dim), torch.cat(zs, dim=dim)], dim=dim) + + elif 'mixer.in_proj.weight' in key and params.mamba_version == 2: + xs = [] + zs = [] + Bs = [] + Cs = [] + dts = [] + for tensor in tensors: + x, z, B, C, dt = torch.split( + tensor, + [ + params.mamba_d_inner // tp_size, + params.mamba_d_inner // tp_size, + (params.mamba2_n_groups // tp_size) * params.mamba_d_state, + (params.mamba2_n_groups // tp_size) * params.mamba_d_state, + params.mamba2_n_heads // tp_size, + ], + dim=dim, + ) + xs.append(x) + zs.append(z) + Bs.append(B) + Cs.append(C) + dts.append(dt) + + for ii in range(len(Bs)): + Bs[ii] = torch.reshape(Bs[ii], (-1, params.mamba_d_state, Bs[ii].shape[-1])) + Cs[ii] = torch.reshape(Cs[ii], (-1, params.mamba_d_state, Cs[ii].shape[-1])) + B = torch.cat(Bs, dim=dim) + C = torch.cat(Cs, dim=dim) + x = torch.cat(xs, dim=dim) + z = torch.cat(zs, dim=dim) + dt = torch.cat(dts, dim=dim) + + return torch.cat([x, z, B.flatten(0, 1), C.flatten(0, 1), dt], dim=dim) + + elif 'mixer.conv1d' in key and params.mamba_version == 2: + xs = [] + Bs = [] + Cs = [] + for tensor in tensors: + x, B, C = torch.split( + tensor, + [ + params.mamba_d_inner // tp_size, + (params.mamba2_n_groups // tp_size) * params.mamba_d_state, + (params.mamba2_n_groups // tp_size) * params.mamba_d_state, + ], + dim=dim, + ) + xs.append(x) + Bs.append(B) + Cs.append(C) + + for ii in range(len(Bs)): + if 'weight' in key: + Bs[ii] = torch.reshape(Bs[ii], (-1, params.mamba_d_state, Bs[ii].shape[-2], Bs[ii].shape[-1])) + Cs[ii] = torch.reshape(Cs[ii], (-1, params.mamba_d_state, Cs[ii].shape[-2], Cs[ii].shape[-1])) + elif 'bias' in key: + Bs[ii] = torch.reshape(Bs[ii], (-1, params.mamba_d_state)) + Cs[ii] = torch.reshape(Cs[ii], (-1, params.mamba_d_state)) + else: + raise Exception("Unknown key") + B = torch.cat(Bs, dim=dim) + C = torch.cat(Cs, dim=dim) + x = torch.cat(xs, dim=dim) + + return torch.cat([x, B.flatten(0, 1), C.flatten(0, 1)], dim=dim) + + else: + return torch.cat(tensors, dim=dim) + + ################# ### Utilities ### ################# @@ -296,6 +375,58 @@ def split_tp_partition_only(args, model, original_model, tp_size, write_path=Non tar.extractall(path=os.path.dirname(write_path)) +def merge_partition(args, model, partitions, write_path: str = None): + # Extract the pp_rank and number of modules per tp rank in each pp rank + + input_tp_rank = len(partitions) + + # During merge - model is TP 1 PP 1 model with all parameters present in correct order. + # Merge the parameters of the various PP X TP Y models into the TP 1 PP 1 model. + from collections import OrderedDict + + full_model = OrderedDict() + combined_tp_model = OrderedDict() + for _, (key, original_tensor) in enumerate(partitions[0].items()): + if "_extra_state" in key: + combined_tp_model[key] = original_tensor + continue + + import copy + + split_dim = get_split_dim(key) + original_shape = list(original_tensor.shape) + combined_shape = copy.deepcopy(original_shape) + combined_shape[split_dim] *= input_tp_rank + + if split_dim != -1: + # slice together model + + combined_tensor = combine_tp_tensors( + args, key, split_dim, [partitions[jj][key].cpu() for jj in range(input_tp_rank)] + ) + combined_tp_model[key] = combined_tensor + else: + # copy model + combined_tp_model[key] = original_tensor + + for _, (local_key, local_original_tensor) in enumerate(combined_tp_model.items()): + try: + layer_num = int(re.findall(r'\d+', local_key)[0]) + new_key = local_key.replace(str(layer_num), str(layer_num), 1) + except: + new_key = local_key + full_model[new_key] = local_original_tensor + + # Update the model parameter with the merged tensor + + model.load_state_dict(full_model, strict=True) + + # Save the file iff the original file was PP 1 TP 1 + if write_path is not None: + model.save_to(write_path) + return model + + def main(): parser = ArgumentParser() parser.add_argument("--model_file", type=str, default=None, required=False, help="Path to source .nemo file") @@ -351,7 +482,7 @@ def main(): '--tp_conversion_only', default=True, action='store_true', help='Only convert TP model to TP model' ) parser.add_argument('--model_extracted_dir', type=str, default=None, help='Path to pre-extracted model directory') - + parser.add_argument('--tokenizer_path', type=str, default=None, required=True) parser.add_argument('--d-model', type=int, default=4096) parser.add_argument('--mamba-version', type=int, default=2) parser.add_argument('--mamba-d-state', type=int, default=128) @@ -394,25 +525,6 @@ def main(): pp_size = args.pipeline_model_parallel_size tgt_pp_size = args.target_pipeline_model_parallel_size pipeline_model_parallel_split_rank = args.target_pipeline_model_parallel_split_rank - vp_size = args.virtual_pipeline_model_parallel_size - if vp_size is None: - vp_size = 1 - - convert_vp = vp_size > 1 - if convert_vp: - from megatron.core import parallel_state - - parallel_state.set_virtual_pipeline_model_parallel_world_size(vp_size) - - hparams_filepath = args.hparams_file - if hparams_filepath is None: - logging.warning( - '\n\n\n!!!!!!!!!\n' - 'You are converting a model with virtual pipeline parallelism enabled, \n' - 'but have not passed `hparams_file` argument. \n' - 'This will cause each ckpt file to be temporarily laoded onto GPU memory!\n\n' - 'It is highly recommended to pass `hparams_file` argument to avoid this.\n' - ) # Import the class of the model @@ -478,16 +590,11 @@ def main(): tgt_pp_size = 1 pipeline_model_parallel_split_rank = 0 - if vp_size is None or vp_size < 0: - vp_size = 1 - app_state = AppState() app_state.data_parallel_rank = 0 app_state.pipeline_model_parallel_size = pp_size app_state.tensor_model_parallel_size = tp_size - if vp_size > 1: - app_state.virtual_pipeline_model_parallel_size = vp_size app_state.model_parallel_size = app_state.pipeline_model_parallel_size * app_state.tensor_model_parallel_size world_size = pp_size * tp_size # pseudo world size for simulating load of a specific rank on a single gpu @@ -520,57 +627,203 @@ def main(): f"`--tokenizer_model_path`.\n\n" ) - # If input model has TP > 1 or PP > 1 - # Reconstruct the model to have TP = 1 and PP = 1 - # Note that this is a forward loop that will process PP [0..N] TP [0..M] in sequential order. + # If input model has TP > 1 + # Reconstruct the model to have TP = 1 + if tp_size > 1 or pp_size > 1: + partitions = [] + model = None + + for pp_rank in range(pp_size): + app_state.pipeline_model_parallel_rank = pp_rank + + for tp_rank in range(tp_size): + app_state.tensor_model_parallel_rank = tp_rank + + logging.info(f"Loading ------------ PP Rank: {pp_rank} TP Rank: {tp_rank}") + + # Override flag that forces Model to use AppState instead of Trainer + # to determine the world size, global and local rank + # Used for simulating load of a specific rank on a single gpu + os.environ[NEMO_MEGATRON_MODEL_PARALLEL_APPSTATE_OVERRIDE] = "true" + + # Compute the global rank to load the correct subset of parameters + global_rank = pp_rank * tp_size + tp_rank + + # Update AppState + app_state.world_size = world_size + app_state.global_rank = global_rank + app_state.local_rank = global_rank % num_gpu_per_node + app_state.pipeline_model_parallel_size = pp_size + app_state.tensor_model_parallel_size = tp_size + app_state.pipeline_model_parallel_split_rank = pipeline_model_parallel_split_rank + app_state.model_parallel_size = ( + app_state.pipeline_model_parallel_size * app_state.tensor_model_parallel_size + ) + + save_restore_connector = NLPSaveRestoreConnector() + + if args.model_extracted_dir is not None: + logging.info(f"Using extracted model directory: {args.model_extracted_dir}") + save_restore_connector.model_extracted_dir = args.model_extracted_dir + + if args.model_file is not None: + model_filepath = args.model_file + else: + model_filepath = args.model_extracted_dir + + # Get model config + tmp_cfg = MegatronMambaModel.restore_from( + restore_path=model_filepath, + trainer=trainer, + map_location=torch.device("cpu"), + save_restore_connector=save_restore_connector, + return_config=True, + ) + + # Force model onto CPU + tmp_cfg, restore_dict = force_cpu_model(tmp_cfg) - # If input model has TP = 1 and PP = 1 - app_state.model_parallel_size = 1 + # Restore model + model = MegatronMambaModel.restore_from( + restore_path=model_filepath, + trainer=trainer, + map_location=torch.device("cpu"), + save_restore_connector=save_restore_connector, + override_config_path=tmp_cfg, + ) + model.freeze() + + # Restore model config + restore_model_config(model.cfg, restore_dict) + + model.to(dtype=dtype) + + # Reset env flag + os.environ.pop(NEMO_MEGATRON_MODEL_PARALLEL_APPSTATE_OVERRIDE, None) + + logging.info(f"<<<<<<<< LOADED MODEL TP={tp_rank + 1} | " f"GLOBAL RANK = {global_rank} >>>>>>>>>") + + # Save the parameters + + partitions.append(model.state_dict()) + + # app_state is being updated incorrectly during restore + app_state.data_parallel_rank = 0 + app_state.pipeline_model_parallel_rank = pp_rank + app_state.tensor_model_parallel_rank = tp_rank + app_state.pipeline_model_parallel_size = pp_size + app_state.tensor_model_parallel_size = tp_size + app_state.model_parallel_size = ( + app_state.pipeline_model_parallel_size * app_state.tensor_model_parallel_size + ) + + # Build a unified model with PP 1 TP 1 + with open_dict(model.cfg): + model.cfg.tensor_model_parallel_size = 1 + model.cfg.pipeline_model_parallel_size = 1 + model.cfg.virtual_pipeline_model_parallel_size = None + + app_state.global_rank = 0 + app_state.local_rank = 0 + app_state.data_parallel_rank = 0 + app_state.pipeline_model_parallel_rank = 0 + app_state.tensor_model_parallel_rank = 0 + app_state.pipeline_model_parallel_size = 1 + app_state.tensor_model_parallel_size = 1 + app_state.model_parallel_size = 1 + + trainer = Trainer(plugins=plugins, devices=1, strategy=NLPDDPStrategy(), accelerator="cpu") + + with open_dict(model.cfg): + if args.tokenizer_model_path is not None: + model.cfg.tokenizer.model = args.tokenizer_model_path + if args.tokenizer_vocab_file is not None: + model.cfg.tokenizer.vocab_file = args.tokenizer_vocab_file + + model.cfg, restore_dict = force_cpu_model(model.cfg) + + # Remove Virtual Parallelism + model.cfg.virtual_pipeline_model_parallel_size = None + + logging.info(f"<<<<<<<< Building TP 1 PP 1 base model >>>>>>>>>") + + gbs = model.cfg.global_batch_size + mbs = model.cfg.micro_batch_size + + model.cfg.global_batch_size = None + model.cfg.micro_batch_size = None + + model.cfg.tokenizer.model = args.tokenizer_path + model.cfg.tokenizer.library = 'megatron' + model.cfg.tokenizer.type = 'GPTSentencePieceTokenizer' + + model = MegatronMambaModel(model.cfg, trainer) # type: nn.Module + model.freeze() + model = model.to('cpu') + model._save_restore_connector = NLPSaveRestoreConnector() + + restore_model_config(model.cfg, restore_dict) + + if tgt_tp_size > 1: + original_model = merge_partition(args, model, partitions) + else: + # Write out the PP 1 TP 1 model to disk + original_model = merge_partition(args, model, partitions, args.target_file) - save_restore_connector = NLPSaveRestoreConnector() + # Empty cache memory of all parameters from all PP TP partitions + partitions.clear() - if args.model_extracted_dir is not None: - logging.info(f"Using extracted model directory: {args.model_extracted_dir}") - save_restore_connector.model_extracted_dir = args.model_extracted_dir + model.cfg.global_batch_size = gbs + model.cfg.micro_batch_size = mbs - if args.model_file is not None: - model_filepath = args.model_file + # If input model has TP = 1 else: - model_filepath = args.model_extracted_dir + app_state.model_parallel_size = 1 - tmp_cfg = MegatronMambaModel.restore_from( - restore_path=model_filepath, - trainer=trainer, - map_location=torch.device("cpu"), - save_restore_connector=save_restore_connector, - return_config=True, - ) + save_restore_connector = NLPSaveRestoreConnector() - tmp_cfg, restore_dict = force_cpu_model(tmp_cfg) + if args.model_extracted_dir is not None: + logging.info(f"Using extracted model directory: {args.model_extracted_dir}") + save_restore_connector.model_extracted_dir = args.model_extracted_dir - model = MegatronMambaModel.restore_from( - restore_path=model_filepath, - trainer=trainer, - map_location=torch.device("cpu"), - save_restore_connector=save_restore_connector, - override_config_path=tmp_cfg, - ) + if args.model_file is not None: + model_filepath = args.model_file + else: + model_filepath = args.model_extracted_dir - original_model = MegatronMambaModel.restore_from( - restore_path=model_filepath, - trainer=trainer, - map_location=torch.device("cpu"), - save_restore_connector=save_restore_connector, - override_config_path=tmp_cfg, - ) - original_model = original_model.to('cpu') - original_model._save_restore_connector = NLPSaveRestoreConnector() - original_model.freeze() - original_model.to(dtype=dtype) + tmp_cfg = MegatronMambaModel.restore_from( + restore_path=model_filepath, + trainer=trainer, + map_location=torch.device("cpu"), + save_restore_connector=save_restore_connector, + return_config=True, + ) + + tmp_cfg, restore_dict = force_cpu_model(tmp_cfg) + + model = MegatronMambaModel.restore_from( + restore_path=model_filepath, + trainer=trainer, + map_location=torch.device("cpu"), + save_restore_connector=save_restore_connector, + override_config_path=tmp_cfg, + ) + + original_model = MegatronMambaModel.restore_from( + restore_path=model_filepath, + trainer=trainer, + map_location=torch.device("cpu"), + save_restore_connector=save_restore_connector, + override_config_path=tmp_cfg, + ) + original_model = original_model.to('cpu') + original_model._save_restore_connector = NLPSaveRestoreConnector() + original_model.freeze() + original_model.to(dtype=dtype) - model.to(dtype=dtype) + model.to(dtype=dtype) - restore_model_config(model.cfg, restore_dict) + restore_model_config(model.cfg, restore_dict) # If target model has TP > 1 or PP > 1 if tgt_pp_size > 1 or tgt_tp_size > 1: @@ -653,12 +906,11 @@ def main(): model.cfg, restore_dict = force_cpu_model(model.cfg) - from apex.transformer.pipeline_parallel.utils import _GLOBAL_NUM_MICROBATCHES_CALCULATOR + gbs = model.cfg.global_batch_size + mbs = model.cfg.micro_batch_size - _GLOBAL_NUM_MICROBATCHES_CALCULATOR.current_global_batch_size = 1 - _GLOBAL_NUM_MICROBATCHES_CALCULATOR.current_micro_batch_size = 1 - model.cfg.global_batch_size = 1 - model.cfg.micro_batch_size = 1 + model.cfg.global_batch_size = None + model.cfg.micro_batch_size = None model = MegatronMambaModel(model.cfg, trainer) model = model.to('cpu') @@ -666,6 +918,9 @@ def main(): model.freeze() model.to(dtype=dtype) + model.cfg.global_batch_size = gbs + model.cfg.micro_batch_size = mbs + restore_model_config(model.cfg, restore_dict) # Update global batch size diff --git a/examples/nlp/language_modeling/megatron_gpt_pretraining.py b/examples/nlp/language_modeling/megatron_gpt_pretraining.py index 422319a382c8..1cc3d0aae27d 100644 --- a/examples/nlp/language_modeling/megatron_gpt_pretraining.py +++ b/examples/nlp/language_modeling/megatron_gpt_pretraining.py @@ -44,9 +44,12 @@ def main(cfg) -> None: if cfg.model.get("restore_from_path") is not None: # Option 1: Restore only the model weights from a .nemo file logging.info(f"Continual training: loading weights from {cfg.model.restore_from_path}") + from nemo.collections.nlp.models.language_modeling.megatron_gpt_sft_model import MegatronGPTSFTModel + + model_cfg = MegatronGPTSFTModel.merge_cfg_with(cfg.model.restore_from_path, cfg) model = MegatronGPTModel.restore_from( restore_path=cfg.model.restore_from_path, - override_config_path=cfg.model, + override_config_path=model_cfg, trainer=trainer, save_restore_connector=NLPSaveRestoreConnector(), ) diff --git a/examples/nlp/language_modeling/megatron_mamba_eval.py b/examples/nlp/language_modeling/megatron_mamba_eval.py new file mode 100644 index 000000000000..ed12e4b904ac --- /dev/null +++ b/examples/nlp/language_modeling/megatron_mamba_eval.py @@ -0,0 +1,411 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import datetime +import json +import os +import threading +from functools import partial + +import torch +from omegaconf import OmegaConf, open_dict +from pytorch_lightning.trainer.trainer import Trainer +from torch.utils.data import DataLoader, Dataset + +from nemo.collections.nlp.models.language_modeling.megatron_mamba_model import MegatronMambaModel +from nemo.collections.nlp.modules.common.megatron.megatron_init import fake_initialize_model_parallel +from nemo.collections.nlp.modules.common.text_generation_server import MegatronServer +from nemo.collections.nlp.modules.common.text_generation_utils import generate +from nemo.collections.nlp.modules.common.transformer.text_generation import LengthParam, SamplingParam +from nemo.collections.nlp.parts.nlp_overrides import CustomProgressBar, NLPDDPStrategy, NLPSaveRestoreConnector +from nemo.collections.nlp.parts.utils_funcs import torch_dtype_from_precision +from nemo.core.config import hydra_runner +from nemo.utils.app_state import AppState +from nemo.utils.model_utils import inject_model_parallel_rank + +try: + from megatron.core import parallel_state + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + + HAVE_MEGATRON_CORE = False + +""" +This is the script to run GPT text generation. + +Usage: + Assume the model has TP=1, PP=1 in the following use cases. + a. run greedy inference from a nemo file: + python megatron_gpt_eval.py \ + mamba_model_file=PATH_TO_MODEL \ + inference.greedy=True \ + inference.add_BOS=True \ + trainer.devices=1 \ + trainer.num_nodes=1 \ + tensor_model_parallel_size=-1 \ + pipeline_model_parallel_size=-1 \ + prompts=[prompt1,prompt2] + + b. run greedy inference from a PTL checkpoint file: + python megatron_gpt_eval.py \ + checkpoint_dir=PATH_TO_CHECKPOINT_FILE \ + checkpoint_name=CHECKPOINT_FILE_NAME \ + hparams_file=HPARAMS_FILE \ + inference.greedy=True \ + inference.add_BOS=True \ + trainer.devices=1 \ + trainer.num_nodes=1 \ + tensor_model_parallel_size=-1 \ + pipeline_model_parallel_size=-1 \ + prompts=[prompt1,prompt2] + + c. run top_p inference from a nemo file: + python megatron_gpt_eval.py \ + mamba_model_file=PATH_TO_MODEL \ + inference.greedy=False \ + inference.top_k=0 \ + inference.top_p=0.9 \ + inference.repetition_penalty=1.2 \ + inference.add_BOS=True \ + trainer.devices=1 \ + trainer.num_nodes=1 \ + tensor_model_parallel_size=-1 \ + pipeline_model_parallel_size=-1 \ + prompts=[prompt1,prompt2] + + d. If you don't need to generate tokens and need model to compute logprobs: + python megatron_gpt_eval.py \ + mamba_model_file=PATH_TO_MODEL \ + inference.compute_logprob=True \ + trainer.devices=1 \ + trainer.num_nodes=1 \ + tensor_model_parallel_size=-1 \ + pipeline_model_parallel_size=-1 \ + prompts=[text to get logprob] + + e. Launch the inference server + python megatron_gpt_eval.py \ + mamba_model_file=PATH_TO_MODEL \ + trainer.devices=1 \ + trainer.num_nodes=1 \ + tensor_model_parallel_size=-1 \ + pipeline_model_parallel_size=-1 \ + server=True + + To send a request to the server, here is one example code: + ```python + import json + import requests + + batch_size = 8 + port_num = 5555 + headers = {"Content-Type": "application/json"} + + + def request_data(data): + resp = requests.put('http://localhost:{}/generate'.format(port_num), + data=json.dumps(data), + headers=headers) + sentences = resp.json()['sentences'] + return sentences + + + data = { + "sentences": [""] * batch_size, + "tokens_to_generate": 300, + "temperature": 1.0, + "add_BOS": True, + "top_k": 0, + "top_p": 0.9, + "greedy": False, + "all_probs": False, + "repetition_penalty": 1.2, + "min_tokens_to_generate": 2, + } + + sentences = request_data(data) + ``` +""" + +if not torch.cuda.is_available(): + raise EnvironmentError("GPU is needed for the inference") + + +class RequestDataSet(Dataset): + def __init__(self, sentences): + super().__init__() + self.sentences = sentences + + def __len__( + self, + ): + return len(self.sentences) + + def __getitem__(self, idx): + return self.sentences[idx] + + +def remove_padded_prompts(response, nb_paddings): + result = {} + for k, v in response.items(): + if v != None and (type(v) is list or type(v) is torch.Tensor): + v = v[:-nb_paddings] + result[k] = v + return result + + +def load_model_from_config(trainer, cfg): + if cfg.mamba_model_file is not None: + if ( + cfg.tensor_model_parallel_size < 0 + or cfg.pipeline_model_parallel_size < 0 + or cfg.get('pipeline_model_parallel_split_rank', -1) < 0 + ): + save_restore_connector = NLPSaveRestoreConnector() + if os.path.isdir(cfg.mamba_model_file): + save_restore_connector.model_extracted_dir = cfg.mamba_model_file + model_config = MegatronMambaModel.restore_from( + restore_path=cfg.mamba_model_file, + trainer=trainer, + return_config=True, + save_restore_connector=save_restore_connector, + ) + + # with dist checkpointing we don't need to set this + if not model_config.get('mcore_gpt', False): + with open_dict(cfg): + cfg.tensor_model_parallel_size = model_config.get('tensor_model_parallel_size', 1) + cfg.pipeline_model_parallel_size = model_config.get('pipeline_model_parallel_size', 1) + cfg.pipeline_model_parallel_split_rank = model_config.get('pipeline_model_parallel_split_rank', 0) + + assert ( + cfg.trainer.devices * cfg.trainer.num_nodes + == cfg.tensor_model_parallel_size + * cfg.pipeline_model_parallel_size + * max(1, cfg.get('expert_model_parallel_size', 1)) + ), "devices * num_nodes should equal tensor_model_parallel_size * pipeline_model_parallel_size" + + if cfg.mamba_model_file: + save_restore_connector = NLPSaveRestoreConnector() + if os.path.isdir(cfg.mamba_model_file): + save_restore_connector.model_extracted_dir = cfg.mamba_model_file + + pretrained_cfg = MegatronMambaModel.restore_from( + restore_path=cfg.mamba_model_file, + trainer=trainer, + return_config=True, + save_restore_connector=save_restore_connector, + ) + OmegaConf.set_struct(pretrained_cfg, True) + with open_dict(pretrained_cfg): + pretrained_cfg.sequence_parallel = False + pretrained_cfg.activations_checkpoint_granularity = None + pretrained_cfg.activations_checkpoint_method = None + pretrained_cfg.precision = trainer.precision + pretrained_cfg["use_flash_attention"] = cfg.inference.get("use_flash_attention", False) + pretrained_cfg["apply_rope_fusion"] = False + if pretrained_cfg.get('mcore_gpt', False): + # with dist checkpointing we can use the model parallel config specified by the user + pretrained_cfg.tensor_model_parallel_size = cfg.tensor_model_parallel_size + pretrained_cfg.pipeline_model_parallel_size = cfg.pipeline_model_parallel_size + pretrained_cfg.expert_model_parallel_size = cfg.get('expert_model_parallel_size', 1) + pretrained_cfg.micro_batch_size = 1 + if trainer.precision == "16": + pretrained_cfg.megatron_amp_O2 = False + elif trainer.precision in ['bf16', 'bf16-mixed'] and cfg.get('megatron_amp_O2', False): + pretrained_cfg.megatron_amp_O2 = True + + model = MegatronMambaModel.restore_from( + restore_path=cfg.mamba_model_file, + trainer=trainer, + override_config_path=pretrained_cfg, + save_restore_connector=save_restore_connector, + map_location=f'cuda:{trainer.local_rank}', # map_location is needed for converted models + ) + elif cfg.checkpoint_dir: + app_state = AppState() + if ( + cfg.tensor_model_parallel_size > 1 + or cfg.pipeline_model_parallel_size > 1 + or cfg.get('expert_model_parallel_size', 1) > 1 + ): + app_state.model_parallel_size = ( + cfg.tensor_model_parallel_size + * cfg.pipeline_model_parallel_size + * cfg.get('expert_model_parallel_size', 1) + ) + app_state.tensor_model_parallel_size = cfg.tensor_model_parallel_size + app_state.pipeline_model_parallel_size = cfg.pipeline_model_parallel_size + app_state.expert_model_parallel_size = cfg.get('expert_model_parallel_size', 1) + ( + app_state.tensor_model_parallel_rank, + app_state.pipeline_model_parallel_rank, + app_state.expert_model_parallel_rank, + app_state.model_parallel_size, + app_state.data_parallel_size, + app_state.pipeline_model_parallel_split_rank, + app_state.virtual_pipeline_model_parallel_rank, + ) = fake_initialize_model_parallel( + world_size=app_state.model_parallel_size, + rank=trainer.global_rank, + tensor_model_parallel_size_=cfg.tensor_model_parallel_size, + pipeline_model_parallel_size_=cfg.pipeline_model_parallel_size, + pipeline_model_parallel_split_rank_=cfg.pipeline_model_parallel_split_rank, + expert_model_parallel_size_=cfg.get('expert_model_parallel_size', 1), + ) + checkpoint_path = os.path.join(cfg.checkpoint_dir, cfg.checkpoint_name) + # checkpoint_path is a dir in case of distributed checkpointing + if not os.path.isdir(checkpoint_path): + # legacy checkpoint needs model parallel rank injection + checkpoint_path = inject_model_parallel_rank(os.path.join(cfg.checkpoint_dir, cfg.checkpoint_name)) + model = MegatronMambaModel.load_from_checkpoint( + checkpoint_path, hparams_file=cfg.hparams_file, trainer=trainer + ) + else: + raise ValueError("need at least a nemo file or checkpoint dir") + + dtype = torch_dtype_from_precision(cfg.trainer.precision) + model = model.to(dtype=dtype) + return model + + +def load_prompts(cfg): + prompts = [] + if (cfg_prompts := getattr(cfg, 'prompts', None)) is not None: + prompts = OmegaConf.to_container(cfg_prompts) + if (prompts_jsonl := getattr(cfg, 'prompts_jsonl', None)) is not None: + with open(prompts_jsonl, 'rt') as fp: + try: + prompts += list(map(json.loads, map(str.rstrip, fp))) + except: + prompts += list(map(str.rstrip, fp)) + # Make sure non-empty input + assert len(prompts) > 0, "Expected at least one prompt" + # Make sure all have the same type + assert all( + map(lambda x: isinstance(x, type(prompts[0])), prompts) + ), "Expected all prompts to have the same datatype" + return prompts + + +def round_to_mult(n, mult=8): + """ + Rounds number n to be a multiple of mult + """ + return ((n + mult - 1) // mult) * mult + + +@hydra_runner(config_path="conf", config_name="megatron_mamba_inference") +def main(cfg) -> None: + + callbacks = [] + # enable_progress_bar is True by default. If cfg.trainer.enable_progress_bar=False, CustomProgressBar is not appended to callbacks + if 'enable_progress_bar' not in cfg.trainer or cfg.trainer.enable_progress_bar: + callbacks.append(CustomProgressBar()) + # trainer required for restoring model parallel models + trainer = Trainer( + strategy=NLPDDPStrategy(timeout=datetime.timedelta(seconds=18000)), + **cfg.trainer, + callbacks=callbacks, + ) + + model = load_model_from_config(trainer, cfg) + model.freeze() + + # Have to turn off activations_checkpoint_method for inference + try: + model.model.language_model.encoder.activations_checkpoint_method = None + except AttributeError: + pass + + length_params: LengthParam = { + "max_length": cfg.inference.tokens_to_generate, + "min_length": cfg.inference.min_tokens_to_generate, + } + + sampling_params: SamplingParam = { + "use_greedy": cfg.inference.greedy, + "temperature": cfg.inference.temperature, + "top_k": cfg.inference.top_k, + "top_p": cfg.inference.top_p, + "repetition_penalty": cfg.inference.repetition_penalty, + "add_BOS": cfg.inference.add_BOS, + "all_probs": cfg.inference.all_probs, + "compute_logprob": cfg.inference.compute_logprob, + "end_strings": cfg.inference.end_strings, + } + + prompts = load_prompts(cfg) + + # First method of running text generation, call model.generate method + response = model.generate(inputs=prompts, length_params=length_params, sampling_params=sampling_params) + + print("***************************") + print(response) + print("***************************") + + # Second method of running text generation, call trainer.predict [recommended] + bs = 2 + ds = RequestDataSet(prompts) + request_dl = DataLoader(dataset=ds, batch_size=bs) + config = OmegaConf.to_container(cfg.inference) + model.set_inference_config(config) + response = trainer.predict(model, request_dl) + + print("***************************") + print(response) + print("***************************") + + # Third method of running text generation, use inference server + if cfg.server: + from nemo.collections.nlp.modules.common.megatron_web_server import get_chatbot_demo, get_demo + + if parallel_state.is_pipeline_first_stage() and parallel_state.get_tensor_model_parallel_rank() == 0: + if cfg.web_server: + if cfg.chat: + defaults = { + 'user': cfg.chatbot_config.user, + 'assistant': cfg.chatbot_config.assistant, + 'system': cfg.chatbot_config.system, + } + web_ui = partial( + get_chatbot_demo, + defaults=defaults, + value=cfg.chatbot_config.value, + attributes=cfg.chatbot_config.attributes, + ) + else: + web_ui = get_demo + loop = asyncio.new_event_loop() + thread = threading.Thread( + target=web_ui, + daemon=True, + args=(cfg.share, cfg.username, cfg.password, cfg.port, cfg.web_port, loop), + ) + thread.start() + server = MegatronServer(model.cuda()) + server.run("0.0.0.0", port=cfg.port) + + while True: + choice = torch.cuda.LongTensor(1) + torch.distributed.broadcast(choice, 0) + if choice[0].item() == 0: + generate(model.cuda()) + + +if __name__ == '__main__': + main() # noqa pylint: disable=no-value-for-parameter diff --git a/examples/nlp/language_modeling/tuning/conf/megatron_gpt_finetuning_config.yaml b/examples/nlp/language_modeling/tuning/conf/megatron_gpt_finetuning_config.yaml index 6517b62010b4..06551f46486c 100644 --- a/examples/nlp/language_modeling/tuning/conf/megatron_gpt_finetuning_config.yaml +++ b/examples/nlp/language_modeling/tuning/conf/megatron_gpt_finetuning_config.yaml @@ -158,6 +158,7 @@ model: index_mapping_dir: null # Path to a directory to write index mapping files. prompt_template: "{input} {output}" # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}" truncation_method: 'right' # Truncation from which position, Options: ['left', 'right'] + global_sample_mapping: False # Whether to shuffle the replicated data all together, or shuffle the dataset within each epoch validation_ds: file_names: ??? # Path to a list of JSONL files corresponding to the source data. Data format is identical to train_ds. names: null # Names of the corresponding datasets used to log metrics. @@ -181,6 +182,7 @@ model: prompt_template: ${model.data.train_ds.prompt_template} # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}" tokens_to_generate: 32 # decide how many tokens we want to generate to evaluate performance with string metrics truncation_method: 'right' # Truncation from which position, Options: ['left', 'right'] + global_sample_mapping: False # Whether to shuffle the replicated data all together, or shuffle the dataset within each epoch metric: name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss'] average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported. @@ -208,6 +210,7 @@ model: prompt_template: ${model.data.train_ds.prompt_template} tokens_to_generate: 32 # decide how many tokens we want to generate to evaluate performance with string metrics truncation_method: 'right' # Truncation from which position, Options: ['left', 'right'] + global_sample_mapping: False # Whether to shuffle the replicated data all together, or shuffle the dataset within each epoch metric: name: "loss" # Name of the evaluation metric to use. Options: ['exact_string_match', 'loss'] average: null # Average the metric over the dataset. Options: ['macro', 'micro']. Works only for 'F1', 'accuracy' etc. Refer to torchmetrics for metrics where this is supported. diff --git a/examples/nlp/language_modeling/tuning/conf/megatron_mamba_finetuning_config.yaml b/examples/nlp/language_modeling/tuning/conf/megatron_mamba_finetuning_config.yaml index 33498540a3d5..447d46714f3d 100644 --- a/examples/nlp/language_modeling/tuning/conf/megatron_mamba_finetuning_config.yaml +++ b/examples/nlp/language_modeling/tuning/conf/megatron_mamba_finetuning_config.yaml @@ -21,7 +21,7 @@ exp_manager: explicit_log_dir: null exp_dir: null name: ${name} - create_wandb_logger: True + create_wandb_logger: False wandb_logger_kwargs: project: griffin name: sft-test @@ -82,7 +82,7 @@ model: ffn_dropout: 0.0 peft: - peft_scheme: "lora" # can be either adapter,ia3, lora, or ptuning + peft_scheme: "none" # can be either adapter,ia3, lora, or ptuning restore_from_path: null # Used for adapter peft training diff --git a/examples/nlp/language_modeling/tuning/conf/megatron_mamba_generate_config.yaml b/examples/nlp/language_modeling/tuning/conf/megatron_mamba_generate_config.yaml index fddfa16c8c09..ec856efe39a2 100644 --- a/examples/nlp/language_modeling/tuning/conf/megatron_mamba_generate_config.yaml +++ b/examples/nlp/language_modeling/tuning/conf/megatron_mamba_generate_config.yaml @@ -135,10 +135,10 @@ model: output_file_path_prefix: null # Prefix of the file to write predictions to. truncation_field: "input" # Options: keys in prompt_template index_mapping_dir: null # Path to a directory to write index mapping files. - prompt_template: "{input} {output}" + prompt_template: "{input} {output}" # fstring to use for assistant prompt. Example: "Q: {input}\nA: {output}" tokens_to_generate: 32 # decide how many tokens we want to generate to evaluate performance with string metrics truncation_method: 'right' # Truncation from which position, Options: ['left', 'right'] - ceil_to_power_2: True + ceil_to_power_2: False get_attention_mask_from_fusion: True pad_to_max_length: True diff --git a/examples/nlp/language_modeling/tuning/megatron_gpt_finetuning.py b/examples/nlp/language_modeling/tuning/megatron_gpt_finetuning.py index aaa087a46623..bfe8ea35960e 100644 --- a/examples/nlp/language_modeling/tuning/megatron_gpt_finetuning.py +++ b/examples/nlp/language_modeling/tuning/megatron_gpt_finetuning.py @@ -1,4 +1,4 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/examples/nlp/machine_translation/nmt_transformer_infer_megatron.py b/examples/nlp/machine_translation/nmt_transformer_infer_megatron.py index c8ab668fc16c..fcf1fb8d1796 100644 --- a/examples/nlp/machine_translation/nmt_transformer_infer_megatron.py +++ b/examples/nlp/machine_translation/nmt_transformer_infer_megatron.py @@ -29,21 +29,12 @@ from nemo.collections.nlp.models.machine_translation.megatron_nmt_model import MegatronNMTModel from nemo.collections.nlp.modules.common.megatron.megatron_init import fake_initialize_model_parallel -from nemo.collections.nlp.modules.common.megatron.utils import ApexGuardDefaults from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy, NLPSaveRestoreConnector from nemo.core.config import hydra_runner from nemo.utils import logging from nemo.utils.app_state import AppState from nemo.utils.model_utils import inject_model_parallel_rank -try: - from apex.transformer.pipeline_parallel.utils import _reconfigure_microbatch_calculator - - HAVE_APEX = True -except (ImportError, ModuleNotFoundError): - ModelType = ApexGuardDefaults() - HAVE_APEX = False - @hydra_runner(config_path="conf", config_name="nmt_megatron_infer") def main(cfg) -> None: @@ -101,13 +92,19 @@ def main(cfg) -> None: src_text.append(line.strip()) if len(src_text) == cfg.batch_size: translations = model.translate( - text=src_text, source_lang=cfg.source_lang, target_lang=cfg.target_lang, + text=src_text, + source_lang=cfg.source_lang, + target_lang=cfg.target_lang, ) for translation in translations: tgt_f.write(translation + "\n") src_text = [] if len(src_text) > 0: - translations = model.translate(text=src_text, source_lang=cfg.source_lang, target_lang=cfg.target_lang,) + translations = model.translate( + text=src_text, + source_lang=cfg.source_lang, + target_lang=cfg.target_lang, + ) for translation in translations: tgt_f.write(translation + "\n") diff --git a/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py b/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py index e9e97d3d32d7..b1fc17bf8836 100644 --- a/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py +++ b/nemo/collections/asr/data/audio_to_text_lhotse_prompted.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from dataclasses import dataclass from typing import Callable, Sequence import torch.utils.data @@ -18,6 +19,7 @@ from lhotse.cut import MixedCut, MonoCut from lhotse.dataset import AudioSamples from lhotse.dataset.collation import collate_vectors +from lhotse.utils import ifnone from nemo.collections.asr.data.audio_to_text_lhotse import TokenizerWrapper from nemo.collections.common.prompts.canary import CanaryPromptFormatter @@ -25,6 +27,26 @@ from nemo.collections.common.tokenizers.canary_tokenizer import CANARY_SPECIAL_TOKENIZER +@dataclass +class PromptedAudioToTextMiniBatch: + audio: torch.Tensor + audio_lens: torch.Tensor + transcript: torch.Tensor + transcript_lens: torch.Tensor + prompt: torch.Tensor + prompt_lens: torch.Tensor + prompted_transcript: torch.Tensor + prompted_transcript_lens: torch.Tensor + + def get_decoder_inputs_outputs(self) -> tuple[torch.Tensor, torch.Tensor]: + """ + Returns the inputs and outputs of transformer decoder for training. + The input is ``prompted_transcript`` (minus last token), + and the output is ``prompted_transcript`` (minus first token). + """ + return self.prompted_transcript[:, :-1], self.prompted_transcript[:, 1:] + + class PromptedAudioToTextLhotseDataset(torch.utils.data.Dataset): """ This dataset is based on :class:`~nemo.collections.asr.data.audio_to_text_lhotse.LhotseSpeechToTextBpeDataset`. @@ -45,41 +67,46 @@ class PromptedAudioToTextLhotseDataset(torch.utils.data.Dataset): def __init__( self, tokenizer: TokenizerSpec, - prompt_format_fn: Callable[[CutSet, TokenizerWrapper, bool], Sequence[Sequence[int]]], - inference: bool = False, + prompt_format_fn: Callable[[CutSet, TokenizerWrapper], Sequence[Sequence[int]]], ): super().__init__() self.tokenizer = TokenizerWrapper(tokenizer) self.load_audio = AudioSamples(fault_tolerant=True) self.padding_value = self.tokenizer._tokenizer.pad_id self.prompt_format_fn = prompt_format_fn - self.inference = inference - def __getitem__(self, cuts: CutSet) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + def __getitem__(self, cuts: CutSet) -> PromptedAudioToTextMiniBatch: audio, audio_lens, cuts = self.load_audio(cuts) - prompts_with_answers, prompts = self.prompt_format_fn(cuts, self.tokenizer, inference=self.inference) - - prompts_with_answers = [torch.as_tensor(t) for t in prompts_with_answers] - prompts_with_answers_lens = torch.tensor([t.size(0) for t in prompts_with_answers], dtype=torch.long) - prompts_with_answers = collate_vectors(prompts_with_answers, padding_value=self.padding_value) - - if self.inference: - prompts = [torch.as_tensor(t) for t in prompts] - prompts_lens = torch.tensor([t.size(0) for t in prompts], dtype=torch.long) - prompts = collate_vectors(prompts, padding_value=self.padding_value) - else: - prompts = None - prompts_lens = None + prompts_with_answers, prompts, answers = self.prompt_format_fn(cuts, self.tokenizer) + + transcript, transcript_lens = self._collate_tokens(answers) + prompts_with_answers, prompts_with_answers_lens = self._collate_tokens(prompts_with_answers) + prompts, prompt_lens = self._collate_tokens(prompts) + + return PromptedAudioToTextMiniBatch( + audio=audio, + audio_lens=audio_lens, + transcript=transcript, + transcript_lens=transcript_lens, + prompt=prompts, + prompt_lens=prompt_lens, + prompted_transcript=prompts_with_answers, + prompted_transcript_lens=prompts_with_answers_lens, + ) - return audio, audio_lens, prompts_with_answers, prompts_with_answers_lens, prompts, prompts_lens + def _collate_tokens(self, tokens: list[list[int]]) -> tuple[torch.Tensor, torch.Tensor]: + tokens = [torch.as_tensor(t) for t in tokens] + token_lens = torch.tensor([t.size(0) for t in tokens], dtype=torch.long) + tokens = collate_vectors(tokens, padding_value=self.padding_value) + return tokens, token_lens # Mapping from a string name to a known prompt formatter function. PROMPT_FORMAT_FNS = {} -def registered_prompt_format_fn(prompt_fn: Callable[[CutSet, TokenizerWrapper, bool], Sequence[Sequence[int]]]): +def registered_prompt_format_fn(prompt_fn: Callable[[CutSet, TokenizerWrapper], Sequence[Sequence[int]]]): """ Decorator for registering prompt functions under a name. @@ -97,7 +124,7 @@ def registered_prompt_format_fn(prompt_fn: Callable[[CutSet, TokenizerWrapper, b return prompt_fn -def get_prompt_format_fn(name: str) -> Callable[[CutSet, TokenizerWrapper, bool], Sequence[Sequence[int]]]: +def get_prompt_format_fn(name: str) -> Callable[[CutSet, TokenizerWrapper], Sequence[Sequence[int]]]: if name not in PROMPT_FORMAT_FNS: raise ValueError( f"Unknown prompt format function name: {name} " f"(must be one of: {list(PROMPT_FORMAT_FNS.keys())}" @@ -107,8 +134,8 @@ def get_prompt_format_fn(name: str) -> Callable[[CutSet, TokenizerWrapper, bool] @registered_prompt_format_fn def canary( - cuts: CutSet, tokenizer: TokenizerWrapper, inference: bool = False -) -> tuple[list[torch.Tensor], list[torch.Tensor]]: + cuts: CutSet, tokenizer: TokenizerWrapper +) -> tuple[list[torch.Tensor], list[torch.Tensor], list[torch.Tensor]]: """ Prepend and append control tokens to the token sequence as per Canary format. @@ -137,7 +164,7 @@ def canary( ), "To use 'canary' prompt format, you must use the CanaryTokenizer." formatter = CanaryPromptFormatter(tokenizer._tokenizer) - prompts_with_answers, prompts = [], [] + prompts_with_answers, prompts, answers = [], [], [] for cut in cuts: if isinstance(cut, MixedCut): cut = cut._first_non_padding_cut @@ -160,28 +187,40 @@ def canary( f"Please ensure that every utterance in the input manifests contains these keys." ) - encoded = formatter.encode_dialog( - turns=[ - dict( - role="user", - slots={ - **{slot: cut.custom[slot] for slot in expected_slots}, - formatter.PROMPT_LANGUAGE_SLOT: CANARY_SPECIAL_TOKENIZER, - }, - ), + turns = [ + dict( + role="user", + slots={ + **{slot: cut.custom[slot] for slot in expected_slots}, + formatter.PROMPT_LANGUAGE_SLOT: CANARY_SPECIAL_TOKENIZER, + }, + ) + ] + if text := ' '.join(s.text for s in cut.supervisions if s.text is not None): + # Create answer_ids only if there is some transcript in the data. + turns.extend( dict( role="assistant", slots={ - "text": ' '.join(s.text for s in cut.supervisions), - formatter.PROMPT_LANGUAGE_SLOT: cut.custom["target_lang"], + "text": text, + formatter.PROMPT_LANGUAGE_SLOT: ifnone( + cut.supervisions[0].language, cut.custom.get("target_lang") + ), }, ), - ] - ) + ) + encoded = formatter.encode_dialog(turns) prompts_with_answers.append(encoded["input_ids"]) prompts.append(encoded["context_ids"]) + if "answer_ids" in encoded: + assert ( + encoded["answer_ids"][-1].item() == formatter.tokenizer.eos + ), f"Expected the last token in answer_ids to be EOS, but we got {encoded['answer_ids']=}" + answers.append(encoded["answer_ids"][:-1]) # Strip Canary's EOS + else: + answers.append([]) - return prompts_with_answers, prompts + return prompts_with_answers, prompts, answers class ProbablyIncorrectLanguageKeyError(RuntimeError): diff --git a/nemo/collections/asr/metrics/bleu.py b/nemo/collections/asr/metrics/bleu.py index 011e3efe0c6a..32bd25d952d4 100644 --- a/nemo/collections/asr/metrics/bleu.py +++ b/nemo/collections/asr/metrics/bleu.py @@ -34,12 +34,12 @@ def move_dimension_to_the_front(tensor, dim_index): # TODO: Add documentation class BLEU(SacreBLEUScore): """ - This metric computes numerator, denominator, hypotheses lengths, and target lengths for Overall Bilingual Evaluation Understudy (BLEU) - between prediction and reference texts. When doing distributed training/evaluation the result of + This metric computes numerator, denominator, hypotheses lengths, and target lengths for Overall Bilingual Evaluation Understudy (BLEU) + between prediction and reference texts. When doing distributed training/evaluation the result of ``res=BLEU.(predictions, predictions_lengths, targets, target_lengths)`` calls will be all-reduced between all workers using SUM operations. - If used with PytorchLightning LightningModule, include bleu_num bleur_den, bleu_pred_len, and bleu_target_len values inside + If used with PytorchLightning LightningModule, include bleu_num bleur_den, bleu_pred_len, and bleu_target_len values inside validation_step results. Then aggregate (sum) then at the end of validation epoch to correctly compute validation BLEUR. Example: @@ -99,7 +99,6 @@ def __init__( smooth=smooth, dist_sync_on_step=dist_sync_on_step, ) - self.has_spl_tokens = False self.decoding = decoding self.decode = None if isinstance(self.decoding, AbstractRNNTDecoding): @@ -113,7 +112,6 @@ def __init__( fold_consecutive=self.fold_consecutive, ) elif isinstance(self.decoding, AbstractMultiTaskDecoding): - self.has_spl_tokens = True self.decode = lambda predictions, prediction_lengths, predictions_mask, input_ids, targets: self.decoding.decode_predictions_tensor( encoder_hidden_states=predictions, encoder_input_mask=predictions_mask, @@ -165,10 +163,6 @@ def update( references.append(reference) hypotheses, _ = self.decode(predictions, predictions_lengths, predictions_mask, input_ids, targets) - if self.has_spl_tokens: - hypotheses = [self.decoding.strip_special_tokens(hyp) for hyp in hypotheses] - references = [self.decoding.strip_special_tokens(ref) for ref in references] - if self.log_prediction: logging.info(f"\n") logging.info(f"reference:{references[0]}") @@ -185,7 +179,7 @@ def compute(self, return_all_metrics=True, prefix="", suffix=""): only BLEU. Default: True. prefix: str to prepend to metric value keys. suffix: str to append to metric value keys. - + Returns: Dict: key-value pairs of BLEU metrics and values. Keys are prepended and appended with prefix and suffix flags, respectively. @@ -205,7 +199,11 @@ def compute(self, return_all_metrics=True, prefix="", suffix=""): # Adding wrapper to avoid imports and extra variables over the namespace def _compute_bleu( - self, predictions_lengths, targets_lengths, numerator, denominator, + self, + predictions_lengths, + targets_lengths, + numerator, + denominator, ): return _bleu_score_compute( predictions_lengths, targets_lengths, numerator, denominator, self.n_gram, self.weights, self.smooth diff --git a/nemo/collections/asr/metrics/wer.py b/nemo/collections/asr/metrics/wer.py index 1cb4cf06eaca..a135e5c51e84 100644 --- a/nemo/collections/asr/metrics/wer.py +++ b/nemo/collections/asr/metrics/wer.py @@ -148,8 +148,8 @@ def word_error_rate_detail( def word_error_rate_per_utt(hypotheses: List[str], references: List[str], use_cer=False) -> Tuple[List[float], float]: """ Computes Word Error Rate per utterance and the average WER - between two texts represented as corresponding lists of string. - + between two texts represented as corresponding lists of string. + Hypotheses and references must have same length. Args: @@ -263,7 +263,6 @@ def __init__( self.fold_consecutive = fold_consecutive self.batch_dim_index = batch_dim_index - self.has_spl_tokens = False self.decode = None if isinstance(self.decoding, AbstractRNNTDecoding): self.decode = lambda predictions, predictions_lengths, predictions_mask, input_ids, targets: self.decoding.rnnt_decoder_predictions_tensor( @@ -276,7 +275,6 @@ def __init__( fold_consecutive=self.fold_consecutive, ) elif isinstance(self.decoding, AbstractMultiTaskDecoding): - self.has_spl_tokens = True self.decode = lambda predictions, prediction_lengths, predictions_mask, input_ids, targets: self.decoding.decode_predictions_tensor( encoder_hidden_states=predictions, encoder_input_mask=predictions_mask, @@ -326,10 +324,6 @@ def update( references.append(reference) hypotheses, _ = self.decode(predictions, predictions_lengths, predictions_mask, input_ids, targets) - if self.has_spl_tokens: - hypotheses = [self.decoding.strip_special_tokens(hyp) for hyp in hypotheses] - references = [self.decoding.strip_special_tokens(ref) for ref in references] - if self.log_prediction: logging.info(f"\n") logging.info(f"reference:{references[0]}") diff --git a/nemo/collections/asr/models/aed_multitask_models.py b/nemo/collections/asr/models/aed_multitask_models.py index 5ec7a8298bee..301b2b0ad026 100644 --- a/nemo/collections/asr/models/aed_multitask_models.py +++ b/nemo/collections/asr/models/aed_multitask_models.py @@ -27,6 +27,7 @@ from nemo.collections.asr.data.audio_to_text_lhotse_prompted import ( PromptedAudioToTextLhotseDataset, + PromptedAudioToTextMiniBatch, get_prompt_format_fn, ) from nemo.collections.asr.metrics import BLEU, WER @@ -434,7 +435,7 @@ def change_prompt( prompt_cls = PromptFormatter.resolve(self.prompt_format) self.prompt = prompt_cls( tokenizer=self.tokenizer, - defaults=OmegaConf.to_container(pd) if (pd := self.cfg.prompt_defaults) is not None else None, + defaults=OmegaConf.to_container(pd) if (pd := self.cfg.get('prompt_defaults')) is not None else None, ) # Update config @@ -498,7 +499,7 @@ def transcribe( return super().transcribe(audio=audio, override_config=trcfg) - def _setup_dataloader_from_config(self, config: Optional[Dict], inference: bool = False): + def _setup_dataloader_from_config(self, config: Optional[Dict]): assert config.get("use_lhotse", False), ( "Multi-task model only supports dataloading with Lhotse. " "Please set config.{train,validation,test}_ds.use_lhotse=True" @@ -510,7 +511,6 @@ def _setup_dataloader_from_config(self, config: Optional[Dict], inference: bool dataset=PromptedAudioToTextLhotseDataset( tokenizer=self.tokenizer, prompt_format_fn=get_prompt_format_fn(self.prompt_format), - inference=inference, ), ) @@ -554,7 +554,7 @@ def setup_validation_data(self, val_data_config: Optional[Union[DictConfig, Dict # preserve config self._update_dataset_config(dataset_name='validation', config=val_data_config) - self._validation_dl = self._setup_dataloader_from_config(config=val_data_config, inference=True) + self._validation_dl = self._setup_dataloader_from_config(config=val_data_config) def setup_test_data(self, test_data_config: Optional[Union[DictConfig, Dict]]): """ @@ -570,7 +570,7 @@ def setup_test_data(self, test_data_config: Optional[Union[DictConfig, Dict]]): # preserve config self._update_dataset_config(dataset_name='test', config=test_data_config) - self._test_dl = self._setup_dataloader_from_config(config=test_data_config, inference=True) + self._test_dl = self._setup_dataloader_from_config(config=test_data_config) @property def input_types(self) -> Optional[Dict[str, NeuralType]]: @@ -664,20 +664,18 @@ def forward( return transf_log_probs, encoded_len, enc_states, enc_mask # PTL-specific methods - def training_step(self, batch, batch_nb): + def training_step(self, batch: PromptedAudioToTextMiniBatch, batch_nb): if batch is None: return torch.tensor([0.0]) - # During training prompt and prompt_len are null, ignore. - signal, signal_len, transcript, transcript_len, prompt, prompt_len = batch - input_ids, labels = transcript[:, :-1], transcript[:, 1:] + input_ids, labels = batch.get_decoder_inputs_outputs() transf_log_probs, encoded_len, enc_states, enc_mask = self.forward( - input_signal=signal, - input_signal_length=signal_len, + input_signal=batch.audio, + input_signal_length=batch.audio_lens, transcript=input_ids, - transcript_length=transcript_len, + transcript_length=batch.prompted_transcript_lens, ) audio_loss = self.loss(log_probs=transf_log_probs, labels=labels) @@ -689,16 +687,14 @@ def training_step(self, batch, batch_nb): return {'loss': audio_loss, 'log': tensorboard_logs} - def validation_pass(self, batch, batch_idx, dataloader_idx=0, eval_mode="val"): - # During inference, dataloader passes pure prompt without transcript text. - signal, signal_len, transcript, transcript_len, prompt, prompt_len = batch - input_ids, labels = transcript[:, :-1], transcript[:, 1:] + def validation_pass(self, batch: PromptedAudioToTextMiniBatch, batch_idx, dataloader_idx=0, eval_mode="val"): + input_ids, labels = batch.get_decoder_inputs_outputs() transf_log_probs, encoded_len, enc_states, enc_mask = self.forward( - input_signal=signal, - input_signal_length=signal_len, + input_signal=batch.audio, + input_signal_length=batch.audio_lens, transcript=input_ids, - transcript_length=transcript_len, + transcript_length=batch.prompted_transcript_lens, ) transf_loss = self.loss(log_probs=transf_log_probs, labels=labels) @@ -710,10 +706,10 @@ def validation_pass(self, batch, batch_idx, dataloader_idx=0, eval_mode="val"): self.wer.update( predictions=enc_states, predictions_lengths=encoded_len, - targets=transcript, - targets_lengths=transcript_len, + targets=batch.transcript, + targets_lengths=batch.transcript_lens, predictions_mask=enc_mask, - input_ids=prompt, + input_ids=batch.prompt, ) wer, wer_num, wer_denom = self.wer.compute() output_dict.update({"val_wer": wer, "val_wer_num": wer_num, "val_wer_denom": wer_denom}) @@ -722,10 +718,10 @@ def validation_pass(self, batch, batch_idx, dataloader_idx=0, eval_mode="val"): self.bleu.update( predictions=enc_states, predictions_lengths=encoded_len, - targets=transcript, - targets_lengths=transcript_len, + targets=batch.transcript, + targets_lengths=batch.transcript_lens, predictions_mask=enc_mask, - input_ids=prompt, + input_ids=batch.prompt, ) bleu_metrics = self.bleu.compute(prefix=f"{eval_mode}_") output_dict.update(bleu_metrics) @@ -823,7 +819,9 @@ def _transcribe_input_manifest_processing( return super()._transcribe_input_manifest_processing(audio_files, temp_dir, trcfg) - def _transcribe_forward(self, batch: Any, trcfg: MultiTaskTranscriptionConfig): + def _transcribe_forward( + self, batch: PromptedAudioToTextMiniBatch | tuple[torch.Tensor, ...], trcfg: MultiTaskTranscriptionConfig + ) -> dict: """ Internal function to perform the model's custom forward pass to return outputs that are processed by `_transcribe_output_processing()`. @@ -836,13 +834,25 @@ def _transcribe_forward(self, batch: Any, trcfg: MultiTaskTranscriptionConfig): Returns: The model's outputs that are processed by `_transcribe_output_processing()`. """ - log_probs, encoded_len, enc_states, enc_mask = self.forward( - input_signal=batch[0], input_signal_length=batch[1] - ) - if len(batch) == 6: - # Prompt provided by the dataloader. - decoder_input_ids = batch[4] + if isinstance(batch, PromptedAudioToTextMiniBatch): + # Handling regular Canary DataLoader + audio = batch.audio + audio_lens = batch.audio_lens + decoder_input_ids = batch.prompt else: + # Handling TensorDataset / external DataLoader + audio, audio_lens = batch[0], batch[1] + if len(batch) == 6: + # Prompt provided by the user. + decoder_input_ids = batch[4] + else: + # Prompt to be built dynamically. + decoder_input_ids = None + batch_size = audio.shape[0] + + log_probs, encoded_len, enc_states, enc_mask = self.forward(input_signal=audio, input_signal_length=audio_lens) + + if decoder_input_ids is None: # The dataloader provided only audio + audio_lens, so we # are constructing the prompt dynamically using TranscribeConfig. @@ -877,17 +887,17 @@ def _transcribe_forward(self, batch: Any, trcfg: MultiTaskTranscriptionConfig): decoder_input_ids = ( self.prompt.encode_dialog(turns=turns)["context_ids"] .unsqueeze(0) - .repeat(batch[0].shape[0], 1) + .repeat(batch_size, 1) .to(trcfg._internal.device) ) - output = dict( + + return dict( log_probs=log_probs, encoded_lengths=encoded_len, encoder_states=enc_states, encoder_mask=enc_mask, decoder_input_ids=decoder_input_ids, ) - return output def _transcribe_output_processing(self, outputs, trcfg: MultiTaskTranscriptionConfig) -> GenericTranscriptionType: """ @@ -918,19 +928,6 @@ def _transcribe_output_processing(self, outputs, trcfg: MultiTaskTranscriptionCo return_hypotheses=trcfg.return_hypotheses, ) - if trcfg.return_hypotheses: - for hyp in best_hypotheses: - hyp.text = self.decoding.strip_special_tokens(hyp.text) - if all_hypotheses is not None: - for i in range(len(all_hypotheses)): - for j in range(len(all_hypotheses[i])): - all_hypotheses[i][j].text = self.decoding.strip_special_tokens(all_hypotheses[i][j].text) - else: - best_hypotheses = [self.decoding.strip_special_tokens(text) for text in best_hypotheses] - if all_hypotheses is not None: - for i in range(len(all_hypotheses)): - all_hypotheses[i] = [self.decoding.strip_special_tokens(text) for text in all_hypotheses[i]] - del enc_states, enc_mask, decoder_input_ids if all_hypotheses is None: return best_hypotheses @@ -967,7 +964,7 @@ def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLo 'channel_selector': config.get('channel_selector', None), } - temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config), inference=True) + temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config)) return temporary_datalayer def _transcribe_on_end(self, trcfg: MultiTaskTranscriptionConfig): @@ -979,7 +976,7 @@ def _transcribe_on_end(self, trcfg: MultiTaskTranscriptionConfig): """ super()._transcribe_on_end(trcfg) - self.transf_decoder.unfreeze() + self.transf_decoder.unfreeze(partial=True) def _may_be_make_dict_and_fix_paths(self, json_items, manifest_path, trcfg: MultiTaskTranscriptionConfig): """ @@ -1002,13 +999,10 @@ def _may_be_make_dict_and_fix_paths(self, json_items, manifest_path, trcfg: Mult entry = { 'audio_filepath': item, 'duration': 100000, - trcfg.text_field: 'nothing', } elif isinstance(item, dict): entry = item entry['audio_filepath'] = get_full_path(entry['audio_filepath'], manifest_file=manifest_path) - if trcfg.text_field not in entry: - entry[trcfg.text_field] = 'nothing' else: raise ValueError(f"Expected str or dict, got {type(item)}") default_turn = [t for t in trcfg.prompt if t["role"] == "user"] @@ -1030,34 +1024,36 @@ def get_transcribe_config(cls) -> MultiTaskTranscriptionConfig: """ return MultiTaskTranscriptionConfig() - def predict_step(self, batch, batch_idx=0, dataloader_idx=0, has_processed_signal=False): - signal, signal_len, _, _, prompt, prompt_len = batch - - processed_signal = None - processed_signal_length = None + def predict_step( + self, batch: PromptedAudioToTextMiniBatch, batch_idx=0, dataloader_idx=0, has_processed_signal=False + ): if has_processed_signal: - processed_signal = signal - processed_signal_length = signal_len + processed_signal = batch.audio + processed_signal_length = batch.audio_lens signal = None signal_len = None + else: + processed_signal = None + processed_signal_length = None + signal = batch.audio + signal_len = batch.audio_lens transf_log_probs, encoded_len, enc_states, enc_mask = self.forward( input_signal=signal, input_signal_length=signal_len, processed_signal=processed_signal, processed_signal_length=processed_signal_length, - transcript=prompt, - transcript_length=prompt_len, + transcript=batch.prompt, + transcript_length=batch.prompt_lens, ) text = self.decoding.decode_predictions_tensor( encoder_hidden_states=enc_states, encoder_input_mask=enc_mask, - decoder_input_ids=prompt, + decoder_input_ids=batch.prompt, return_hypotheses=False, )[0] - text = [self.decoding.strip_special_tokens(t) for t in text] return text @property diff --git a/nemo/collections/asr/models/ctc_models.py b/nemo/collections/asr/models/ctc_models.py index b6d8945b6c6b..76233d57622b 100644 --- a/nemo/collections/asr/models/ctc_models.py +++ b/nemo/collections/asr/models/ctc_models.py @@ -665,20 +665,6 @@ def test_dataloader(self): """ Transcription related methods """ - def _transcribe_on_begin(self, audio, trcfg: TranscribeConfig): - super()._transcribe_on_begin(audio, trcfg) - - # Freeze the encoder and decoder modules - self.encoder.freeze() - self.decoder.freeze() - - def _transcribe_on_end(self, trcfg: TranscribeConfig): - super()._transcribe_on_end(trcfg) - - # Unfreeze the encoder and decoder modules - self.encoder.unfreeze() - self.decoder.unfreeze() - def _transcribe_forward(self, batch: Any, trcfg: TranscribeConfig): logits, logits_len, greedy_predictions = self.forward(input_signal=batch[0], input_signal_length=batch[1]) output = dict(logits=logits, logits_len=logits_len) diff --git a/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py b/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py index c7c09739be64..f161454c9bae 100644 --- a/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py +++ b/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py @@ -157,7 +157,7 @@ def _transcribe_on_end(self, trcfg: TranscribeConfig): super()._transcribe_on_end(trcfg) if hasattr(self, 'ctc_decoder'): - self.ctc_decoder.unfreeze() + self.ctc_decoder.unfreeze(partial=True) def _transcribe_forward(self, batch: Any, trcfg: TranscribeConfig): if self.cur_decoder == "rnnt": diff --git a/nemo/collections/asr/models/transformer_bpe_models.py b/nemo/collections/asr/models/transformer_bpe_models.py index 79de83f1d4a1..9970b4970236 100644 --- a/nemo/collections/asr/models/transformer_bpe_models.py +++ b/nemo/collections/asr/models/transformer_bpe_models.py @@ -633,4 +633,4 @@ def _transcribe_on_end(self, trcfg: TranscribeConfig): super()._transcribe_on_end(trcfg) # Unfreeze the encoder and decoder modules - self.transf_decoder.unfreeze() + self.transf_decoder.unfreeze(partial=True) diff --git a/nemo/collections/asr/modules/transformer/transformer_generators.py b/nemo/collections/asr/modules/transformer/transformer_generators.py index 1a38e7fa4b6c..e6775a48f635 100644 --- a/nemo/collections/asr/modules/transformer/transformer_generators.py +++ b/nemo/collections/asr/modules/transformer/transformer_generators.py @@ -15,7 +15,9 @@ from contextlib import contextmanager import torch +from torch.distributions import Categorical +from nemo.collections.asr.parts.submodules.token_classifier import TokenClassifier from nemo.collections.common.parts import NEG_INF, mask_padded_tokens __all__ = [ @@ -30,12 +32,13 @@ class GreedySequenceGenerator: """ Greedy sequence generator based on the decoder followed by log_softmax. + Optionally supports temperature sampling with ``n_samples`` and ``temperature`` options. Args: embedding: nn.Module, transforms input_ids into vector embeddings decoder: nn.Module, takes embeddings and produces hidden_states - log_softmax: nn.Module, takes hidden_states and produces log_probs - which correspond to probability distribution of tokens (ids) + classifier: nn.Module, takes hidden_states and produces + logits or log-probability distribution of tokens (ids) pad: index of padding token in the vocabulary bos: index of beginning of sequence token in the vocabulary eos: index of end of sequence token in the vocabulary @@ -45,28 +48,35 @@ class GreedySequenceGenerator: source sequences plus max_delta_length batch_size: size of the batch of generated sequences if neither source nor target starting sequences are provided + n_samples: number of sequences to generate (requires ``temperature`` to be set) + temperature: temperature for temperature sampling. Even with ``n_samples`` set to 1, + enabling temperature will sample hypotheses instead of returning the best ones. """ def __init__( self, embedding, decoder, - log_softmax, + classifier: TokenClassifier, pad=0, bos=1, eos=2, max_sequence_length=512, max_delta_length=20, batch_size=1, + n_samples=1, + temperature=None, ): super().__init__() self.embedding = embedding self.decoder = decoder - self.log_softmax = log_softmax + self.classifier = classifier self.pad, self.bos, self.eos = pad, bos, eos self.max_seq_length = max_sequence_length self.max_delta_len = max_delta_length self.batch_size = batch_size + self.n_samples = n_samples + self.temperature = temperature def _one_step_forward( self, @@ -75,6 +85,7 @@ def _one_step_forward( encoder_input_mask=None, decoder_mems_list=None, pos=0, + return_scores: bool = True, ): """ One step of autoregressive output generation. @@ -107,8 +118,9 @@ def _one_step_forward( decoder_mems_list = self.decoder.forward( decoder_hidden_states, decoder_input_mask, decoder_mems_list, return_mems=True ) - log_probs = self.log_softmax.forward(hidden_states=decoder_mems_list[-1][:, -1:]) - return log_probs, decoder_mems_list + with self.classifier.with_log_softmax_enabled(return_scores) as clf: + logits = clf.forward(hidden_states=decoder_mems_list[-1][:, -1:]) + return logits, decoder_mems_list def _prepare_for_search(self, decoder_input_ids=None, encoder_hidden_states=None): """ @@ -145,30 +157,57 @@ def _forward( self, decoder_input_ids=None, encoder_hidden_states=None, encoder_input_mask=None, return_beam_scores=False ): assert not return_beam_scores + is_sampling = self.temperature is not None and self.n_samples > 1 + tgt, batch_size, max_generation_length = self._prepare_for_search(decoder_input_ids, encoder_hidden_states) + if is_sampling: + tgt = torch.repeat_interleave(tgt, self.n_samples, dim=0) + encoder_hidden_states = torch.repeat_interleave(encoder_hidden_states, self.n_samples, dim=0) + encoder_input_mask = torch.repeat_interleave(encoder_input_mask, self.n_samples, dim=0) + orig_batch_size = batch_size + batch_size = batch_size * self.n_samples # pad profile tracks sequences ending with token to replace # everything after with token decoder_parameter = next(self.decoder.parameters()) - pad_profile = torch.zeros(batch_size, 1).long().to(decoder_parameter.device) + pad_profile = torch.zeros(batch_size).long().to(decoder_parameter.device) decoder_mems_list = None for i in range(max_generation_length): - log_probs, decoder_mems_list = self._one_step_forward( - tgt[:, -1:], encoder_hidden_states, encoder_input_mask, decoder_mems_list, i + if i == 0: + input_ids = tgt + else: + input_ids = tgt[:, -1:] + + logits, decoder_mems_list = self._one_step_forward( + input_ids, + encoder_hidden_states, + encoder_input_mask, + decoder_mems_list, + i, + return_scores=return_beam_scores, ) - next_tokens = torch.argmax(log_probs[:, -1], dim=-1, keepdim=True) + if self.temperature is None: # Greedy decoding + next_tokens = torch.argmax(logits[:, -1], dim=-1) + else: # Temperature sampling + next_tokens = Categorical(logits=logits[:, -1] / self.temperature).sample() + next_tokens = self.pad * pad_profile + next_tokens * (1 - pad_profile) pad_profile = torch.max(pad_profile, (next_tokens == self.eos).long()) - tgt = torch.cat((tgt, next_tokens), dim=-1) + tgt = torch.cat((tgt, next_tokens.unsqueeze(1)), dim=-1) # abort generation if all sequences end with if pad_profile.sum() == batch_size: break - return tgt + samples = None + if is_sampling: + samples = list(tgt.view(orig_batch_size, self.n_samples, -1)) + tgt = tgt[:: self.n_samples] + + return tgt, samples def __call__( self, decoder_input_ids=None, encoder_hidden_states=None, encoder_input_mask=None, return_beam_scores=False @@ -195,9 +234,9 @@ def freeze(self) -> None: for param in self.decoder.parameters(): param.requires_grad = False self.decoder.eval() - for param in self.log_softmax.parameters(): + for param in self.classifier.parameters(): param.requires_grad = False - self.log_softmax.eval() + self.classifier.eval() def unfreeze(self) -> None: """Unfreeze weights of embedding, decoder, and classification layers.""" @@ -207,14 +246,14 @@ def unfreeze(self) -> None: for param in self.decoder.parameters(): param.requires_grad = True self.decoder.train() - for param in self.log_softmax.parameters(): + for param in self.classifier.parameters(): param.requires_grad = True - self.log_softmax.train() + self.classifier.train() @contextmanager def as_frozen(self): """ - Context manager which temporarily freezes embedding, decoder, and log_softmax modules, + Context manager which temporarily freezes embedding, decoder, and classifier modules, yields control and finally unfreezes the modules. """ self.freeze() @@ -252,9 +291,15 @@ def _one_step_forward( encoder_input_mask=None, decoder_mems_list=None, pos=0, + return_scores: bool = True, ): log_probs, decoder_mems_list = super()._one_step_forward( - decoder_input_ids, encoder_hidden_states, encoder_input_mask, decoder_mems_list, pos + decoder_input_ids, + encoder_hidden_states, + encoder_input_mask, + decoder_mems_list, + pos, + return_scores=return_scores, ) batch_size, seq_len, vocab_size = log_probs.size() diff --git a/nemo/collections/asr/parts/mixins/transcription.py b/nemo/collections/asr/parts/mixins/transcription.py index b6238cad4534..84dd4aed9bce 100644 --- a/nemo/collections/asr/parts/mixins/transcription.py +++ b/nemo/collections/asr/parts/mixins/transcription.py @@ -17,7 +17,7 @@ import tempfile from abc import ABC, abstractmethod from collections.abc import Iterable -from dataclasses import dataclass +from dataclasses import dataclass, fields, is_dataclass from functools import partial from typing import Any, Dict, List, Optional, Tuple, Union @@ -69,13 +69,18 @@ class TranscribeConfig: def move_to_device(batch, device, non_blocking=False): """ Recursively move all tensors in `batch` to `device`. + Supports tensors, lists, tuples, dictionaries, and dataclasses. """ if isinstance(batch, torch.Tensor): return batch.to(device, non_blocking=non_blocking) elif isinstance(batch, (list, tuple)): - return [move_to_device(x, device, non_blocking) for x in batch] + return type(batch)(move_to_device(x, device, non_blocking) for x in batch) elif isinstance(batch, dict): return {k: move_to_device(v, device, non_blocking) for k, v in batch.items()} + elif is_dataclass(batch): + return type(batch)( + **{field.name: move_to_device(getattr(batch, field.name), device, non_blocking) for field in fields(batch)} + ) else: return batch # do nothing if not supported type @@ -770,13 +775,13 @@ def _transcribe_on_end(self, trcfg: TranscribeConfig): # Unfreeze the encoder and decoder modules if hasattr(self, 'encoder'): - self.encoder.unfreeze() + self.encoder.unfreeze(partial=True) if hasattr(self, 'decoder'): - self.decoder.unfreeze() + self.decoder.unfreeze(partial=True) if hasattr(self, 'joint'): - self.joint.unfreeze() + self.joint.unfreeze(partial=True) @classmethod def get_transcribe_config(cls) -> TranscribeConfig: diff --git a/nemo/collections/asr/parts/submodules/multitask_beam_decoding.py b/nemo/collections/asr/parts/submodules/multitask_beam_decoding.py index c6dc28a47480..de2d63cd99de 100644 --- a/nemo/collections/asr/parts/submodules/multitask_beam_decoding.py +++ b/nemo/collections/asr/parts/submodules/multitask_beam_decoding.py @@ -102,8 +102,7 @@ class TransformerAEDBeamInfer(AEDBeamInfer, Typing): @property def input_types(self): - """Returns definitions of module input ports. - """ + """Returns definitions of module input ports.""" # Input can be of dimention - # ('B', 'T', 'D') [Log probs] or ('B', 'T') [Labels] @@ -116,8 +115,7 @@ def input_types(self): @property def output_types(self): - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return {"predictions": [NeuralType(elements_type=HypothesisType())]} def __init__( @@ -141,15 +139,18 @@ def __init__( preserve_alignments=preserve_alignments, ) self.beam_size = beam_size + self.bos = tokenizer.bos + self.pad = tokenizer.pad + self.eos = tokenizer.eos self.beam_search = BeamSearchSequenceGenerator( embedding=transformer_decoder.embedding, decoder=transformer_decoder.decoder, log_softmax=log_softmax_module, max_sequence_length=transformer_decoder.max_sequence_length, beam_size=beam_size, - bos=tokenizer.bos_id, - pad=tokenizer.pad_id, - eos=tokenizer.eos_id, + bos=self.bos, + pad=self.pad, + eos=self.eos, len_pen=length_penalty, max_delta_length=max_generation_delta, ) @@ -196,9 +197,9 @@ def forward( for i in range(len(topk_hypotheses)): hypotheses = [Hypothesis(score=0.0, y_sequence=[], timestep=[]) for _ in range(self.beam_size)] # Pack results into Hypotheses - packed_result.append( - NBestHypotheses(pack_hypotheses(hypotheses, topk_hypotheses[i], beam_scores[i])) - ) + hypotheses = pack_hypotheses(hypotheses, topk_hypotheses[i], beam_scores[i]) + self.format_hypotheses(hypotheses, decoder_input_ids) + packed_result.append(NBestHypotheses(hypotheses)) else: beam_scores = [None for _ in range(len(best_hypo))] best_hypo = best_hypo.detach().cpu() @@ -207,9 +208,38 @@ def forward( ] # Pack results into Hypotheses packed_result = pack_hypotheses(hypotheses, best_hypo, beam_scores) + self.format_hypotheses(packed_result, decoder_input_ids) return (packed_result,) + def format_hypotheses(self, packed_result: List[Hypothesis], decoder_input_ids: torch.Tensor | None) -> None: + """ + For each hypothesis in the mini-batch: + * Remove the decoder input ids (prompt) from the predictions + * Remove BOS, EOS, and PAD ids from the predictions. + Modifies results in-place. + """ + if decoder_input_ids is not None: + assert ( + len(packed_result) == decoder_input_ids.shape[0] + ), f"Mismatching number of examples {len(packed_result)=} {decoder_input_ids.shape[0]=}" + decoder_input_ids = decoder_input_ids.detach().cpu() + for hyp, prefix in zip(packed_result, decoder_input_ids): + assert ( + hyp.y_sequence[: prefix.shape[0]] == prefix + ).all(), f"The decoder input IDs were not found at the beginning of prediction: {hyp.y_sequence=} {prefix=})" + hyp.y_sequence = hyp.y_sequence[prefix.shape[0] :] + for hyp in packed_result: + ids = hyp.y_sequence + ids_len = ids.shape[0] + pos = -1 + while ids[pos] == self.pad or ids[pos] == self.eos: + pos -= 1 + if ids_len + pos == -1: + break # empty sequence + if pos < -1: + hyp.y_sequence = ids[: pos + 1] + @dataclass class AEDBeamInferConfig: diff --git a/nemo/collections/asr/parts/submodules/multitask_decoding.py b/nemo/collections/asr/parts/submodules/multitask_decoding.py index c336ae7d4170..715ee7168037 100644 --- a/nemo/collections/asr/parts/submodules/multitask_decoding.py +++ b/nemo/collections/asr/parts/submodules/multitask_decoding.py @@ -25,6 +25,10 @@ AEDBeamInferConfig, TransformerAEDBeamInfer, ) +from nemo.collections.asr.parts.submodules.multitask_greedy_decoding import ( + AEDGreedyInferConfig, + TransformerAEDGreedyInfer, +) from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis, NBestHypotheses from nemo.collections.common.tokenizers.aggregate_tokenizer import AggregateTokenizer from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec @@ -60,11 +64,9 @@ class AbstractMultiTaskDecoding(ABC): The config may further contain the following sub-dictionaries: "greedy": - max_symbols: int, describing the maximum number of target tokens to decode per - timestep during greedy decoding. Setting to larger values allows longer sentences - to be decoded, at the cost of increased execution time. - preserve_frame_confidence: Same as above, overrides above value. - confidence_method_cfg: Same as above, overrides confidence_cfg.method_cfg. + temperature: None (disabled) or float, specifying this enables temperature sampling instead of greedy decoding. + max_generation_delta: int = -1 # -1 means up to the max length of the decoder + preserve_alignments: bool = False (unsupported) "beam": beam_size: int, defining the beam size for beam search. Must be >= 1. @@ -103,34 +105,47 @@ def __init__( self.preserve_alignments = self.cfg.get('preserve_alignments', None) self.compute_langs = self.cfg.get('compute_langs', False) self.compute_hypothesis_token_set = self.cfg.get('compute_hypothesis_token_set', False) + self.transformer_decoder = transformer_decoder + self.log_softmax_module = log_softmax_module + self.tokenizer = tokenizer + self.change_strategy(self.cfg.strategy) + + def change_strategy(self, strategy: str) -> "AbstractMultiTaskDecoding": possible_strategies = ['greedy', 'greedy_batch', 'beam'] - if self.cfg.strategy not in possible_strategies: - raise ValueError(f"Decoding strategy must be one of {possible_strategies}") + if strategy not in possible_strategies: + raise ValueError(f"Decoding strategy must be one of {possible_strategies}" f"but was provided {strategy}") # Update preserve alignments if self.preserve_alignments is None: - if self.cfg.strategy in ['greedy', 'greedy_batch']: + if strategy in ['greedy', 'greedy_batch']: self.preserve_alignments = self.cfg.greedy.get('preserve_alignments', False) - elif self.cfg.strategy in ['beam']: + elif strategy in ['beam']: self.preserve_alignments = self.cfg.beam.get('preserve_alignments', False) - if self.cfg.strategy == 'greedy' or self.cfg.strategy == 'greedy_batch': + if strategy in ['greedy', 'greedy_batch']: - # self.decoding = None - raise NotImplementedError("Greedy decoding is not implemented yet.") + self.decoding = TransformerAEDGreedyInfer( + transformer_decoder=self.transformer_decoder, + log_softmax_module=self.log_softmax_module, + tokenizer=self.tokenizer, + max_generation_delta=self.cfg.greedy.get('max_generation_delta', -1), + preserve_alignments=self.preserve_alignments, + temperature=self.cfg.greedy.temperature, + n_samples=self.cfg.greedy.n_samples, + ) - elif self.cfg.strategy == 'beam': + elif strategy == 'beam': self.decoding = TransformerAEDBeamInfer( - transformer_decoder=transformer_decoder, - log_softmax_module=log_softmax_module, - tokenizer=tokenizer, + transformer_decoder=self.transformer_decoder, + log_softmax_module=self.log_softmax_module, + tokenizer=self.tokenizer, search_type=self.cfg.beam.get('search_type', 'default'), beam_size=self.cfg.beam.beam_size, length_penalty=self.cfg.beam.get('length_penalty', 0.0), - max_generation_delta=self.cfg.beam.get('max_generation_delta', 50), + max_generation_delta=self.cfg.beam.get('max_generation_delta', -1), return_best_hypothesis=self.cfg.beam.get('return_best_hypothesis', True), preserve_alignments=self.preserve_alignments, ) @@ -139,7 +154,7 @@ def __init__( raise ValueError( f"Incorrect decoding strategy provided. Must be one of {possible_strategies}\n" - f"but was provided {self.cfg.strategy}" + f"but was provided {strategy}" ) def decode_predictions_tensor( @@ -295,17 +310,6 @@ def decode_ids_to_langs(self, tokens: List[int]) -> List[str]: """ raise NotImplementedError() - def strip_special_tokens(self, text: str): - """ - assuming all special tokens are of format - Note that if any label/pred is of format , it will be stripped - """ - assert isinstance(text, str), f"Expected str, got {type(text)}" - text = re.sub(r'<[^>]+>', '', text) - # strip spaces at the beginning and end; - # this is training data artifact, will be fixed in future (@kpuvvada) - return text.strip() - class MultiTaskDecoding(AbstractMultiTaskDecoding): """ @@ -476,9 +480,7 @@ class MultiTaskDecodingConfig: compute_langs: bool = False # greedy decoding config - # greedy: rnnt_greedy_decoding.GreedyBatchedRNNTInferConfig = field( - # default_factory=rnnt_greedy_decoding.GreedyBatchedRNNTInferConfig - # ) + greedy: AEDGreedyInferConfig = field(default_factory=AEDGreedyInferConfig) # beam decoding config beam: AEDBeamInferConfig = field(default_factory=lambda: AEDBeamInferConfig(beam_size=1)) diff --git a/nemo/collections/asr/parts/submodules/multitask_greedy_decoding.py b/nemo/collections/asr/parts/submodules/multitask_greedy_decoding.py new file mode 100644 index 000000000000..891d003bd001 --- /dev/null +++ b/nemo/collections/asr/parts/submodules/multitask_greedy_decoding.py @@ -0,0 +1,242 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import List, Optional + +import torch + +from nemo.collections.asr.modules.transformer import GreedySequenceGenerator +from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis, NBestHypotheses +from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec +from nemo.core import Typing, typecheck +from nemo.core.neural_types import ChannelType, HypothesisType, LabelsType, MaskType, NeuralType +from nemo.utils import logging + + +def pack_hypotheses( + hypotheses: List[Hypothesis], beam_hypotheses: torch.Tensor, scores: List[Optional[float]] +) -> List[Hypothesis]: + + for idx, hyp in enumerate(hypotheses): # type: Hypothesis + if scores[idx] is not None: + hyp.score = scores[idx] + + hypi = beam_hypotheses[idx] + if torch.is_tensor(hypi): + hyp.y_sequence = hypi.long() + else: + hyp.y_sequence = torch.tensor(hypi, dtype=torch.long) + + if hyp.dec_state is not None: + hyp.dec_state = _states_to_device(hyp.dec_state) + + return hypotheses + + +def _states_to_device(dec_state, device='cpu'): + if torch.is_tensor(dec_state): + dec_state = dec_state.to(device) + + elif isinstance(dec_state, (list, tuple)): + dec_state = tuple(_states_to_device(dec_i, device) for dec_i in dec_state) + + return dec_state + + +class AEDGreedyInfer(ABC): + def __init__( + self, + transformer_decoder: torch.nn.Module, + log_softmax_module: torch.nn.Module, + tokenizer: TokenizerSpec, + search_type: str = 'default', + preserve_alignments: bool = False, + ): + super().__init__() + + self.transformer_decoder = transformer_decoder + self.log_softmax_module = log_softmax_module + self.tokenizer = tokenizer + self.search_type = search_type + + self.preserve_alignments = preserve_alignments + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + @abstractmethod + def forward( + self, + encoder_hidden_states: torch.Tensor, + encoder_input_mask: torch.Tensor, + decoder_input_ids: Optional[torch.Tensor] = None, + partial_hypotheses: Optional[List[Hypothesis]] = None, + ): + raise NotImplementedError() + + def set_decoding_type(self, decoding_type: str): + self.decoding_type = decoding_type + + +class TransformerAEDGreedyInfer(AEDGreedyInfer, Typing): + """ + A greedy decoder engine for AED Transformer models with support for temperature sampling. + """ + + @property + def input_types(self): + """Returns definitions of module input ports.""" + # Input can be of dimention - + # ('B', 'T', 'D') [Log probs] or ('B', 'T') [Labels] + + return { + "encoder_hidden_states": NeuralType(tuple(('B', 'T', 'D')), ChannelType()), + "encoder_input_mask": NeuralType(tuple(('B', 'T')), MaskType()), + "decoder_input_ids": NeuralType(('B', 'T'), LabelsType()), + "partial_hypotheses": NeuralType(optional=True), + } + + @property + def output_types(self): + """Returns definitions of module output ports.""" + return {"predictions": [NeuralType(elements_type=HypothesisType())]} + + def __init__( + self, + transformer_decoder: torch.nn.Module, + log_softmax_module: torch.nn.Module, + tokenizer: TokenizerSpec, + temperature: float | None = None, + max_generation_delta: int = 50, + preserve_alignments: bool = False, + n_samples: int = 1, + ): + super().__init__( + transformer_decoder=transformer_decoder, + log_softmax_module=log_softmax_module, + tokenizer=tokenizer, + preserve_alignments=preserve_alignments, + ) + self.temperature = temperature + self.n_samples = n_samples + self.bos = tokenizer.bos + self.pad = tokenizer.pad + self.eos = tokenizer.eos + self.greedy_search = GreedySequenceGenerator( + embedding=transformer_decoder.embedding, + decoder=transformer_decoder.decoder, + classifier=log_softmax_module, + max_sequence_length=transformer_decoder.max_sequence_length, + bos=self.bos, + pad=self.pad, + eos=self.eos, + max_delta_length=max_generation_delta, + temperature=self.temperature, + n_samples=n_samples, + ) + + self.preserve_alignments = preserve_alignments + if self.preserve_alignments: + logging.info( + "Preservation of alignments was requested but {} does not implement it.".format( + self.__class__.__name__ + ) + ) + + @typecheck() + def forward( + self, + encoder_hidden_states: torch.Tensor, + encoder_input_mask: torch.Tensor, + decoder_input_ids: Optional[torch.Tensor] = None, + partial_hypotheses: Optional[List[Hypothesis]] = None, + ): + """Returns a list of hypotheses given an input batch of the encoder hidden embedding. + Output token is generated auto-repressively. + + Args: + decoder_output: A tensor of size (batch, timesteps, features) or (batch, timesteps) (each timestep is a label). + decoder_lengths: list of int representing the length of each sequence + output sequence. + + Returns: + packed list containing batch number of sentences (Hypotheses). + """ + with torch.inference_mode(): + best_hypo, topk_hypotheses = self.greedy_search( + encoder_hidden_states=encoder_hidden_states, + encoder_input_mask=encoder_input_mask, + decoder_input_ids=decoder_input_ids, + ) + + if topk_hypotheses is not None: + topk_hypotheses = [x.detach().cpu() for x in topk_hypotheses] # each item is [beam, seq_len] + beam_scores = [[None] * self.n_samples for _ in topk_hypotheses] # each item is [beam,] + packed_result = [] + for i in range(len(topk_hypotheses)): + # Pack results into Hypotheses + hypotheses = [Hypothesis(score=0.0, y_sequence=[], timestep=[]) for _ in range(self.n_samples)] + self.format_hypotheses(hypotheses, decoder_input_ids) + packed_result.append( + NBestHypotheses(pack_hypotheses(hypotheses, topk_hypotheses[i], beam_scores[i])) + ) + else: + beam_scores = [None for _ in range(len(best_hypo))] + best_hypo = best_hypo.cpu() + hypotheses = [ + Hypothesis(score=0.0, y_sequence=[], timestep=[]) for _ in range(encoder_hidden_states.shape[0]) + ] + # Pack results into Hypotheses + packed_result = pack_hypotheses(hypotheses, best_hypo, beam_scores) + self.format_hypotheses(packed_result, decoder_input_ids) + + return (packed_result,) + + def format_hypotheses(self, packed_result: List[Hypothesis], decoder_input_ids: torch.Tensor | None) -> None: + """ + For each hypothesis in the mini-batch: + * Remove the decoder input ids (prompt) from the predictions + * Remove BOS, EOS, and PAD ids from the predictions. + Modifies results in-place. + """ + if decoder_input_ids is not None: + assert ( + len(packed_result) == decoder_input_ids.shape[0] + ), f"Mismatching number of examples {len(packed_result)=} {decoder_input_ids.shape[0]=}" + decoder_input_ids = decoder_input_ids.detach().cpu() + for hyp, prefix in zip(packed_result, decoder_input_ids): + assert ( + hyp.y_sequence[: prefix.shape[0]] == prefix + ).all(), f"The decoder input IDs were not found at the beginning of prediction: {hyp.y_sequence=} {prefix=})" + hyp.y_sequence = hyp.y_sequence[prefix.shape[0] :] + for hyp in packed_result: + ids = hyp.y_sequence + ids_len = ids.shape[0] + pos = -1 + while ids[pos] == self.pad or ids[pos] == self.eos: + pos -= 1 + if ids_len + pos == -1: + break # empty sequence + if pos < -1: + hyp.y_sequence = ids[: pos + 1] + + +@dataclass +class AEDGreedyInferConfig: + temperature: float | None = None + max_generation_delta: int = -1 # -1 means up to the max length of the decoder + preserve_alignments: bool = False + n_samples: int = 1 diff --git a/nemo/collections/asr/parts/submodules/token_classifier.py b/nemo/collections/asr/parts/submodules/token_classifier.py index 4061d19d9015..cc435308fcae 100644 --- a/nemo/collections/asr/parts/submodules/token_classifier.py +++ b/nemo/collections/asr/parts/submodules/token_classifier.py @@ -11,16 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from contextlib import contextmanager from dataclasses import dataclass from typing import Dict, Optional +import torch from torch import nn as nn from nemo.collections.asr.parts.submodules.classifier import Classifier from nemo.collections.common.parts import MultiLayerPerceptron from nemo.core.classes import typecheck -from nemo.core.neural_types import LogitsType, LogprobsType, NeuralType +from nemo.core.neural_types import ChannelType, FloatType, LogitsType, LogprobsType, NeuralType __all__ = ['BertPretrainingTokenClassifier', 'TokenClassifier'] @@ -42,11 +43,17 @@ class TokenClassifier(Classifier): """ @property - def output_types(self) -> Optional[Dict[str, NeuralType]]: + def input_types(self) -> Dict[str, NeuralType]: + return { + "hidden_states": NeuralType(('B', 'T', 'D'), ChannelType()), + } + + @property + def output_types(self) -> Dict[str, NeuralType]: """ Returns definitions of module output ports. """ - if not self.log_softmax: + if not self.mlp.log_softmax: return {"logits": NeuralType(('B', 'T', 'C'), LogitsType())} else: return {"log_probs": NeuralType(('B', 'T', 'C'), LogprobsType())} @@ -61,7 +68,6 @@ def __init__( dropout: float = 0.0, use_transformer_init: bool = True, ) -> None: - """ Initializes the Token Classifier module. @@ -75,14 +81,24 @@ def __init__( use_transformer_init: whether to initialize the weights of the classifier head with the same approach used in Transformer """ super().__init__(hidden_size=hidden_size, dropout=dropout) - self.log_softmax = log_softmax self.mlp = MultiLayerPerceptron( hidden_size, num_classes, num_layers=num_layers, activation=activation, log_softmax=log_softmax ) self.post_init(use_transformer_init=use_transformer_init) + @property + def log_softmax(self) -> bool: + return self.mlp.log_softmax + + @contextmanager + def with_log_softmax_enabled(self, value: bool) -> "TokenClassifier": + prev = self.mlp.log_softmax + self.mlp.log_softmax = value + yield self + self.mlp.log_softmax = prev + @typecheck() - def forward(self, hidden_states): + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """ Performs the forward step of the module. Args: @@ -100,12 +116,18 @@ class BertPretrainingTokenClassifier(Classifier): A module to perform token level classification tasks for Bert pretraining. """ + @property + def input_types(self) -> Dict[str, NeuralType]: + return { + "hidden_states": NeuralType(('B', 'T', 'D'), ChannelType()), + } + @property def output_types(self) -> Optional[Dict[str, NeuralType]]: """ Returns definitions of module output ports. """ - if not self.log_softmax: + if not self.mlp.log_softmax: return {"logits": NeuralType(('B', 'T', 'C'), LogitsType())} else: return {"log_probs": NeuralType(('B', 'T', 'C'), LogprobsType())} @@ -120,7 +142,6 @@ def __init__( dropout: float = 0.0, use_transformer_init: bool = True, ) -> None: - """ Initializes the Token Classifier module. @@ -135,8 +156,6 @@ def __init__( """ super().__init__(hidden_size=hidden_size, dropout=dropout) - self.log_softmax = log_softmax - if activation not in ACT2FN: raise ValueError(f'activation "{activation}" not found') self.dense = nn.Linear(hidden_size, hidden_size) @@ -147,8 +166,19 @@ def __init__( ) self.post_init(use_transformer_init=use_transformer_init) + @property + def log_softmax(self) -> bool: + return self.mlp.log_softmax + + @contextmanager + def with_log_softmax_enabled(self, value: bool) -> "TokenClassifier": + prev = self.mlp.log_softmax + self.mlp.log_softmax = value + yield self + self.mlp.log_softmax = prev + @typecheck() - def forward(self, hidden_states): + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """ Performs the forward step of the module. Args: diff --git a/nemo/collections/asr/parts/utils/transcribe_utils.py b/nemo/collections/asr/parts/utils/transcribe_utils.py index c270e5c3a0f7..c26fa6f4984d 100644 --- a/nemo/collections/asr/parts/utils/transcribe_utils.py +++ b/nemo/collections/asr/parts/utils/transcribe_utils.py @@ -289,7 +289,7 @@ def prepare_audio_data(cfg: DictConfig) -> Tuple[List[str], bool]: with open(cfg.dataset_manifest, "rt") as fh: for line in fh: item = json.loads(line) - item["audio_filepath"] = get_full_path(item["audio_filepath"], cfg.dataset_manifest) + item[audio_key] = get_full_path(item[audio_key], cfg.dataset_manifest) if item.get("duration") is None and cfg.presort_manifest: raise ValueError( f"Requested presort_manifest=True, but line {line} in manifest {cfg.dataset_manifest} lacks a 'duration' field." diff --git a/nemo/collections/audio/losses/__init__.py b/nemo/collections/audio/losses/__init__.py index b2968b7b1ad0..f4a1a42ff20b 100644 --- a/nemo/collections/audio/losses/__init__.py +++ b/nemo/collections/audio/losses/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from nemo.collections.audio.losses.audio import MSELoss, SDRLoss +from nemo.collections.audio.losses.audio import MAELoss, MSELoss, SDRLoss diff --git a/nemo/collections/audio/losses/audio.py b/nemo/collections/audio/losses/audio.py index 635b02c5d1fe..ce6b82875e6b 100644 --- a/nemo/collections/audio/losses/audio.py +++ b/nemo/collections/audio/losses/audio.py @@ -584,3 +584,168 @@ def forward( mse = self.reduce(mse) return mse + + +def calculate_mae_batch( + estimate: torch.Tensor, + target: torch.Tensor, + input_length: Optional[torch.Tensor] = None, + mask: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """Calculate mean absolute error (MAE) per channel. + + MAE = ||estimate - target||_1 / input_length + + Args: + estimate: estimated signal, shape (B, C, T) or (B, C, D, T) + target: target signal, shape (B, C, T) or (B, C, D, T) + input_length: Optional, length of valid samples, shape (B,) + mask: Optional, temporal mask, same shape as signals + + Returns: + MAE for each channel, shape (B, C) + """ + assert ( + estimate.shape == target.shape + ), f'Estimate shape ({estimate.shape}) not matching target shape ({target.shape})' + + if input_length is not None: + if mask is not None: + raise RuntimeError( + 'Argument `input_length` is mutually exclusive with `mask`. Both cannot be used at the same time.' + ) + + # Construct a binary mask + mask = make_seq_mask_like(lengths=input_length, like=estimate, time_dim=-1, valid_ones=True) + mask = mask.expand_as(estimate) + + # error + err = estimate - target + + # dimensions for averaging + if estimate.ndim == 3: + # average across time + dim = -1 + elif estimate.ndim == 4: + # average across time and features + dim = (-2, -1) + else: + raise RuntimeError(f'Unexpected dimension of the input: {estimate.shape}') + + # calculate masked mean + mse = calculate_mean(torch.abs(err), mask=mask, dim=dim) + + return mse + + +class MAELoss(Loss, Typing): + """ + Computes the mean absolute error (MAE) loss with weighted average across channels. + + Args: + weight: weight for loss of each output channel, used for averaging the loss across channels. Defaults to `None` (averaging). + reduction: batch reduction. Defaults to `mean` over the batch. + ndim: Number of dimensions for the input signal + """ + + def __init__( + self, + weight: Optional[List[float]] = None, + reduction: str = 'mean', + ndim: int = 3, + ): + super().__init__() + + # weight buffer + if weight is not None: + if any([w <= 0 for w in weight]): + raise ValueError(f'Weight must be positive! Current value: {weight}') + elif not np.isclose(sum(weight), 1, atol=1e-6): + raise ValueError(f'Weight should add to one, current weight: {weight}') + weight = torch.tensor(weight).reshape(1, -1) + logging.info(f'Channel weight set to %s', weight) + self.register_buffer('weight', weight) + self.weight: Optional[Tensor] + + # Batch reduction + self.reduction = reduction + if reduction == 'mean': + self.reduce = torch.mean + else: + raise ValueError(f'Unexpected reduction mode {reduction}.') + + # Input dimension + self.ndim = ndim + + if self.ndim == 3: + # Time-domain input + self.signal_shape = ('B', 'C', 'T') + elif self.ndim == 4: + # Spectral-domain input + self.signal_shape = ('B', 'C', 'D', 'T') + else: + raise ValueError(f'Unexpected input dimension: {self.ndim}') + + logging.debug('Initialized %s with', self.__class__.__name__) + logging.debug('\tweight: %s', self.weight) + logging.debug('\treduction: %s', self.reduction) + logging.debug('\tndim: %s', self.ndim) + logging.debug('\tsignal_shape: %s', self.signal_shape) + + @property + def input_types(self): + """Input types definitions for MAELoss.""" + return { + "estimate": NeuralType(self.signal_shape, VoidType()), + "target": NeuralType(self.signal_shape, VoidType()), + "input_length": NeuralType(tuple('B'), LengthsType(), optional=True), + "mask": NeuralType(self.signal_shape, MaskType(), optional=True), + } + + @property + def output_types(self): + """Output types definitions for MAELoss. + loss: + NeuralType(None) + """ + return {"loss": NeuralType(elements_type=LossType())} + + @typecheck() + def forward( + self, + estimate: torch.Tensor, + target: torch.Tensor, + input_length: Optional[torch.Tensor] = None, + mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """For input batch of multi-channel signals, calculate MAE between estimate and target for each channel, + perform averaging across channels (weighting optional), and apply reduction across the batch. + + Args: + estimate: Estimate of the target signal + target: Target signal + input_length: Length of each example in the batch + mask: Mask for each signal + + Returns: + Scalar loss. + """ + mae = calculate_mae_batch( + estimate=estimate, + target=target, + input_length=input_length, + mask=mask, + ) + + # channel averaging + if self.weight is None: + mae = torch.mean(mae, dim=1) + else: + # weighting across channels + mae = mae * self.weight + mae = torch.sum(mae, dim=1) + + # reduction + mae = self.reduce(mae) + + return mae diff --git a/nemo/collections/audio/metrics/__init__.py b/nemo/collections/audio/metrics/__init__.py index d9155f923f18..20c8fd2fa4e2 100644 --- a/nemo/collections/audio/metrics/__init__.py +++ b/nemo/collections/audio/metrics/__init__.py @@ -11,3 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from nemo.collections.audio.metrics.audio import AudioMetricWrapper +from nemo.collections.audio.metrics.squim import SquimMOSMetric, SquimObjectiveMetric diff --git a/nemo/collections/audio/metrics/audio.py b/nemo/collections/audio/metrics/audio.py index 096700eff24a..0f8b5bee0fd2 100644 --- a/nemo/collections/audio/metrics/audio.py +++ b/nemo/collections/audio/metrics/audio.py @@ -21,6 +21,7 @@ from torchmetrics.audio.sdr import ScaleInvariantSignalDistortionRatio, SignalDistortionRatio from torchmetrics.audio.snr import ScaleInvariantSignalNoiseRatio, SignalNoiseRatio from torchmetrics.audio.stoi import ShortTimeObjectiveIntelligibility +from nemo.collections.audio.metrics.squim import SquimMOSMetric, SquimObjectiveMetric from nemo.utils import logging @@ -34,6 +35,8 @@ SignalNoiseRatio, PerceptualEvaluationSpeechQuality, ShortTimeObjectiveIntelligibility, + SquimMOSMetric, + SquimObjectiveMetric, ] diff --git a/nemo/collections/audio/metrics/squim.py b/nemo/collections/audio/metrics/squim.py new file mode 100644 index 000000000000..c20be43f79f8 --- /dev/null +++ b/nemo/collections/audio/metrics/squim.py @@ -0,0 +1,197 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import torch +from torchmetrics import Metric +from nemo.utils import logging + +try: + import torchaudio + + HAVE_TORCHAUDIO = True +except ModuleNotFoundError: + HAVE_TORCHAUDIO = False + + +class SquimMOSMetric(Metric): + """A metric calculating the average Torchaudio Squim MOS. + + Args: + fs: sampling rate of the input signals + """ + + sample_rate: int = 16000 # sample rate of the model + mos_sum: torch.Tensor + num_examples: torch.Tensor + higher_is_better: bool = True + + def __init__(self, fs: int, **kwargs: Any): + super().__init__(**kwargs) + + if not HAVE_TORCHAUDIO: + raise ModuleNotFoundError(f"{self.__class__.__name__} metric needs `torchaudio`.") + + if fs != self.sample_rate: + # Resampler: kaiser_best + self._squim_mos_metric_resampler = torchaudio.transforms.Resample( + orig_freq=fs, + new_freq=self.sample_rate, + lowpass_filter_width=64, + rolloff=0.9475937167399596, + resampling_method='sinc_interp_kaiser', + beta=14.769656459379492, + ) + logging.warning('Input signals will be resampled from fs=%d to %d Hz', fs, self.sample_rate) + self.fs = fs + + # MOS model + self._squim_mos_metric_model = torchaudio.pipelines.SQUIM_SUBJECTIVE.get_model() + + self.add_state('mos_sum', default=torch.tensor(0.0), dist_reduce_fx='sum') + self.add_state('num_examples', default=torch.tensor(0), dist_reduce_fx='sum') + logging.debug('Setup metric %s with input fs=%s', self.__class__.__name__, self.fs) + + def update(self, preds: torch.Tensor, target: torch.Tensor) -> None: + """Update the metric by calculating the MOS score for the current batch. + + Args: + preds: tensor with predictions, shape (B, T) + target: tensor with target signals, shape (B, T). Target can be a non-matching reference. + """ + if self.fs != self.sample_rate: + preds = self._squim_mos_metric_resampler(preds) + target = self._squim_mos_metric_resampler(target) + + if preds.ndim == 1: + # Unsqueeze batch dimension + preds = preds.unsqueeze(0) + target = target.unsqueeze(0) + elif preds.ndim > 2: + raise ValueError(f'Expected 1D or 2D signals, got {preds.ndim}D signals') + + mos_batch = self._squim_mos_metric_model(preds, target) + + self.mos_sum += mos_batch.sum() + self.num_examples += mos_batch.numel() + + def compute(self) -> torch.Tensor: + """Compute the underlying metric.""" + return self.mos_sum / self.num_examples + + def state_dict(self, *args, **kwargs): + """Do not save the MOS model and resampler in the state dict.""" + state_dict = super().state_dict(*args, **kwargs) + # Do not include resampler or mos_model in the state dict + remove_keys = [ + key + for key in state_dict.keys() + if '_squim_mos_metric_resampler' in key or '_squim_mos_metric_model' in key + ] + for key in remove_keys: + del state_dict[key] + return state_dict + + +class SquimObjectiveMetric(Metric): + """A metric calculating the average Torchaudio Squim objective metric. + + Args: + fs: sampling rate of the input signals + metric: the objective metric to calculate. One of 'stoi', 'pesq', 'si_sdr' + """ + + sample_rate: int = 16000 # sample rate of the model + metric_sum: torch.Tensor + num_examples: torch.Tensor + higher_is_better: bool = True + + def __init__(self, fs: int, metric: str, **kwargs: Any): + super().__init__(**kwargs) + + if not HAVE_TORCHAUDIO: + raise ModuleNotFoundError(f"{self.__class__.__name__} needs `torchaudio`.") + + if fs != self.sample_rate: + # Resampler: kaiser_best + self._squim_objective_metric_resampler = torchaudio.transforms.Resample( + orig_freq=fs, + new_freq=self.sample_rate, + lowpass_filter_width=64, + rolloff=0.9475937167399596, + resampling_method='sinc_interp_kaiser', + beta=14.769656459379492, + ) + logging.warning('Input signals will be resampled from fs=%d to %d Hz', fs, self.sample_rate) + self.fs = fs + + if metric not in ['stoi', 'pesq', 'si_sdr']: + raise ValueError(f'Unsupported metric {metric}. Supported metrics are "stoi", "pesq", "si_sdr".') + + self.metric = metric + + # Objective model + self._squim_objective_metric_model = torchaudio.pipelines.SQUIM_OBJECTIVE.get_model() + + self.add_state('metric_sum', default=torch.tensor(0.0), dist_reduce_fx='sum') + self.add_state('num_examples', default=torch.tensor(0), dist_reduce_fx='sum') + logging.debug('Setup %s with metric=%s, input fs=%s', self.__class__.__name__, self.metric, self.fs) + + def update(self, preds: torch.Tensor, target: Any = None) -> None: + """Update the metric by calculating the selected metric score for the current batch. + + Args: + preds: tensor with predictions, shape (B, T) + target: None, not used. Keeping for interfacfe compatibility with other metrics. + """ + if self.fs != self.sample_rate: + preds = self._squim_objective_metric_resampler(preds) + + if preds.ndim == 1: + # Unsqueeze batch dimension + preds = preds.unsqueeze(0) + elif preds.ndim > 2: + raise ValueError(f'Expected 1D or 2D signals, got {preds.ndim}D signals') + + stoi_batch, pesq_batch, si_sdr_batch = self._squim_objective_metric_model(preds) + + if self.metric == 'stoi': + metric_batch = stoi_batch + elif self.metric == 'pesq': + metric_batch = pesq_batch + elif self.metric == 'si_sdr': + metric_batch = si_sdr_batch + else: + raise ValueError(f'Unknown metric {self.metric}') + + self.metric_sum += metric_batch.sum() + self.num_examples += metric_batch.numel() + + def compute(self) -> torch.Tensor: + """Compute the underlying metric.""" + return self.metric_sum / self.num_examples + + def state_dict(self, *args, **kwargs): + """Do not save the MOS model and resampler in the state dict.""" + state_dict = super().state_dict(*args, **kwargs) + # Do not include resampler or mos_model in the state dict + remove_keys = [ + key + for key in state_dict.keys() + if '_squim_objective_metric_resampler' in key or '_squim_objective_metric_model' in key + ] + for key in remove_keys: + del state_dict[key] + return state_dict diff --git a/nemo/collections/audio/models/audio_to_audio.py b/nemo/collections/audio/models/audio_to_audio.py index b12f9ce73cbe..ef9ce648f1a2 100644 --- a/nemo/collections/audio/models/audio_to_audio.py +++ b/nemo/collections/audio/models/audio_to_audio.py @@ -46,7 +46,11 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): def _setup_loss(self): """Setup loss for this model.""" - self.loss = AudioToAudioModel.from_config_dict(self._cfg.loss) + if 'loss' in self._cfg: + self.loss = AudioToAudioModel.from_config_dict(self._cfg.loss) + else: + logging.warning('No loss function is defined in the config.') + self.loss = None def _get_num_dataloaders(self, tag: str = 'val'): if tag == 'val': diff --git a/nemo/collections/audio/models/enhancement.py b/nemo/collections/audio/models/enhancement.py index f60553704183..e7fbc9023117 100644 --- a/nemo/collections/audio/models/enhancement.py +++ b/nemo/collections/audio/models/enhancement.py @@ -25,7 +25,12 @@ from nemo.core.neural_types import AudioSignal, LengthsType, LossType, NeuralType from nemo.utils import logging -__all__ = ['EncMaskDecAudioToAudioModel', 'ScoreBasedGenerativeAudioToAudioModel', 'PredictiveAudioToAudioModel'] +__all__ = [ + 'EncMaskDecAudioToAudioModel', + 'ScoreBasedGenerativeAudioToAudioModel', + 'PredictiveAudioToAudioModel', + 'SchroedingerBridgeAudioToAudioModel', +] class EncMaskDecAudioToAudioModel(AudioToAudioModel): @@ -433,14 +438,13 @@ def forward(self, input_signal, input_length=None): - decoder to transform the sampler output into the time domain Args: - input_signal: Tensor that represents a batch of raw audio signals, - of shape [B, T] or [B, T, C]. T here represents timesteps, with 1 second of audio represented as + input_signal: Tensor that represents a batch of time-domain audio signals, + of shape [B, C, T]. T here represents timesteps, with 1 second of audio represented as `self.sample_rate` number of floating point values. - input_signal_length: Vector of length B, that contains the individual lengths of the audio - sequences. + input_signal_length: Vector of length B, contains the individual lengths of the audio sequences. Returns: - Output signal `output` in the time domain and the length of the output signal `output_length`. + Output `output_signal` in the time domain and the length of the output signal `output_length`. """ batch_length = input_signal.size(-1) @@ -612,3 +616,353 @@ def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32)) return {f'{tag}_loss': loss} + + +class SchroedingerBridgeAudioToAudioModel(AudioToAudioModel): + """This models is using a Schrödinger Bridge process to generate + an encoded representation of the enhanced signal. + + The model consists of the following blocks: + - encoder: transforms input audio signal into an encoded representation (analysis transform) + - estimator: neural model, estimates the coefficients for the SB process + - noise_schedule: defines the path between the clean and noisy signals + - sampler: sampler for the reverse process, estimates coefficients of the target signal + - decoder: transforms sampler output into the time domain (synthesis transform) + + References: + Schrödinger Bridge for Generative Speech Enhancement, https://arxiv.org/abs/2407.16074 + """ + + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + super().__init__(cfg=cfg, trainer=trainer) + self.sample_rate = self._cfg.sample_rate + + # Setup processing modules + self.encoder = self.from_config_dict(self._cfg.encoder) + self.decoder = self.from_config_dict(self._cfg.decoder) + + # Neural estimator + self.estimator = self.from_config_dict(self._cfg.estimator) + self.estimator_output = self._cfg.estimator_output + + # Noise schedule + self.noise_schedule = self.from_config_dict(self._cfg.noise_schedule) + + # Sampler + self.sampler = hydra.utils.instantiate( + self._cfg.sampler, + noise_schedule=self.noise_schedule, + estimator=self.estimator, + estimator_output=self.estimator_output, + ) + + # Normalization + self.normalize_input = self._cfg.get('normalize_input', False) + + # Metric evaluation + self.max_utts_evaluation_metrics = self._cfg.get('max_utts_evaluation_metrics') + + if self.max_utts_evaluation_metrics is not None: + logging.warning( + 'Metrics will be evaluated on first %d examples of the evaluation datasets.', + self.max_utts_evaluation_metrics, + ) + + # Loss in the encoded domain + if 'loss_encoded' in self._cfg: + self.loss_encoded = self.from_config_dict(self._cfg.loss_encoded) + self.loss_encoded_weight = self._cfg.get('loss_encoded_weight', 1.0) + else: + self.loss_encoded = None + self.loss_encoded_weight = 0.0 + + # Loss in the time domain + if 'loss_time' in self._cfg: + self.loss_time = self.from_config_dict(self._cfg.loss_time) + self.loss_time_weight = self._cfg.get('loss_time_weight', 1.0) + else: + self.loss_time = None + self.loss_time_weight = 0.0 + + if self.loss is not None and (self.loss_encoded is not None or self.loss_time is not None): + raise ValueError('Either ``loss`` or ``loss_encoded`` and ``loss_time`` should be defined, not both.') + + # Term added to the denominator to improve numerical stability + self.eps = self._cfg.get('eps', 1e-8) + + # Setup optional optimization flags + self.setup_optimization_flags() + + logging.debug('Initialized %s', self.__class__.__name__) + logging.debug('\testimator_output: %s', self.estimator_output) + logging.debug('\tnormalize_input: %s', self.normalize_input) + logging.debug('\tloss: %s', self.loss) + logging.debug('\tloss_encoded: %s', self.loss_encoded) + logging.debug('\tloss_encoded_weight: %s', self.loss_encoded_weight) + logging.debug('\tloss_time: %s', self.loss_time) + logging.debug('\tloss_time_weight: %s', self.loss_time_weight) + logging.debug('\teps: %s', self.eps) + + @property + def input_types(self) -> Dict[str, NeuralType]: + # time-domain input + return { + "input_signal": NeuralType(('B', 'C', 'T'), AudioSignal(freq=self.sample_rate)), + "input_length": NeuralType(tuple('B'), LengthsType(), optional=True), + } + + @property + def output_types(self) -> Dict[str, NeuralType]: + # time-domain output + return { + "output_signal": NeuralType(('B', 'C', 'T'), AudioSignal(freq=self.sample_rate)), + "output_length": NeuralType(tuple('B'), LengthsType(), optional=True), + } + + @typecheck() + @torch.inference_mode() + def forward(self, input_signal, input_length=None): + """Forward pass of the model. + + Forward pass of the model consists of the following steps + - encoder to obtain the encoded representation of the input signal + - sampler to generate the estimated coefficients of the target signal + - decoder to transform the estimated output into the time domain + + Args: + input_signal: Tensor that represents a batch of time-domain audio signals, + of shape [B, C, T]. T here represents timesteps, with 1 second of audio represented as + `self.sample_rate` number of floating point values. + input_signal_length: Vector of length B, contains the individual lengths of the audio sequences. + + Returns: + Output `output_signal` in the time domain and the length of the output signal `output_length`. + """ + batch_length = input_signal.size(-1) + + if self.normalize_input: + # max for each example in the batch + norm_scale = torch.amax(input_signal.abs(), dim=(-1, -2), keepdim=True) + # scale input signal + input_signal = input_signal / (norm_scale + self.eps) + + # Encoder + encoded, encoded_length = self.encoder(input=input_signal, input_length=input_length) + + # Sampler + generated, generated_length = self.sampler( + prior_mean=encoded, estimator_condition=encoded, state_length=encoded_length + ) + + # Decoder + output, output_length = self.decoder(input=generated, input_length=generated_length) + + if self.normalize_input: + # rescale to the original scale + output = output * norm_scale + + # Trim or pad the estimated signal to match input length + output = self.match_batch_length(input=output, batch_length=batch_length) + + return output, output_length + + @typecheck( + input_types={ + "target_signal": NeuralType(('B', 'C', 'T'), AudioSignal()), + "input_signal": NeuralType(('B', 'C', 'T'), AudioSignal()), + "input_length": NeuralType(tuple('B'), LengthsType()), + }, + output_types={ + "loss": NeuralType(None, LossType()), + "loss_encoded": NeuralType(None, LossType()), + "loss_time": NeuralType(None, LossType()), + }, + ) + def _step(self, target_signal, input_signal, input_length=None): + """Randomly generate time step for each example in the batch, run neural estimator + to estimate the target and calculate the loss. + """ + batch_size = target_signal.size(0) + + if self.normalize_input: + # max for each example in the batch + norm_scale = torch.amax(input_signal.abs(), dim=(-1, -2), keepdim=True) + # scale input signal + input_signal = input_signal / (norm_scale + self.eps) + # scale the target signal + target_signal = target_signal / (norm_scale + self.eps) + + # Apply encoder to both target and the input + # For example, if the encoder is STFT, then _enc is the complex-valued STFT of the corresponding signal + input_enc, input_enc_len = self.encoder(input=input_signal, input_length=input_length) + target_enc, _ = self.encoder(input=target_signal, input_length=input_length) + + # Generate random time steps + process_time = self.noise_schedule.generate_time(size=batch_size, device=input_enc.device) + + # Prepare necessary info from the noise schedule + alpha_t, alpha_bar_t, alpha_t_max = self.noise_schedule.get_alphas(time=process_time) + sigma_t, sigma_bar_t, sigma_t_max = self.noise_schedule.get_sigmas(time=process_time) + + # Marginal distribution + weight_target = alpha_t * sigma_bar_t**2 / (sigma_t_max**2 + self.eps) + weight_input = alpha_bar_t * sigma_t**2 / (sigma_t_max**2 + self.eps) + # view weights as [B, C, D, T] + weight_target = weight_target.view(-1, 1, 1, 1) + weight_input = weight_input.view(-1, 1, 1, 1) + # mean + mean_x = weight_target * target_enc + weight_input * input_enc + # standard deviation + std_x = alpha_t * sigma_bar_t * sigma_t / (sigma_t_max + self.eps) + # view as [B, C, D, T] + std_x = std_x.view(-1, 1, 1, 1) + + # Generate a random sample from a standard normal distribution + z_norm = torch.randn_like(input_enc) + + # Generate a random sample from the marginal distribution + x_t = mean_x + std_x * z_norm + + # Estimator is conditioned on the generated sample and the original input (prior) + estimator_input = torch.cat([x_t, input_enc], dim=-3) + + # Neural estimator + # Estimator input is the same data type as the encoder output + # For example, if the encoder is STFT, then the estimator input and output are complex-valued coefficients + estimate, estimate_len = self.estimator( + input=estimator_input, input_length=input_enc_len, condition=process_time + ) + + # Prepare output target and calculate loss + if self.estimator_output == 'data_prediction': + if self.loss is not None: + # Single loss in the encoded domain + loss = self.loss(estimate=estimate, target=target_enc, input_length=estimate_len) + loss_encoded = loss_time = None + else: + # Weighted loss between encoded and time domain + loss = 0.0 + + # Loss in the encoded domain + if self.loss_encoded is not None: + # Loss between the estimate and the target in the encoded domain + loss_encoded = self.loss_encoded(estimate=estimate, target=target_enc, input_length=estimate_len) + # Weighting + loss += self.loss_encoded_weight * loss_encoded + else: + loss_encoded = None + + # Loss in the time domain + if self.loss_time is not None: + # Convert the estimate to the time domain + with typecheck.disable_checks(): + # Note: stimate is FloatType, decoder requires SpectrogramType + estimate_signal, _ = self.decoder(input=estimate, input_length=estimate_len) + + # Match estimate length + batch_length = input_signal.size(-1) + estimate_signal = self.match_batch_length(input=estimate_signal, batch_length=batch_length) + + # Loss between the estimate and the target in the time domain + loss_time = self.loss_time( + estimate=estimate_signal, target=target_signal, input_length=input_length + ) + # Weighting + loss += self.loss_time_weight * loss_time + else: + loss_time = None + else: + raise NotImplementedError(f'Output type {self.estimator_output} is not implemented') + + return loss, loss_encoded, loss_time + + # PTL-specific methods + def training_step(self, batch, batch_idx): + + if isinstance(batch, dict): + # lhotse batches are dictionaries + input_signal = batch['input_signal'] + input_length = batch['input_length'] + target_signal = batch['target_signal'] + else: + input_signal, input_length, target_signal, _ = batch + + # For consistency, the model uses multi-channel format, even if the channel dimension is 1 + if input_signal.ndim == 2: + input_signal = einops.rearrange(input_signal, 'B T -> B 1 T') + if target_signal.ndim == 2: + target_signal = einops.rearrange(target_signal, 'B T -> B 1 T') + + # Calculate the loss + loss, loss_encoded, loss_time = self._step( + target_signal=target_signal, input_signal=input_signal, input_length=input_length + ) + + # Logs + self.log('train_loss', loss) + self.log('learning_rate', self._optimizer.param_groups[0]['lr']) + self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32)) + + if loss_encoded is not None: + self.log('train_loss_encoded', loss_encoded) + + if loss_time is not None: + self.log('train_loss_time', loss_time) + + return loss + + def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = 'val'): + + if isinstance(batch, dict): + # lhotse batches are dictionaries + input_signal = batch['input_signal'] + input_length = batch['input_length'] + target_signal = batch['target_signal'] + else: + input_signal, input_length, target_signal, _ = batch + + # For consistency, the model uses multi-channel format, even if the channel dimension is 1 + if input_signal.ndim == 2: + input_signal = einops.rearrange(input_signal, 'B T -> B 1 T') + if target_signal.ndim == 2: + target_signal = einops.rearrange(target_signal, 'B T -> B 1 T') + + # Calculate loss + loss, *_ = self._step(target_signal=target_signal, input_signal=input_signal, input_length=input_length) + + # Update metrics + update_metrics = False + if self.max_utts_evaluation_metrics is None: + # Always update if max is not configured + update_metrics = True + # Number of examples to process + num_examples = input_signal.size(0) # batch size + else: + # Check how many examples have been used for metric calculation + first_metric_name = next(iter(self.metrics[tag][dataloader_idx])) + num_examples_evaluated = self.metrics[tag][dataloader_idx][first_metric_name].num_examples + # Update metrics if some examples were not processed + update_metrics = num_examples_evaluated < self.max_utts_evaluation_metrics + # Number of examples to process + num_examples = min(self.max_utts_evaluation_metrics - num_examples_evaluated, input_signal.size(0)) + + if update_metrics: + # Generate output signal + output_signal, _ = self.forward( + input_signal=input_signal[:num_examples, ...], input_length=input_length[:num_examples] + ) + + # Update metrics + if hasattr(self, 'metrics') and tag in self.metrics: + # Update metrics for this (tag, dataloader_idx) + for name, metric in self.metrics[tag][dataloader_idx].items(): + metric.update( + preds=output_signal, + target=target_signal[:num_examples, ...], + input_length=input_length[:num_examples], + ) + + # Log global step + self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32)) + + return {f'{tag}_loss': loss} diff --git a/nemo/collections/audio/modules/transforms.py b/nemo/collections/audio/modules/transforms.py index ecbdca88e22b..6839ae0f7598 100644 --- a/nemo/collections/audio/modules/transforms.py +++ b/nemo/collections/audio/modules/transforms.py @@ -14,6 +14,7 @@ from typing import Dict, Optional, Tuple import torch +from einops import rearrange from nemo.collections.asr.parts.preprocessing.features import make_seq_mask_like from nemo.core.classes import NeuralModule, typecheck @@ -43,6 +44,157 @@ class AudioToSpectrogram(NeuralModule): scale: Positive scaling of the spectrogram. """ + def __init__(self, fft_length: int, hop_length: int, magnitude_power: float = 1.0, scale: float = 1.0): + super().__init__() + + # For now, assume FFT length is divisible by two + if fft_length % 2 != 0: + raise ValueError(f'fft_length = {fft_length} must be divisible by 2') + + self.fft_length = fft_length + self.hop_length = hop_length + self.pad_mode = 'constant' + window = torch.hann_window(self.win_length) + self.register_buffer('window', window) + + self.num_subbands = fft_length // 2 + 1 + + if magnitude_power <= 0: + raise ValueError(f'Magnitude power needs to be positive: current value {magnitude_power}') + self.magnitude_power = magnitude_power + + if scale <= 0: + raise ValueError(f'Scale needs to be positive: current value {scale}') + self.scale = scale + + logging.debug('Initialized %s with:', self.__class__.__name__) + logging.debug('\tfft_length: %s', fft_length) + logging.debug('\thop_length: %s', hop_length) + logging.debug('\tmagnitude_power: %s', magnitude_power) + logging.debug('\tscale: %s', scale) + + @property + def win_length(self) -> int: + return self.fft_length + + def stft(self, x: torch.Tensor): + """Apply STFT as in torchaudio.transforms.Spectrogram(power=None) + + Args: + x_spec: Input time-domain signal, shape (..., T) + + Returns: + Time-domain signal ``x_spec = STFT(x)``, shape (..., F, N). + """ + # pack batch + B, C, T = x.size() + x = rearrange(x, 'B C T -> (B C) T') + + x_spec = torch.stft( + input=x, + n_fft=self.fft_length, + hop_length=self.hop_length, + win_length=self.win_length, + window=self.window, + center=True, + pad_mode=self.pad_mode, + normalized=False, + onesided=True, + return_complex=True, + ) + + # unpack batch + x_spec = rearrange(x_spec, '(B C) F N -> B C F N', B=B, C=C) + + return x_spec + + @property + def input_types(self) -> Dict[str, NeuralType]: + """Returns definitions of module output ports.""" + return { + "input": NeuralType(('B', 'C', 'T'), AudioSignal()), + "input_length": NeuralType(('B',), LengthsType(), optional=True), + } + + @property + def output_types(self) -> Dict[str, NeuralType]: + """Returns definitions of module output ports.""" + return { + "output": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), + "output_length": NeuralType(('B',), LengthsType()), + } + + @typecheck() + def forward( + self, input: torch.Tensor, input_length: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Convert a batch of C-channel input signals + into a batch of complex-valued spectrograms. + + Args: + input: Time-domain input signal with C channels, shape (B, C, T) + input_length: Length of valid entries along the time dimension, shape (B,) + + Returns: + Output spectrogram with F subbands and N time frames, shape (B, C, F, N) + and output length with shape (B,). + """ + B, T = input.size(0), input.size(-1) + input = input.view(B, -1, T) + + # STFT output (B, C, F, N) + with torch.cuda.amp.autocast(enabled=False): + output = self.stft(input.float()) + + if self.magnitude_power != 1: + # apply power on the magnitude + output = torch.pow(output.abs(), self.magnitude_power) * torch.exp(1j * output.angle()) + + if self.scale != 1: + # apply scaling of the coefficients + output = self.scale * output + + if input_length is not None: + # Mask padded frames + output_length = self.get_output_length(input_length=input_length) + + length_mask: torch.Tensor = make_seq_mask_like( + lengths=output_length, like=output, time_dim=-1, valid_ones=False + ) + output = output.masked_fill(length_mask, 0.0) + else: + # Assume all frames are valid for all examples in the batch + output_length = output.size(-1) * torch.ones(B, device=output.device).long() + + return output, output_length + + def get_output_length(self, input_length: torch.Tensor) -> torch.Tensor: + """Get length of valid frames for the output. + + Args: + input_length: number of valid samples, shape (B,) + + Returns: + Number of valid frames, shape (B,) + """ + # centered STFT results in (T // hop_length + 1) frames for T samples (cf. torch.stft) + output_length = input_length.div(self.hop_length, rounding_mode='floor').add(1).long() + return output_length + + +class AudioToSpectrogramTA(NeuralModule): + """Transform a batch of input multi-channel signals into a batch of + STFT-based spectrograms. Using torchaudio. + + Args: + fft_length: length of FFT + hop_length: length of hops/shifts of the sliding window + power: exponent for magnitude spectrogram. Default `None` will + return a complex-valued spectrogram + magnitude_power: Transform magnitude of the spectrogram as x^magnitude_power. + scale: Positive scaling of the spectrogram. + """ + def __init__(self, fft_length: int, hop_length: int, magnitude_power: float = 1.0, scale: float = 1.0): if not HAVE_TORCHAUDIO: logging.error('Could not import torchaudio. Some features might not work.') @@ -62,7 +214,7 @@ def __init__(self, fft_length: int, hop_length: int, magnitude_power: float = 1. ) # number of subbands - self.F = fft_length // 2 + 1 + self.num_subbands = fft_length // 2 + 1 if magnitude_power <= 0: raise ValueError(f'Magnitude power needs to be positive: current value {magnitude_power}') @@ -78,10 +230,6 @@ def __init__(self, fft_length: int, hop_length: int, magnitude_power: float = 1. logging.debug('\tmagnitude_power: %s', magnitude_power) logging.debug('\tscale: %s', scale) - @property - def num_subbands(self) -> int: - return self.F - @property def input_types(self) -> Dict[str, NeuralType]: """Returns definitions of module output ports.""" @@ -166,6 +314,157 @@ class SpectrogramToAudio(NeuralModule): scale: Spectrogram will be scaled with 1/scale before the inverse transform. """ + def __init__(self, fft_length: int, hop_length: int, magnitude_power: float = 1.0, scale: float = 1.0): + super().__init__() + + # For now, assume FFT length is divisible by two + if fft_length % 2 != 0: + raise ValueError(f'fft_length = {fft_length} must be divisible by 2') + + self.fft_length = fft_length + self.hop_length = hop_length + window = torch.hann_window(self.win_length) + self.register_buffer('window', window) + + self.num_subbands = fft_length // 2 + 1 + + if magnitude_power <= 0: + raise ValueError(f'Magnitude power needs to be positive: current value {magnitude_power}') + self.magnitude_power = magnitude_power + + if scale <= 0: + raise ValueError(f'Scale needs to be positive: current value {scale}') + self.scale = scale + + logging.debug('Initialized %s with:', self.__class__.__name__) + logging.debug('\tfft_length: %s', fft_length) + logging.debug('\thop_length: %s', hop_length) + logging.debug('\tmagnitude_power: %s', magnitude_power) + logging.debug('\tscale: %s', scale) + + @property + def win_length(self) -> int: + return self.fft_length + + def istft(self, x_spec: torch.Tensor): + """Apply iSTFT as in torchaudio.transforms.InverseSpectrogram + + Args: + x_spec: Input complex-valued spectrogram, shape (..., F, N) + + Returns: + Time-domain signal ``x = iSTFT(x_spec)``, shape (..., T). + """ + if not x_spec.is_complex(): + raise ValueError("Expected `x_spec` to be complex dtype.") + + # pack batch + B, C, F, N = x_spec.size() + x_spec = rearrange(x_spec, 'B C F N -> (B C) F N') + + x = torch.istft( + input=x_spec, + n_fft=self.fft_length, + hop_length=self.hop_length, + win_length=self.win_length, + window=self.window, + center=True, + normalized=False, + onesided=True, + length=None, + return_complex=False, + ) + + # unpack batch + x = rearrange(x, '(B C) T -> B C T', B=B, C=C) + + return x + + @property + def input_types(self) -> Dict[str, NeuralType]: + """Returns definitions of module output ports.""" + return { + "input": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), + "input_length": NeuralType(('B',), LengthsType(), optional=True), + } + + @property + def output_types(self) -> Dict[str, NeuralType]: + """Returns definitions of module output ports.""" + return { + "output": NeuralType(('B', 'C', 'T'), AudioSignal()), + "output_length": NeuralType(('B',), LengthsType()), + } + + @typecheck() + def forward(self, input: torch.Tensor, input_length: Optional[torch.Tensor] = None) -> torch.Tensor: + """Convert input complex-valued spectrogram to a time-domain + signal. Multi-channel IO is supported. + + Args: + input: Input spectrogram for C channels, shape (B, C, F, N) + input_length: Length of valid entries along the time dimension, shape (B,) + + Returns: + Time-domain signal with T time-domain samples and C channels, (B, C, T) + and output length with shape (B,). + """ + B, F, N = input.size(0), input.size(-2), input.size(-1) + assert F == self.num_subbands, f'Number of subbands F={F} not matching self.num_subbands={self.num_subbands}' + input = input.view(B, -1, F, N) + + # iSTFT output (B, C, T) + with torch.cuda.amp.autocast(enabled=False): + output = input.cfloat() + + if self.scale != 1: + # apply 1/scale on the coefficients + output = output / self.scale + + if self.magnitude_power != 1: + # apply 1/power on the magnitude + output = torch.pow(output.abs(), 1 / self.magnitude_power) * torch.exp(1j * output.angle()) + output = self.istft(output) + + if input_length is not None: + # Mask padded samples + output_length = self.get_output_length(input_length=input_length) + + length_mask: torch.Tensor = make_seq_mask_like( + lengths=output_length, like=output, time_dim=-1, valid_ones=False + ) + output = output.masked_fill(length_mask, 0.0) + else: + # Assume all frames are valid for all examples in the batch + output_length = output.size(-1) * torch.ones(B, device=output.device).long() + + return output, output_length + + def get_output_length(self, input_length: torch.Tensor) -> torch.Tensor: + """Get length of valid samples for the output. + + Args: + input_length: number of valid frames, shape (B,) + + Returns: + Number of valid samples, shape (B,) + """ + # centered STFT results in ((N-1) * hop_length) time samples for N frames (cf. torch.istft) + output_length = input_length.sub(1).mul(self.hop_length).long() + return output_length + + +class SpectrogramToAudioTA(NeuralModule): + """Transform a batch of input multi-channel spectrograms into a batch of + time-domain multi-channel signals. Using torchaudio. + + Args: + fft_length: length of FFT + hop_length: length of hops/shifts of the sliding window + magnitude_power: Transform magnitude of the spectrogram as x^(1/magnitude_power). + scale: Spectrogram will be scaled with 1/scale before the inverse transform. + """ + def __init__(self, fft_length: int, hop_length: int, magnitude_power: float = 1.0, scale: float = 1.0): if not HAVE_TORCHAUDIO: logging.error('Could not import torchaudio. Some features might not work.') @@ -184,7 +483,7 @@ def __init__(self, fft_length: int, hop_length: int, magnitude_power: float = 1. n_fft=fft_length, hop_length=hop_length, pad_mode='constant' ) - self.F = fft_length // 2 + 1 + self.num_subbands = fft_length // 2 + 1 if magnitude_power <= 0: raise ValueError(f'Magnitude power needs to be positive: current value {magnitude_power}') @@ -200,10 +499,6 @@ def __init__(self, fft_length: int, hop_length: int, magnitude_power: float = 1. logging.debug('\tmagnitude_power: %s', magnitude_power) logging.debug('\tscale: %s', scale) - @property - def num_subbands(self) -> int: - return self.F - @property def input_types(self) -> Dict[str, NeuralType]: """Returns definitions of module output ports.""" @@ -234,7 +529,7 @@ def forward(self, input: torch.Tensor, input_length: Optional[torch.Tensor] = No and output length with shape (B,). """ B, F, N = input.size(0), input.size(-2), input.size(-1) - assert F == self.F, f'Number of subbands F={F} not matching self.F={self.F}' + assert F == self.num_subbands, f'Number of subbands F={F} not matching self.num_subbands={self.num_subbands}' input = input.view(B, -1, F, N) # iSTFT output (B, C, T) diff --git a/nemo/collections/audio/parts/submodules/diffusion.py b/nemo/collections/audio/parts/submodules/diffusion.py index c8b3e803e373..2c9e08fc30fd 100644 --- a/nemo/collections/audio/parts/submodules/diffusion.py +++ b/nemo/collections/audio/parts/submodules/diffusion.py @@ -18,7 +18,7 @@ import numpy as np import torch -from nemo.collections.tts.parts.utils.helpers import mask_sequence_tensor +from nemo.collections.common.parts.utils import mask_sequence_tensor from nemo.core.classes import NeuralModule, typecheck from nemo.core.neural_types import FloatType, LengthsType, NeuralType, SpectrogramType, VoidType from nemo.utils import logging diff --git a/nemo/collections/audio/parts/submodules/ncsnpp.py b/nemo/collections/audio/parts/submodules/ncsnpp.py index adbeccc0dc02..543e29fc7847 100644 --- a/nemo/collections/audio/parts/submodules/ncsnpp.py +++ b/nemo/collections/audio/parts/submodules/ncsnpp.py @@ -20,8 +20,7 @@ import torch import torch.nn.functional as F -from nemo.collections.common.parts.utils import activation_registry -from nemo.collections.tts.parts.utils.helpers import mask_sequence_tensor +from nemo.collections.common.parts.utils import activation_registry, mask_sequence_tensor from nemo.core.classes import NeuralModule, typecheck from nemo.core.neural_types import FloatType, LengthsType, NeuralType, SpectrogramType, VoidType from nemo.utils import logging diff --git a/nemo/collections/audio/parts/submodules/schroedinger_bridge.py b/nemo/collections/audio/parts/submodules/schroedinger_bridge.py new file mode 100644 index 000000000000..07bfc2f88011 --- /dev/null +++ b/nemo/collections/audio/parts/submodules/schroedinger_bridge.py @@ -0,0 +1,607 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from abc import ABC, abstractmethod +from typing import Optional + +import torch + +from nemo.collections.common.parts.utils import mask_sequence_tensor +from nemo.core.classes import NeuralModule, typecheck +from nemo.core.neural_types import LengthsType, NeuralType, SpectrogramType +from nemo.utils import logging + + +class SBNoiseSchedule(NeuralModule, ABC): + """Noise schedule for the Schrödinger Bridge + + Args: + time_min: minimum time for the process + time_max: maximum time for the process + num_steps: number of steps for the process + eps: small regularization + + References: + Schrödinger Bridge for Generative Speech Enhancement, https://arxiv.org/abs/2407.16074 + """ + + def __init__( + self, + time_min: float = 0.0, + time_max: float = 1.0, + num_steps: int = 100, + eps: float = 1e-8, + ): + super().__init__() + + # min and max time + if time_min < 0: + raise ValueError(f'time_min should be non-negative, current value {time_min}') + + if time_max <= time_min: + raise ValueError(f'time_max should be larger than time_min, current max {time_max} and min {time_min}') + + self.time_min = time_min + self.time_max = time_max + + if num_steps <= 0: + raise ValueError(f'Expected num_steps > 0, got {num_steps}') + + self.num_steps = num_steps + + if eps <= 0: + raise ValueError(f'Expected eps > 0, got {eps}') + + self.eps = eps + + logging.debug('Initialized %s with', self.__class__.__name__) + logging.debug('\ttime_min: %s', self.time_min) + logging.debug('\ttime_max: %s', self.time_max) + logging.debug('\tnum_steps: %s', self.num_steps) + logging.debug('\teps: %s', self.eps) + + @property + def dt(self) -> float: + """Time step for the process.""" + return self.time_max / self.num_steps + + @property + def time_delta(self) -> float: + """Time range for the process.""" + return self.time_max - self.time_min + + def generate_time(self, size: int, device: torch.device) -> torch.Tensor: + """Generate random time steps in the valid range.""" + time = torch.rand(size, device=device) * self.time_delta + self.time_min + return time + + @property + def alpha_t_max(self): + """Return alpha_t at t_max.""" + t_max = torch.tensor([self.time_max], device=alpha.device) + return self.alpha(t_max) + + @property + def sigma_t_max(self): + """Return sigma_t at t_max.""" + t_max = torch.tensor([self.time_max], device=alpha.device) + return self.sigma(t_max) + + @abstractmethod + def f(self, time: torch.Tensor) -> torch.Tensor: + """Drift scaling f(t). + + Args: + time: tensor with time steps + + Returns: + Tensor the same size as time, representing drift scaling. + """ + pass + + @abstractmethod + def g(self, time: torch.Tensor) -> torch.Tensor: + """Diffusion scaling g(t). + + Args: + time: tensor with time steps + + Returns: + Tensor the same size as time, representing diffusion scaling. + """ + pass + + @abstractmethod + def alpha(self, time: torch.Tensor) -> torch.Tensor: + """Return alpha for SB noise schedule. + + alpha_t = exp( int_0^s f(s) ds ) + + Args: + time: tensor with time steps + + Returns: + Tensor the same size as time, representing alpha for each time. + """ + pass + + def alpha_bar_from_alpha(self, alpha: torch.Tensor) -> (torch.Tensor, torch.Tensor): + """Return alpha_bar for SB. + + alpha_bar = alpha_t / alpha_t_max + + Args: + alpha: tensor with alpha values + + Returns: + Tensors the same size as alpha, representing alpha_bar and alpha_t_max. + """ + alpha_t_max = self.alpha(torch.tensor([self.time_max], device=alpha.device)) + alpha_bar = alpha / (alpha_t_max + self.eps) + return alpha_bar, alpha_t_max + + def get_alphas(self, time: torch.Tensor) -> (torch.Tensor, torch.Tensor, torch.Tensor): + """Return alpha, alpha_bar and alpha_t_max for SB. + + Args: + time: tensor with time steps + + Returns: + Tuple of tensors with alpha, alpha_bar and alpha_t_max. + """ + alpha = self.alpha(time) + alpha_bar, alpha_t_max = self.alpha_bar_from_alpha(alpha) + return alpha, alpha_bar, alpha_t_max + + @abstractmethod + def sigma(self, time: torch.Tensor) -> torch.Tensor: + """Return sigma_t for SB. + + sigma_t^2 = int_0^s g^2(s) / alpha_s^2 ds + + Args: + time: tensor with time steps + + Returns: + Tensor the same size as time, representing sigma for each time. + """ + pass + + def sigma_bar_from_sigma(self, sigma: torch.Tensor) -> (torch.Tensor, torch.Tensor): + """Return sigma_bar_t for SB. + + sigma_bar_t^2 = sigma_t_max^2 - sigma_t^2 + + Args: + sigma: tensor with sigma values + + Returns: + Tensors the same size as sigma, representing sigma_bar and sigma_t_max. + """ + sigma_t_max = self.sigma(torch.tensor([self.time_max], device=sigma.device)) + sigma_bar_sq = sigma_t_max**2 - sigma**2 + return torch.sqrt(sigma_bar_sq + self.eps), sigma_t_max + + def get_sigmas(self, time: torch.Tensor) -> (torch.Tensor, torch.Tensor, torch.Tensor): + """Return sigma, sigma_bar and sigma_t_max for SB. + + Args: + time: tensor with time steps + + Returns: + Tuple of tensors with sigma, sigma_bar and sigma_t_max. + """ + sigma = self.sigma(time) + sigma_bar, sigma_t_max = self.sigma_bar_from_sigma(sigma) + return sigma, sigma_bar, sigma_t_max + + @abstractmethod + def copy(self): + """Return a copy of the noise schedule.""" + pass + + def __repr__(self): + desc = f'{self.__class__.__name__}(time_min={self.time_min}, time_max={self.time_max}, num_steps={self.num_steps})' + desc += f'\n\tdt: {self.dt}' + desc += f'\n\ttime_delta: {self.time_delta}' + return desc + + +class SBNoiseScheduleVE(SBNoiseSchedule): + """Variance exploding noise schedule for the Schrödinger Bridge. + + Args: + k: defines the base for the exponential diffusion coefficient + c: scaling for the diffusion coefficient + time_min: minimum time for the process + time_max: maximum time for the process + num_steps: number of steps for the process + eps: small regularization + + References: + Schrödinger Bridge for Generative Speech Enhancement, https://arxiv.org/abs/2407.16074 + """ + + def __init__( + self, + k: float, + c: float, + time_min: float = 0.0, + time_max: float = 1.0, + num_steps: int = 100, + eps: float = 1e-8, + ): + super().__init__(time_min=time_min, time_max=time_max, num_steps=num_steps, eps=eps) + + # Shape parameters + if k <= 1: + raise ValueError(f'Expected k > 1, got {k}') + + if c <= 0: + raise ValueError(f'Expected c > 0, got {c}') + + self.c = c + self.k = k + + logging.debug('Initialized %s with', self.__class__.__name__) + logging.debug('\tk: %s', self.k) + logging.debug('\tc: %s', self.c) + logging.debug('\ttime_min: %s', self.time_min) + logging.debug('\ttime_max: %s', self.time_max) + logging.debug('\tnum_steps: %s', self.num_steps) + logging.debug('\teps: %s', self.eps) + + def f(self, time: torch.Tensor) -> torch.Tensor: + return torch.zeros_like(time) + + def g(self, time: torch.Tensor) -> torch.Tensor: + return torch.sqrt(self.c) * self.k**self.time + + def alpha(self, time: torch.Tensor) -> torch.Tensor: + return torch.ones_like(time) + + def sigma(self, time: torch.Tensor) -> torch.Tensor: + sigma_sq = self.c * (self.k ** (2 * time) - 1) / (2 * math.log(self.k) + self.eps) + return torch.sqrt(sigma_sq) + + def copy(self): + return SBNoiseScheduleVE( + k=self.k, + c=self.c, + time_min=self.time_min, + time_max=self.time_max, + num_steps=self.num_steps, + eps=self.eps, + ) + + def __repr__(self): + desc = super().__repr__() + desc += f'\n\tk: {self.k}' + desc += f'\n\tc: {self.c}' + return desc + + +class SBNoiseScheduleVP(SBNoiseSchedule): + """Variance preserving noise schedule for the Schrödinger Bridge. + + Args: + beta_0: defines the lower bound for diffusion coefficient + beta_1: defines upper bound for diffusion coefficient + c: scaling for the diffusion coefficient + time_min: minimum time for the process + time_max: maximum time for the process + num_steps: number of steps for the process + eps: small regularization + """ + + def __init__( + self, + beta_0: float, + beta_1: float, + c: float = 1.0, + time_min: float = 0.0, + time_max: float = 1.0, + num_steps: int = 100, + eps: float = 1e-8, + ): + super().__init__(time_min=time_min, time_max=time_max, num_steps=num_steps, eps=eps) + + # Shape parameters + if beta_0 < 0: + raise ValueError(f'Expected beta_0 >= 0, got {beta_0}') + + if beta_1 < 0: + raise ValueError(f'Expected beta_1 >= 0, got {beta_1}') + + if beta_0 >= beta_1: + raise ValueError(f'Expected beta_0 < beta_1, got beta_0={beta_0} and beta_1={beta_1}') + + if c <= 0: + raise ValueError(f'Expected c > 0, got {c}') + + self.beta_0 = beta_0 + self.beta_1 = beta_1 + self.c = c + + logging.debug('Initialized %s with', self.__class__.__name__) + logging.debug('\tbeta_0: %s', self.beta_0) + logging.debug('\tbeta_1: %s', self.beta_1) + logging.debug('\tc: %s', self.c) + logging.debug('\ttime_min: %s', self.time_min) + logging.debug('\ttime_max: %s', self.time_max) + logging.debug('\tnum_steps: %s', self.num_steps) + logging.debug('\teps: %s', self.eps) + + def f(self, time: torch.Tensor) -> torch.Tensor: + return -0.5 * (self.beta_0 + time * (self.beta_1 - self.beta_0)) + + def g(self, time: torch.Tensor) -> torch.Tensor: + g_sq = self.c * (self.beta_0 + time * (self.beta_1 - self.beta_0)) + return torch.sqrt(g_sq) + + def alpha(self, time: torch.Tensor) -> torch.Tensor: + tmp = self.beta_0 * time + (self.beta_1 - self.beta_0) / 2 * time**2 + return torch.exp(-0.5 * tmp) + + def sigma(self, time: torch.Tensor) -> torch.Tensor: + sigma_sq = self.beta_0 * time + (self.beta_1 - self.beta_0) / 2 * time**2 + sigma_sq = torch.exp(sigma_sq) - 1 + sigma_sq = self.c * sigma_sq + return torch.sqrt(sigma_sq) + + def copy(self): + return SBNoiseScheduleVP( + beta_0=self.beta_0, + beta_1=self.beta_1, + c=self.c, + time_min=self.time_min, + time_max=self.time_max, + num_steps=self.num_steps, + eps=self.eps, + ) + + def __repr__(self): + desc = super().__repr__() + desc += f'\n\tbeta_0: {self.beta_0}' + desc += f'\n\tbeta_1: {self.beta_1}' + desc += f'\n\tc: {self.c}' + return desc + + +class SBSampler(NeuralModule): + """Schrödinger Bridge sampler. + + Args: + noise_schedule: noise schedule for the bridge + estimator: neural estimator + estimator_output: defines the output of the estimator, e.g., data_prediction + estimator_time: time for conditioning the estimator, e.g., 'current' + or 'previous'. Default is 'previous'. + process: defines the process, e.g., sde or ode + time_max: maximum time for the process + time_min: minimum time for the process + num_steps: number of steps for the process + eps: small regularization to prevent division by zero + + References: + Schrödinger Bridge for Generative Speech Enhancement, https://arxiv.org/abs/2407.16074 + Schrodinger Bridges Beat Diffusion Models on Text-to-Speech Synthesis, https://arxiv.org/abs/2312.03491 + """ + + def __init__( + self, + noise_schedule: SBNoiseSchedule, + estimator: NeuralModule, # neural estimator + estimator_output: str, + estimator_time: str = 'previous', # time for the estimator + process: str = 'sde', + time_max: Optional[float] = None, + time_min: Optional[float] = None, + num_steps: int = 50, + eps: float = 1e-8, + ): + super().__init__() + # Create a copy of the noise schedule + self.noise_schedule = noise_schedule.copy() + + # Update sampling parameters + if time_max is not None: + self.noise_schedule.time_max = time_max + logging.info('noise_schedule.time_max set to: %s', self.noise_schedule.time_max) + + if time_min is not None: + self.noise_schedule.time_min = time_min + logging.info('noise_schedule.time_min set to: %s', self.noise_schedule.time_min) + + self.noise_schedule.num_steps = num_steps + logging.info('noise_schedule.num_steps set to: %s', self.noise_schedule.num_steps) + + # Estimator + self.estimator = estimator + self.estimator_output = estimator_output + self.estimator_time = estimator_time + + # Sampling process + self.process = process + + # Small regularization + if eps <= 0: + raise ValueError(f'Expected eps > 0, got {eps}') + self.eps = eps + + logging.debug('Initialized %s with', self.__class__.__name__) + logging.debug('\testimator_output: %s', self.estimator_output) + logging.debug('\testimator_time: %s', self.estimator_time) + logging.debug('\tprocess: %s', self.process) + logging.debug('\ttime_min: %s', self.time_min) + logging.debug('\ttime_max: %s', self.time_max) + logging.debug('\tnum_steps: %s', self.num_steps) + logging.debug('\teps: %s', self.eps) + + @property + def time_max(self): + return self.noise_schedule.time_max + + @time_max.setter + def time_max(self, value: float): + self.noise_schedule.time_max = value + logging.debug('noise_schedule.time_max set to: %s', self.noise_schedule.time_max) + + @property + def time_min(self): + return self.noise_schedule.time_min + + @time_min.setter + def time_min(self, value: float): + self.noise_schedule.time_min = value + logging.debug('noise_schedule.time_min set to: %s', self.noise_schedule.time_min) + + @property + def num_steps(self): + return self.noise_schedule.num_steps + + @num_steps.setter + def num_steps(self, value: int): + self.noise_schedule.num_steps = value + logging.debug('noise_schedule.num_steps set to: %s', self.noise_schedule.num_steps) + + @property + def process(self): + return self._process + + @process.setter + def process(self, value: str): + if value not in ['sde', 'ode']: + raise ValueError(f'Unexpected process: {value}') + self._process = value + logging.info('process set to: %s', self._process) + + @property + def estimator_time(self): + return self._estimator_time + + @estimator_time.setter + def estimator_time(self, value: str): + if value not in ['current', 'previous']: + raise ValueError(f'Unexpected estimator time: {value}') + self._estimator_time = value + logging.info('estimator time set to: %s', self._estimator_time) + + @typecheck( + input_types={ + "prior_mean": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), + "estimator_condition": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType(), optional=True), + "state_length": NeuralType(tuple('B'), LengthsType(), optional=True), + }, + output_types={ + "sample": NeuralType(('B', 'C', 'D', 'T'), SpectrogramType()), + "state_length": NeuralType(tuple('B'), LengthsType(), optional=True), + }, + ) + @torch.inference_mode() + def forward( + self, prior_mean: torch.Tensor, estimator_condition: torch.Tensor, state_length: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Takes prior mean and generates a sample.""" + # SB starts from the prior mean + state = prior_mean + + if state_length is not None: + state = mask_sequence_tensor(state, state_length) + + # Time steps for sampling + time_steps = torch.linspace(self.time_max, self.time_min, self.num_steps + 1, device=state.device) + + # Initial values + time_prev = time_steps[0] * torch.ones(state.shape[0], device=state.device) + alpha_prev, _, alpha_t_max = self.noise_schedule.get_alphas(time_prev) + sigma_prev, sigma_bar_prev, sigma_t_max = self.noise_schedule.get_sigmas(time_prev) + + # Sampling + # Sample at the initial time step (`self.time_max`) is exactly the prior_mean. + # We do not need to estimate it, but we need to pass it to the next time step. + # We iterate through the following time steps to generate the sample at the final time (`self.time_min`). + for t in time_steps[1:]: + + # Prepare time steps for the whole batch + time = t * torch.ones(state.shape[0], device=state.device) + + # Prepare input for estimator, concatenate conditioning along the channel dimension + estimator_input = state if estimator_condition is None else torch.cat([state, estimator_condition], dim=1) + estimator_time = time if self.estimator_time == 'current' else time_prev + + # Estimator + if self.estimator_output == 'data_prediction': + current_estimate, _ = self.estimator( + input=estimator_input, input_length=state_length, condition=estimator_time + ) + else: + raise NotImplementedError(f'Unexpected estimator output: {self.estimator_output}') + + # Get noise schedule for current time + alpha_t, alpha_bar_t, _ = self.noise_schedule.get_alphas(time) + sigma_t, sigma_bar_t, _ = self.noise_schedule.get_sigmas(time) + + if self.process == 'sde': + # Calculate scaling for the first-order discretization from the paper + weight_prev = alpha_t * sigma_t**2 / (alpha_prev * sigma_prev**2 + self.eps) + tmp = 1 - sigma_t**2 / (sigma_prev**2 + self.eps) + weight_estimate = alpha_t * tmp + weight_z = alpha_t * sigma_t * torch.sqrt(tmp) + + # View as [B, C, D, T] + weight_prev = weight_prev.view(-1, 1, 1, 1) + weight_estimate = weight_estimate.view(-1, 1, 1, 1) + weight_z = weight_z.view(-1, 1, 1, 1) + + # Random sample + z_norm = torch.randn_like(state) + + # Update state: weighted sum of previous state, current estimate and noise + state = weight_prev * state + weight_estimate * current_estimate + weight_z * z_norm + elif self.process == 'ode': + # Calculate scaling for the first-order discretization from the paper + weight_prev = alpha_t * sigma_t * sigma_bar_t / (alpha_prev * sigma_prev * sigma_bar_prev + self.eps) + weight_estimate = ( + alpha_t + / (sigma_t_max**2 + self.eps) + * (sigma_bar_t**2 - sigma_bar_prev * sigma_t * sigma_bar_t / (sigma_prev + self.eps)) + ) + weight_prior_mean = ( + alpha_t + / (alpha_t_max * sigma_t_max**2 + self.eps) + * (sigma_t**2 - sigma_prev * sigma_t * sigma_bar_t / (sigma_bar_prev + self.eps)) + ) + + # View as [B, C, D, T] + weight_prev = weight_prev.view(-1, 1, 1, 1) + weight_estimate = weight_estimate.view(-1, 1, 1, 1) + weight_prior_mean = weight_prior_mean.view(-1, 1, 1, 1) + + # Update state: weighted sum of previous state, current estimate and prior + state = weight_prev * state + weight_estimate * current_estimate + weight_prior_mean * prior_mean + else: + raise RuntimeError(f'Unexpected process: {self.process}') + + # Save previous values + time_prev = time + alpha_prev = alpha_t + sigma_prev = sigma_t + sigma_bar_prev = sigma_bar_t + + # Final output + if state_length is not None: + state = mask_sequence_tensor(state, state_length) + + return state, state_length diff --git a/nemo/collections/common/data/lhotse/cutset.py b/nemo/collections/common/data/lhotse/cutset.py index 775395400d8e..7e7fdbc95a61 100644 --- a/nemo/collections/common/data/lhotse/cutset.py +++ b/nemo/collections/common/data/lhotse/cutset.py @@ -128,6 +128,7 @@ def read_dataset_config(config) -> tuple[CutSet, bool]: "text_field": config.text_field, "lang_field": config.lang_field, "metadata_only": config.metadata_only, + "force_finite": config.force_finite, "max_open_streams": config.max_open_streams, } input_cfg = config.input_cfg @@ -244,7 +245,7 @@ def parse_and_combine_datasets( weights=weights if weights else None, max_open_streams=propagate_attrs["max_open_streams"], seed=propagate_attrs["shard_seed"], - metadata_only=propagate_attrs["metadata_only"], + force_finite=propagate_attrs["force_finite"] or propagate_attrs["metadata_only"], ) else: (cuts,) = cuts @@ -269,6 +270,7 @@ def read_lhotse_manifest(config, is_tarred: bool) -> CutSet: # This is mostly useful for unit testing or debugging. shard_seed = config.shard_seed metadata_only = config.metadata_only + force_finite = config.force_finite if config.get("cuts_path") is not None: warnings.warn("Note: lhotse.cuts_path will be ignored because lhotse.shar_path was provided.") if isinstance(config.shar_path, (str, Path)): @@ -276,7 +278,7 @@ def read_lhotse_manifest(config, is_tarred: bool) -> CutSet: cuts = CutSet.from_shar( **_resolve_shar_inputs(config.shar_path, metadata_only), shuffle_shards=True, seed=shard_seed ) - if not metadata_only: + if not metadata_only and not force_finite: cuts = cuts.repeat() else: # Multiple datasets in Lhotse Shar format: we will dynamically multiplex them @@ -313,7 +315,7 @@ def read_lhotse_manifest(config, is_tarred: bool) -> CutSet: weights=weights, max_open_streams=config.max_open_streams, seed=config.shard_seed, - metadata_only=metadata_only, + force_finite=force_finite, ) else: # Regular Lhotse manifest points to individual audio files (like native NeMo manifest). @@ -383,6 +385,7 @@ def read_nemo_manifest(config, is_tarred: bool) -> CutSet: "lang_field": config.lang_field, "shuffle_shards": config.shuffle, "shard_seed": config.shard_seed, + "extra_fields": config.get("extra_fields", None), } # The option below is to allow a special case of NeMo manifest iteration as Lhotse CutSet # without performing any I/O. NeMo manifests typically don't have sampling_rate information required by Lhotse, @@ -392,6 +395,7 @@ def read_nemo_manifest(config, is_tarred: bool) -> CutSet: # and other data statistics. notar_kwargs = {"metadata_only": config.metadata_only} metadata_only = config.metadata_only + force_finite = config.force_finite if isinstance(config.manifest_filepath, (str, Path)): logging.info(f"Initializing Lhotse CutSet from a single NeMo manifest (tarred): '{config.manifest_filepath}'") if is_tarred and not metadata_only: @@ -401,7 +405,9 @@ def read_nemo_manifest(config, is_tarred: bool) -> CutSet: tar_paths=config.tarred_audio_filepaths, **common_kwargs, ) - ).repeat() + ) + if not force_finite: + cuts = cuts.repeat() else: cuts = CutSet(LazyNeMoIterator(config.manifest_filepath, **notar_kwargs, **common_kwargs)) else: @@ -467,7 +473,7 @@ def read_nemo_manifest(config, is_tarred: bool) -> CutSet: weights=weights, max_open_streams=config.max_open_streams, seed=config.shard_seed, - metadata_only=metadata_only, + force_finite=force_finite or metadata_only, ) return cuts @@ -477,7 +483,7 @@ def mux( weights: list[int | float], max_open_streams: int | None = None, seed: str | int = "trng", - metadata_only: bool = False, + force_finite: bool = False, ) -> CutSet: """ Helper function to call the right multiplexing method flavour in lhotse. @@ -485,10 +491,10 @@ def mux( it will select a more appropriate multiplexing strategy. """ if max_open_streams is not None: - assert not metadata_only, "max_open_streams and metadata_only options are not compatible" + assert not force_finite, "max_open_streams and metadata_only/force_finite options are not compatible" cuts = CutSet.infinite_mux(*cutsets, weights=weights, seed=seed, max_open_streams=max_open_streams) else: - if not metadata_only: + if not force_finite: cutsets = [cs.repeat() for cs in cutsets] cuts = CutSet.mux(*cutsets, weights=weights, seed=seed) return cuts diff --git a/nemo/collections/common/data/lhotse/dataloader.py b/nemo/collections/common/data/lhotse/dataloader.py index 5533b50922f8..d073e432d7ee 100644 --- a/nemo/collections/common/data/lhotse/dataloader.py +++ b/nemo/collections/common/data/lhotse/dataloader.py @@ -132,6 +132,11 @@ class LhotseDataLoadingConfig: # Enables iteration of NeMo non-tarred manifests that don't have a "sampling_rate" key without performing any I/O. # Note that this will not allow actual dataloading; it's only for manifest iteration as Lhotse objects. metadata_only: bool = False + # Forces the resulting CutSet to be finite, so that the iteration will end after a full single epoch. + # Do not turn this on unless you're sure that you know what you're doing. + # In most cases (such as regular multi-GPU training) it will result in a deadlock due to + # a different number of steps on different DDP ranks. + force_finite: bool = False def get_lhotse_dataloader_from_config( diff --git a/nemo/collections/common/data/lhotse/nemo_adapters.py b/nemo/collections/common/data/lhotse/nemo_adapters.py index d24ce794da5a..55663a04a5a0 100644 --- a/nemo/collections/common/data/lhotse/nemo_adapters.py +++ b/nemo/collections/common/data/lhotse/nemo_adapters.py @@ -16,18 +16,22 @@ import random import re import tarfile +from collections.abc import Mapping, Sequence from io import BytesIO from pathlib import Path from typing import Generator, Iterable, List, Literal +import lhotse.serialization import soundfile from cytoolz import groupby from lhotse import AudioSource, Recording, SupervisionSegment +from lhotse.audio.backend import LibsndfileBackend from lhotse.cut import Cut from lhotse.dataset.dataloading import resolve_seed from lhotse.lazy import LazyIteratorChain, LazyJsonlIterator from lhotse.serialization import open_best from lhotse.utils import compute_num_samples + from nemo.collections.common.parts.preprocessing.manifest import get_full_path @@ -56,16 +60,33 @@ class LazyNeMoIterator: Example:: >>> cuts = lhotse.CutSet(LazyNeMoIterator("nemo_manifests/train.json")) + + We allow attaching custom metadata to cuts from files other than the manifest via ``extra_fields`` argument. + In the example below, we'll iterate file "questions.txt" together with the manifest and attach each line + under ``cut.question`` using the field type ``text_iter``:: + + >>> cuts = lhotse.CutSet(LazyNeMoIterator( + ... "nemo_manifests/train.json", + ... extra_fields=[{"type": "text_iter", "name": "question", "path": "questions.txt"}], + ... )) + + We also support random sampling of lines with field type ``text_sample``:: + + >>> cuts = lhotse.CutSet(LazyNeMoIterator( + ... "nemo_manifests/train.json", + ... extra_fields=[{"type": "text_sample", "name": "question", "path": "questions.txt"}], + ... )) """ def __init__( self, - path: str | Path, + path: str | Path | list[str], text_field: str = "text", lang_field: str = "lang", metadata_only: bool = False, shuffle_shards: bool = False, shard_seed: int | Literal["randomized", "trng"] = "trng", + extra_fields: list[dict[str, str]] | None = None, ) -> None: self.path = path self.shuffle_shards = shuffle_shards @@ -80,8 +101,13 @@ def __init__( self.text_field = text_field self.lang_field = lang_field self.metadata_only = metadata_only + self.extra_fields = extra_fields + validate_extra_fields(self.extra_fields) def __iter__(self) -> Generator[Cut, None, None]: + seed = resolve_seed(self.shard_seed) + # Propagate the random seed + extra_fields = [ExtraField.from_dict({"seed": seed, **field_cfg}) for field_cfg in self.extra_fields or ()] for data in self.source: audio_path = get_full_path(str(data.pop("audio_filepath")), str(self.path)) duration = data.pop("duration") @@ -104,6 +130,8 @@ def __iter__(self) -> Generator[Cut, None, None]: ) ) cut.custom = data + for extra_field in extra_fields: + extra_field.attach_to(cut) yield cut def __len__(self) -> int: @@ -180,20 +208,39 @@ class LazyNeMoTarredIterator: Example of CutSet with inter-shard shuffling enabled:: >>> cuts = lhotse.CutSet(LazyNeMoTarredIterator( - ... manifest_path="nemo_manifests/train.json", + ... manifest_path=["nemo_manifests/sharded_manifests/manifest_0.json", ...], ... tar_paths=["nemo_manifests/audio_0.tar", ...], ... shuffle_shards=True, ... )) + + We allow attaching custom metadata to cuts from files other than the manifest via ``extra_fields`` argument. + In the example below, we'll iterate file "questions.txt" together with the manifest and attach each line + under ``cut.question`` using the field type ``text_iter``:: + + >>> cuts = lhotse.CutSet(LazyNeMoTarredIterator( + ... manifest_path=["nemo_manifests/sharded_manifests/manifest_0.json", ...], + ... tar_paths=["nemo_manifests/audio_0.tar", ...], + ... extra_fields=[{"type": "text_iter", "name": "question", "path": "questions.txt"}], + ... )) + + We also support random sampling of lines with field type ``text_sample``:: + + >>> cuts = lhotse.CutSet(LazyNeMoTarredIterator( + ... manifest_path=["nemo_manifests/sharded_manifests/manifest_0.json", ...], + ... tar_paths=["nemo_manifests/audio_0.tar", ...], + ... extra_fields=[{"type": "text_sample", "name": "question", "path": "questions.txt"}], + ... )) """ def __init__( self, - manifest_path: str | Path, + manifest_path: str | Path | list[str], tar_paths: str | list, shuffle_shards: bool = False, shard_seed: int | Literal["trng", "randomized"] = "trng", text_field: str = "text", lang_field: str = "lang", + extra_fields: list[dict[str, str]] | None = None, ) -> None: self.shard_id_to_manifest: dict[int, Iterable[dict]] self.paths = expand_sharded_filepaths(manifest_path) @@ -235,6 +282,7 @@ def __init__( self.shard_seed = shard_seed self.text_field = text_field self.lang_field = lang_field + self.extra_fields = extra_fields self._validate() def to_shards(self) -> List["LazyNeMoTarredIterator"]: @@ -266,6 +314,7 @@ def _validate(self) -> None: f"* JSON manifest(s) indicate(s) IDs: {sorted(shard_ids_manifest)}\n" f"* Tar path(s) indicate(s) IDs: {sorted(shard_ids_tars)}\n" ) + validate_extra_fields(self.extra_fields) @property def shard_ids(self) -> List[int]: @@ -274,13 +323,26 @@ def shard_ids(self) -> List[int]: def __iter__(self) -> Generator[Cut, None, None]: shard_ids = self.shard_ids + seed = resolve_seed(self.shard_seed) if self.shuffle_shards: - seed = resolve_seed(self.shard_seed) random.Random(seed).shuffle(shard_ids) + # Propagate the random seed + extra_fields = [ExtraField.from_dict({"seed": seed, **field_cfg}) for field_cfg in self.extra_fields or ()] + + # Handle NeMo tarred manifests with offsets. + # They have multiple JSONL entries where audio paths end with '-sub1', '-sub2', etc. for each offset. + offset_pattern = re.compile(r'^.+(-sub\d+)$') + for sid in shard_ids: manifest_path = self.paths[sid] if len(self.paths) > 1 else self.paths[0] - shard_manifest = {data["audio_filepath"]: data for data in self.shard_id_to_manifest[sid]} + + def basename(d: dict) -> str: + return ( + k[: -len(m.group(1))] if (m := offset_pattern.match(k := d["audio_filepath"])) is not None else k + ) + + shard_manifest: dict[str, list[dict]] = groupby(basename, self.shard_id_to_manifest[sid]) tar_path = self.shard_id_to_tar_path[sid] with tarfile.open(fileobj=open_best(tar_path, mode="rb"), mode="r|*") as tar: for tar_info in tar: @@ -288,7 +350,6 @@ def __iter__(self) -> Generator[Cut, None, None]: f"Mismatched entry between JSON manifest ('{manifest_path}') and tar file ('{tar_path}'). " f"Cannot locate JSON entry for tar file '{tar_info.name}'" ) - data = shard_manifest[tar_info.name] raw_audio = tar.extractfile(tar_info).read() # Note: Lhotse has a Recording.from_bytes() utility that we won't use here because # the profiling indicated significant overhead in torchaudio ffmpeg integration @@ -302,19 +363,29 @@ def __iter__(self) -> Generator[Cut, None, None]: num_samples=meta.frames, duration=meta.duration, ) - cut = recording.to_cut() - cut.supervisions.append( - SupervisionSegment( - id=cut.id, - recording_id=cut.recording_id, - start=0, - duration=cut.duration, - text=data.get(self.text_field), - language=data.get(self.lang_field), + cuts_for_recording = [] + for data in sorted(shard_manifest[tar_info.name], key=lambda d: d["audio_filepath"]): + # Cut the recording into corresponding segment and discard audio data outside the segment. + cut = make_cut_with_subset_inmemory_recording( + recording, offset=data.get("offset", 0.0), duration=data.get("duration") ) - ) - cut.custom = _to_custom_attr_dict(data) - yield cut + cut.supervisions.append( + SupervisionSegment( + id=cut.id, + recording_id=cut.recording_id, + start=0, + duration=cut.duration, + text=data.get(self.text_field), + language=data.get(self.lang_field), + ) + ) + cut.custom = _to_custom_attr_dict(data) + for extra_field in extra_fields: + extra_field.attach_to(cut) + cuts_for_recording.append(cut) + del recording # free the memory - helps with very large audio files + del raw_audio + yield from cuts_for_recording def __len__(self) -> int: return len(self.source) @@ -323,11 +394,145 @@ def __add__(self, other): return LazyIteratorChain(self, other) -def expand_sharded_filepaths(path: str | Path) -> list[str]: +def make_cut_with_subset_inmemory_recording( + recording: Recording, offset: float = 0.0, duration: float | None = None +) -> Cut: + """ + This method is built specifically to optimize CPU memory usage during dataloading + when reading tarfiles containing very long recordings (1h+). + Normally each cut would hold a reference to the long in-memory recording and load + the necessary subset of audio (there wouldn't be a separate copy of the long recording for each cut). + This is fairly efficient already, but we don't actually need to hold the unused full recording in memory. + Instead, we re-create each cut so that it only holds a reference to the subset of recording necessary. + This allows us to discard unused data which would otherwise be held in memory as part of sampling buffering. + """ + + # Fast path: no offset and (almost) matching duration (within 200ms; leeway for different audio codec behavior). + cut = recording.to_cut() + if offset == 0.0 and duration is None or abs(duration - recording.duration) < 0.2: + return cut + + # Otherwise, apply the memory optimization. + cut = cut.truncate(offset=offset, duration=duration, preserve_id=True) + audiobytes = BytesIO() + LibsndfileBackend().save_audio(audiobytes, cut.load_audio(), sampling_rate=cut.sampling_rate, format="wav") + audiobytes.seek(0) + new_recording = Recording( + id=recording.id, + sampling_rate=recording.sampling_rate, + num_samples=cut.num_samples, + duration=cut.duration, + sources=[ + AudioSource( + type="memory", + channels=recording.channel_ids, + source=audiobytes.getvalue(), + ) + ], + ) + return new_recording.to_cut() + + +class ExtraField: + TYPE = None + SUPPORTED_TYPES = {} + + def attach_to(self, cut): + raise NotImplementedError() + + def __init_subclass__(cls, **kwargs): + if cls.__name__ not in ExtraField.SUPPORTED_TYPES: + ExtraField.SUPPORTED_TYPES[cls.TYPE] = cls + super().__init_subclass__(**kwargs) + + @staticmethod + def from_dict(data: dict) -> "ExtraField": + assert data["type"] in ExtraField.SUPPORTED_TYPES, f"Unknown transform type: {data['type']}" + return ExtraField.SUPPORTED_TYPES[data["type"]](**{k: v for k, v in data.items() if k != 'type'}) + + @classmethod + def is_supported(cls, field_type: str) -> bool: + return field_type in cls.SUPPORTED_TYPES + + @classmethod + def supported_types(cls) -> list[str]: + return list(cls.SUPPORTED_TYPES) + + +class TextIteratorExtraField(ExtraField): + TYPE = "text_iter" + + def __init__(self, name: str, path: str, seed=None): + self.name = name + self.path = path + self.iterator = None + + def _maybe_init(self): + if self.iterator is None: + self.iterator = iter(map(str.strip, open_best(self.path))) + + def attach_to(self, cut): + self._maybe_init() + try: + attached_value = next(self.iterator) + except StopIteration: + raise RuntimeError(f"Not enough lines in file {self.path} to attach to cuts under field {self.name}.") + setattr(cut, self.name, attached_value) + return cut + + +class TextSampleExtraField(ExtraField): + TYPE = "text_sample" + + def __init__(self, name: str, path: str, seed: int | str): + self.name = name + self.path = path + self.seed = seed + self.population = None + self.rng = None + + def _maybe_init(self): + if self.population is None: + self.population = list(map(str.strip, open_best(self.path))) + self.rng = random.Random(resolve_seed(self.seed)) + + def attach_to(self, cut): + self._maybe_init() + attached_value = self.rng.choice(self.population) + setattr(cut, self.name, attached_value) + return cut + + +def validate_extra_fields(extra_fields): + if extra_fields is None: + return + assert isinstance( + extra_fields, Sequence + ), f"The argument provided to 'extra_fields' must be a list of dicts. We received {extra_fields=}" + for field in extra_fields: + assert isinstance( + field, Mapping + ), f"Each item in 'extra_fields' must be a dict. We received {field=} in {extra_fields=}" + field_type = field.get("type") + assert ExtraField.is_supported(field_type), ( + f"Each item in 'extra_fields' must contain a 'type' field with one of " + f"the supported values ({ExtraField.supported_types()}). " + f"We got {field_type=} in {extra_fields=}" + ) + assert "name" in field, ( + f"Each item in 'extra_fields' must contain a 'name' field so that the field is available under cut.." + f"We found {field=} in {extra_fields=}" + ) + + +def expand_sharded_filepaths(paths: str | Path | list[str]) -> list[str]: # local import to avoid circular imports from nemo.collections.asr.data.audio_to_text import expand_sharded_filepaths as _expand_sharded_filepaths - return _expand_sharded_filepaths(str(path), shard_strategy="replicate", world_size=1, global_rank=0) + if isinstance(paths, Path): + paths = str(paths) + + return _expand_sharded_filepaths(paths, shard_strategy="replicate", world_size=1, global_rank=0) def _to_custom_attr_dict(d: dict, _excluded_fields: set[str] = {"duration", "audio_filepath"}) -> dict: diff --git a/nemo/collections/common/metrics/__init__.py b/nemo/collections/common/metrics/__init__.py index 9e21d93816a9..81f1a181beae 100644 --- a/nemo/collections/common/metrics/__init__.py +++ b/nemo/collections/common/metrics/__init__.py @@ -19,4 +19,5 @@ MetricStringToTorchMetric, TextMetricsSet, ) +from nemo.collections.common.metrics.perf_metrics import FLOPsMeasurementCallback from nemo.collections.common.metrics.perplexity import Perplexity diff --git a/nemo/collections/common/metrics/perf_metrics.py b/nemo/collections/common/metrics/perf_metrics.py new file mode 100644 index 000000000000..5722f52d0e7c --- /dev/null +++ b/nemo/collections/common/metrics/perf_metrics.py @@ -0,0 +1,233 @@ +from typing import Any, Dict, List, Optional + +import numpy as np +from pytorch_lightning.callbacks import Callback + +from nemo.collections.common.parts.perf_metrics_utils import LLM_VOCAB_SIZE_MAP, read_tb_log +from nemo.utils import logging + +__all__ = ["FLOPsMeasurementCallback"] + + +class FLOPsMeasurementCallback(Callback): + """ + Calculate FLOPs per second after last train step for a given job run. + + Args: + model_config (Dict[str, Any]): params for running the experiment/job. + Expects a nested dictionary with parent keys + 1. run- for assessing model name (Eg. 'gpt3', 'llama2', etc.) from sub-key 'name'. + 'name' usually has value like- train_gpt3_5b_*, which is matched to model name 'gpt3'. + 2. exp_manager- for accessing 'explicit_log_dir'. tensorboard log file is stored here, + used for accessing step time needed for calculating TFLOPs per sec per GPU + 3. trainer- for accessing 'num_nodes' and 'devices' needed for calculating + TFLOPs per sec per GPU + 4. model- Hyperparams for the model. Specifically- global batch size, sequence length, + hidden size, ffn hidden size, num_layers, num_attention_heads, num_query_groups, + moe_router_topk. (list might increase with new models as required) + log_dir (Optional[str]): Directory with tenbsorboard log file. If present, will overrride + 'explicit_log_dir' in model_config. Defaults to None. + model_name (Optional[str]): If present, will override 'name' under 'run' in model_config. + Defaults to None. + """ + + higher_is_better = True + + def __init__( + self, + model_config: Dict[str, Any], + log_dir: Optional[str] = None, + model_name: Optional[str] = None, + ): + self.cfg = model_config + + self.run_cfg = self.cfg.get('run', {}) + self.exp_cfg = self.cfg.get('exp_manager', {}) + self.train_cfg = self.cfg.get('trainer', {}) + self.model_cfg = self.cfg.get('model', {}) + + # use config params only when NOT provided explicitly + self.model = self.run_cfg.get('name', "") if model_name is None else model_name + self.log_dir = self.exp_cfg.get('explicit_log_dir', None) if log_dir is None else log_dir + + self.num_nodes = self.train_cfg.get('num_nodes', None) + self.num_gpus_per_node = self.train_cfg.get('devices', None) + + self.gbs = self.model_cfg.get('global_batch_size', None) + self.enc_seq_len = self.model_cfg.get('encoder_seq_length', None) + self.hs = self.model_cfg.get('hidden_size', None) + self.layers = self.model_cfg.get('num_layers', None) + self.ffn_hs = self.model_cfg.get('ffn_hidden_size', None) + self.attention_heads = self.model_cfg.get('num_attention_heads', None) + self.moe_router_topk = self.model_cfg.get('moe_router_topk', None) + + # this handles both- 1. key is present, value is None; 2. key is absent + self.query_groups = self.model_cfg.get('num_query_groups', None) + if self.query_groups is None: + self.query_groups = self.attention_heads + + self.model = self.model.lower() if self.model is not None else self.model + + def on_train_end(self, trainer, pl_module): + """ + PyTorch Lightning callback hook to calculate TFLOPs per sec per GPU after training + """ + tflops_per_sec_per_gpu = -1 + + try: + if "peft" in self.cfg["model"]: + raise NotImplementedError("FLOPs measurement not supported for finetuning jobs") + + step_time_list = read_tb_log(self.log_dir, "train_step_timing in s") + tflops_per_sec_per_gpu = self.eval_tflops_per_sec_per_gpu(step_time_list) + except Exception as exc: + logging.error(f"Failed to calculate TFLOPs per sec per GPU.\n{exc}") + + logging.info(f"TFLOPs per sec per GPU={tflops_per_sec_per_gpu:.2f}") + pl_module.logger.experiment.add_scalar("tflops_per_sec_per_gpu", tflops_per_sec_per_gpu) + + def eval_tflops_per_sec_per_gpu(self, train_step_time: List | float | int) -> float: + """ + Args: + train_step_time (Any[List, float, int]): Train step time (in seconds). + Step time will be less stable for initial steps (~10 steps)- less + accurate measurement + Use average step time over several steps for higher accuracy + Returns: + (float): Model TFLOPs per sec per gpu + """ + total_flops, flops_per_gpu = self.eval_model_flops() + + if not isinstance(train_step_time, list): + train_step_time = [train_step_time] + # efficient mean computation if num train steps is very large + step_time_arr = np.array(train_step_time) + train_step_time = np.mean(step_time_arr[len(step_time_arr) // 2 :]) + + return flops_per_gpu / (1e12 * train_step_time) + + def eval_model_flops(self): + """ + Calculate model FLOPs for a given model + """ + + model_flops_map = { + "gpt3": self._gpt3, + "llama2": self._llama2, + "llama3": self._llama3, + "nemotron": self._nemotron, + "mixtral": self._mixtral, + "bert": self._bert, + } + + if self.model is not None: + model_matches = [model for model in model_flops_map if model in self.model] + self.model = model_matches[0] if len(model_matches) > 0 else self.model + if self.model not in model_flops_map: + logging.info(f"FLOPs measurement supported for {list(model_flops_map.keys())}") + raise KeyError(f"Failed to extract valid model name from or missing FLOPs calculations for {self.model}") + + total_flops = model_flops_map[self.model]() + flops_per_gpu = total_flops / (self.num_nodes * self.num_gpus_per_node) + + return total_flops, flops_per_gpu + + def _gpt3(self): + """Model FLOPs for GPT3 family""" + + vocab_size = LLM_VOCAB_SIZE_MAP["gpt3"] + + return ( + 24 * self.gbs * self.enc_seq_len * self.hs * self.hs + + 4 * self.gbs * self.enc_seq_len * self.enc_seq_len * self.hs + ) * (3 * self.layers) + (6 * self.gbs * self.enc_seq_len * self.hs * vocab_size) + + def _llama2(self): + """Model FLOPs for llama2 family""" + vocab_size = LLM_VOCAB_SIZE_MAP["llama2"] + + return ( + self.gbs + * self.enc_seq_len + * self.layers + * self.hs + * self.hs + * ( + 12 + + (12 * self.query_groups / self.attention_heads) + + (18 * self.ffn_hs / self.hs) + + (12 * self.enc_seq_len / self.hs) + + (6 * vocab_size / (self.layers * self.hs)) + ) + ) + + def _llama3(self): + """Model FLOPs for llama3 family""" + vocab_size = LLM_VOCAB_SIZE_MAP["llama3"] + + return ( + self.gbs + * self.enc_seq_len + * self.layers + * self.hs + * self.hs + * ( + 12 + + (12 * self.query_groups / self.attention_heads) + + (18 * self.ffn_hs / self.hs) + + (12 * self.enc_seq_len / self.hs) + + (6 * vocab_size / (self.layers * self.hs)) + ) + ) + + def _nemotron(self): + """Model FLOPs for nemotron family""" + vocab_size = LLM_VOCAB_SIZE_MAP["nemotron"] + + return ( + self.gbs + * self.enc_seq_len + * self.layers + * self.hs + * self.hs + * ( + 12 + + (12 * self.query_groups / self.attention_heads) + + (12 * self.ffn_hs / self.hs) + + (12 * self.enc_seq_len / self.hs) + + (6 * vocab_size / (self.layers * self.hs)) + ) + ) + + def _mixtral(self): + """Model FLOPs for mixtral family""" + vocab_size = LLM_VOCAB_SIZE_MAP["mixtral"] + + return ( + self.gbs + * self.enc_seq_len + * self.layers + * self.hs + * self.hs + * ( + 12 + + (12 * self.query_groups / self.attention_heads) + + (18 * self.moe_router_topk * self.ffn_hs / self.hs) + + (12 * self.enc_seq_len / self.hs) + + (6 * vocab_size / (self.layers * self.hs)) + ) + ) + + def _bert(self): + """Model FLOPs for BERT family""" + vocab_size = LLM_VOCAB_SIZE_MAP["bert"] + + return ( + 72 + * self.gbs + * self.layers + * self.enc_seq_len + * self.hs + * self.hs + * (1 + (self.enc_seq_len / (6 * self.hs)) + (vocab_size / (12 * self.hs * self.layers))) + ) diff --git a/nemo/collections/common/parts/perf_metrics_utils.py b/nemo/collections/common/parts/perf_metrics_utils.py new file mode 100644 index 000000000000..41273797e035 --- /dev/null +++ b/nemo/collections/common/parts/perf_metrics_utils.py @@ -0,0 +1,46 @@ +import glob +import os +from typing import List + +from tensorboard.backend.event_processing import event_accumulator + +from nemo.utils import logging + +LLM_VOCAB_SIZE_MAP = { + "gpt3": 51200, + "llama2": 32000, + "llama3": 128256, + "nemotron": 256000, + "bert": 29000, + "mixtral": 32000, +} + + +def read_tb_log(path: str, summary_name: str) -> List: + """ + Reads a TensorBoard Events file from the input path, and returns the + summary specified. + + Args: + path: str, path to the dir where the events file is located. + summary_name: str, name of the summary to read from the TB logs. + Returns: + summary_list: list, the values in the read summary list, formatted as a list. + """ + + files = glob.glob(f"{path}/events*tfevents*") + files.sort(key=lambda x: os.path.getmtime(os.path.join(path, x))) + if len(files) == 0 or not os.path.isfile(files[0]): + raise FileNotFoundError(f"Missing TensorBoard log file.") + + events_file = files[0] + try: + ea = event_accumulator.EventAccumulator(events_file) + ea.Reload() + summary = ea.Scalars(summary_name) + summary_list = [round(x.value, 2) for x in summary] + logging.info(f"{summary_name}: {summary_list}") + except KeyError: + raise KeyError(f"{summary_name} not found in {events_file}") + + return summary_list diff --git a/nemo/collections/common/parts/utils.py b/nemo/collections/common/parts/utils.py index c22c433bdfdf..e8eb1b999292 100644 --- a/nemo/collections/common/parts/utils.py +++ b/nemo/collections/common/parts/utils.py @@ -16,6 +16,8 @@ import os from typing import Iterable, List +import einops +import torch import torch.nn as nn __all__ = ['if_exist', '_compute_softmax', 'flatten'] @@ -105,3 +107,27 @@ def extend_instance(obj, mixin): obj.__class__ = type( base_cls_name, (mixin, base_cls), {} ) # mixin needs to go first for our forward() logic to work + + +def mask_sequence_tensor(tensor: torch.Tensor, lengths: torch.Tensor): + """ + For tensors containing sequences, zero out out-of-bound elements given lengths of every element in the batch. + + tensor: tensor of shape (B, L), (B, D, L) or (B, D1, D2, L), + lengths: LongTensor of shape (B,) + """ + batch_size, *_, max_lengths = tensor.shape + + if len(tensor.shape) == 2: + mask = torch.ones(batch_size, max_lengths).cumsum(dim=-1).type_as(lengths) + mask = mask <= einops.rearrange(lengths, 'B -> B 1') + elif len(tensor.shape) == 3: + mask = torch.ones(batch_size, 1, max_lengths).cumsum(dim=-1).type_as(lengths) + mask = mask <= einops.rearrange(lengths, 'B -> B 1 1') + elif len(tensor.shape) == 4: + mask = torch.ones(batch_size, 1, 1, max_lengths).cumsum(dim=-1).type_as(lengths) + mask = mask <= einops.rearrange(lengths, 'B -> B 1 1 1') + else: + raise ValueError('Can only mask tensors of shape B x L, B x D x L and B x D1 x D2 x L') + + return tensor * mask diff --git a/nemo/collections/common/prompts/canary.py b/nemo/collections/common/prompts/canary.py index e511368a1edf..f2b1e58c3bb2 100644 --- a/nemo/collections/common/prompts/canary.py +++ b/nemo/collections/common/prompts/canary.py @@ -16,9 +16,13 @@ class CanaryPromptFormatter(PromptFormatter): "template": f"{CANARY_BOS}|source_lang||task||target_lang||pnc|", "slots": { "source_lang": Modality.Text, - "task": Modality.TextLiteral("asr", "ast", "s2t_translation", "<|transcribe|>", "<|translate|>"), + "task": Modality.TextLiteral( + "asr", "ast", "translate", "transcribe", "s2t_translation", "<|transcribe|>", "<|translate|>" + ), "target_lang": Modality.Text, - "pnc": Modality.TextLiteral("yes", "no", "<|pnc|>", "<|nopnc|>"), + "pnc": Modality.TextLiteral( + "yes", "no", "true", "True", "false", "False", "1", "0", "pnc", "nopnc", "<|pnc|>", "<|nopnc|>" + ), }, }, OUTPUT_ROLE: { @@ -54,13 +58,18 @@ def map_manifest_values_to_special_tokens(slot_values: dict[str, str]) -> dict[s k = "pnc" if k in slot_values and slot_values[k] not in (CANARY_PNC, CANARY_NOPNC): - slot_values[k] = CANARY_PNC if slot_values[k] in ("yes", "1", "True", "true") else CANARY_NOPNC + slot_values[k] = CANARY_PNC if slot_values[k] in ("yes", "1", "True", "true", "pnc") else CANARY_NOPNC any_special_token_present = True # Note: we re-map 'taskname' to 'task' for compatibility with earlier versions of Canary training. for k in ("task", "taskname"): if k in slot_values and slot_values[k] not in ("<|transcribe|>", "<|translate|>"): - slot_values["task"] = "<|transcribe|>" if slot_values[k] == "asr" else "<|translate|>" + if slot_values[k] in {"translate", "ast", "s2t_translation"}: + slot_values["task"] = "<|translate|>" + elif slot_values[k] in {"transcribe", "asr"}: + slot_values["task"] = "<|transcribe|>" + else: + assert False, f"Task {slot_values[k]} invalid task for slot {k}" any_special_token_present = True # Auto-inject which tokenizer to look up in CanaryTokenizer if not provided, diff --git a/nemo/collections/common/prompts/formatter.py b/nemo/collections/common/prompts/formatter.py index 8a82563ebbaa..6d2c67f5311d 100644 --- a/nemo/collections/common/prompts/formatter.py +++ b/nemo/collections/common/prompts/formatter.py @@ -25,6 +25,9 @@ class BaseModalityType: def matches(value: Any) -> bool: raise NotImplementedError + def __repr__(self): + return f"Modality.{self.__class__.__name__}()" + class Text(BaseModalityType): """Modality for text values.""" @@ -42,7 +45,7 @@ def matches(self, value: str) -> bool: return isinstance(value, str) and value in self.allowed_values def __repr__(self): - return f"{self.__class__.__name__}({self.allowed_values})" + return f"Modality.{self.__class__.__name__}(allowed_values={self.allowed_values})" class Modality: diff --git a/nemo/collections/common/tokenizers/__init__.py b/nemo/collections/common/tokenizers/__init__.py index 6a71920bf6d4..4ba946cf9f76 100644 --- a/nemo/collections/common/tokenizers/__init__.py +++ b/nemo/collections/common/tokenizers/__init__.py @@ -19,6 +19,7 @@ from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer from nemo.collections.common.tokenizers.regex_tokenizer import RegExTokenizer from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer +from nemo.collections.common.tokenizers.tiktoken_tokenizer import TiktokenTokenizer from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec from nemo.collections.common.tokenizers.word_tokenizer import WordTokenizer diff --git a/nemo/collections/common/tokenizers/text_to_speech/ipa_lexicon.py b/nemo/collections/common/tokenizers/text_to_speech/ipa_lexicon.py index f4081735eb71..907c308e1ddc 100644 --- a/nemo/collections/common/tokenizers/text_to_speech/ipa_lexicon.py +++ b/nemo/collections/common/tokenizers/text_to_speech/ipa_lexicon.py @@ -15,7 +15,7 @@ # fmt: off -SUPPORTED_LOCALES = ["en-US", "de-DE", "es-ES", "it-IT", "fr-FR"] +SUPPORTED_LOCALES = ["en-US", "de-DE", "es-ES", "it-IT", "fr-FR", "vi-VN", "ja-JP"] DEFAULT_PUNCTUATION = ( ',', '.', '!', '?', '-', @@ -48,6 +48,19 @@ 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'Ä', 'Ö', 'Ü', 'ẞ', ), + # ref: https://en.wikipedia.org/wiki/Vietnamese_alphabet + "vi-VN": ( + 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', + 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', + 'U', 'V', 'W', 'X', 'Y', 'Z', 'Đ', 'Á', 'À', 'Ã', + 'Ả', 'Ạ', 'Ă', 'Ắ', 'Ằ', 'Ẵ', 'Ẳ', 'Ặ', 'Â', 'Ấ', + 'Ầ', 'Ẫ', 'Ẩ', 'Ậ', 'Ó', 'Ò', 'Õ', 'Ỏ', 'Ọ', 'Ô', + 'Ố', 'Ồ', 'Ỗ', 'Ổ', 'Ộ', 'Ơ', 'Ớ', 'Ờ', 'Ỡ', 'Ở', + 'Ợ', 'É', 'È', 'Ẽ', 'Ẻ', 'Ẹ', 'Ê', 'Ế', 'Ề', 'Ễ', + 'Ể', 'Ệ', 'Ú', 'Ù', 'Ũ', 'Ủ', 'Ụ', 'Ư', 'Ứ', 'Ừ', + 'Ữ', 'Ử', 'Ự', 'Í', 'Ì', 'Ĩ', 'Ỉ', 'Ị', 'Ý', 'Ỳ', + 'Ỹ', 'Ỷ', 'Ỵ', + ), "fr-FR": ( 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', @@ -104,6 +117,29 @@ 'ɽ','ʂ','ʈ','ʧ','ʉ','ʋ','ⱱ','ɤ','ʍ','χ','ʏ','ʑ','ʐ', 'ʔ','ʡ','ʕ','ʢ','ǀ','ǁ','ǂ','ᵻ', 'ʃ','ː', ), + "vi-VN": ( + 'a', 'ə', 'ɛ', 'e', 'i', 'o', 'ɔ', 'u', 'ɨ', + 'b', 'c', 'z', 'j', 'd', 'g', 'h', 'x', 'l', + 'm', 'n', 'ŋ', 'ɲ', 'p', 'f', 'w', 'r', 's', + 'ʃ', 't', 'ʈ', 'ʂ', 'v', 'ʔ', 'ɓ', 'ɗ', 'ɣ', + 'k', 'ʰ', 'ʷ', 'ɕ', 'ʑ', 'ʝ', '̚', '̟', 't͡', + '˧', 'ː', 'ɯ', '̀', '̄', '̌', '̂', 'ˀ', '͡', '˥', + '˩', '̤', '˨', 'ɹ', 'ʲ', '̯', 'ă', 'ə̆', 'ǐ', + '˦', 'æ', 'ɐ', + 'ɜ', 'ɡ', 'ɪ', 'ɬ' 'ɾ', 'ʊ', 'ʌ', 'ʒ', '̃', + '̩', 'θ', 'ᵻ', + ), + "ja-JP": ( + 'a', 'i', 'u', 'e', 'o', 'ɯ', 'I', 'ɑ' , 'ɨ ', 'ɒ', + 'ɔ', 'iᵑ', 'eᵑ', 'a', 'ʊ', 'ə', 'eᵝ', 'ɐ', 'ɛ', + 'w', 'k', 'ɾ', 's', 't', 'ʃ', 'r', 'h', 'n', 'nʲ', + 'ɲ', 'ç', 'b', 'm', 'j', 'ɸ', 'z', 'p', 'd', 'N', + 'ʒ', 'ŋ', 'g', 'f', 'ʔ', 'y', 'ɟ', 'v', 'ɥ', 'ɰ', + 'ɰᵝ', 'ɣ', 'ʄ', 'ʑ', 'c', 'ɕ', 'ɠ', 'x', 'l', 'β', + 'ð', 'ø', 'ʁ', 'ts', 'tʃ', 'dʒ', 'y', 'dʑ', 't͡s', + 'ɑ̃', 'ĩ', 'ũ', 'ẽ', 'õ', 'ɑ̃', 'ĩ', 'ũ', 'w̃', + 'ẽ', 'õ', 'hʲ', 'ɪ', 'ː', 'o̞', 'e̞', + ), } GRAPHEME_CHARACTER_CASES = ["upper", "lower", "mixed"] @@ -157,7 +193,7 @@ def get_ipa_punctuation_list(locale): punct_set = set(DEFAULT_PUNCTUATION) # TODO @xueyang: verify potential mismatches with locale-specific punctuation sets used # in nemo_text_processing.text_normalization.en.taggers.punctuation.py - if locale in ["de-DE", "es-ES", "it-IT", "fr-FR"]: + if locale in ["de-DE", "es-ES", "it-IT", "fr-FR", "ja-JP"]: # ref: https://en.wikipedia.org/wiki/Guillemet#Uses punct_set.update(['«', '»', '‹', '›']) if locale == "de-DE": @@ -218,6 +254,48 @@ def get_ipa_punctuation_list(locale): '̧', # combining cedilla, U+0327, decimal 807 ] ) - + elif locale == "ja-JP": + # ref: https://en.wikipedia.org/wiki/List_of_Japanese_typographic_symbols + punct_set.update( + [ + '【', + '】', + '…', + '‥', + '「', + '」', + '『', + '』', + '〜', + '。', + '、', + 'ー', + '・・・', + '〃', + '〔', + '〕', + '⦅', + '⦆', + '〈', + '〉', + '《', + '》', + '〖', + '〗', + '〘', + '〙', + '〚', + '〛', + '•', + '◦', + '﹅', + '﹆', + '※', + '*', + '〽', + '〓', + '〒', + ] + ) punct_list = sorted(list(punct_set)) return punct_list diff --git a/nemo/collections/common/tokenizers/text_to_speech/tokenizer_utils.py b/nemo/collections/common/tokenizers/text_to_speech/tokenizer_utils.py index 542b18186846..b92210b20288 100644 --- a/nemo/collections/common/tokenizers/text_to_speech/tokenizer_utils.py +++ b/nemo/collections/common/tokenizers/text_to_speech/tokenizer_utils.py @@ -24,11 +24,13 @@ "english_text_preprocessing", "any_locale_text_preprocessing", "spanish_text_preprocessing", + "vietnamese_text_preprocessing", "italian_text_preprocessing", "any_locale_word_tokenize", "english_word_tokenize", "LATIN_CHARS_ALL", "normalize_unicode_text", + "japanese_text_preprocessing", ] # Derived from LJSpeech @@ -201,3 +203,11 @@ def chinese_text_preprocessing(text: str) -> str: def french_text_preprocessing(text: str) -> str: return text.lower() + + +def vietnamese_text_preprocessing(text: str) -> str: + return text.lower() + + +def japanese_text_preprocessing(text: str) -> str: + return text.lower() diff --git a/nemo/collections/common/tokenizers/text_to_speech/tts_tokenizers.py b/nemo/collections/common/tokenizers/text_to_speech/tts_tokenizers.py index 1aefc6f1b4bb..943ad78a342a 100644 --- a/nemo/collections/common/tokenizers/text_to_speech/tts_tokenizers.py +++ b/nemo/collections/common/tokenizers/text_to_speech/tts_tokenizers.py @@ -30,7 +30,9 @@ english_text_preprocessing, french_text_preprocessing, italian_text_preprocessing, + japanese_text_preprocessing, spanish_text_preprocessing, + vietnamese_text_preprocessing, ) from nemo.utils import logging from nemo.utils.decorators import experimental @@ -202,6 +204,43 @@ def __init__( ) +class VietnameseCharsTokenizer(BaseCharsTokenizer): + + _LOCALE = "vi-VN" + _CHARSET_STR = get_grapheme_character_set(locale=_LOCALE, case="mixed") + + def __init__( + self, + chars=_CHARSET_STR, + punct=True, + apostrophe=True, + add_blank_at=None, + pad_with_space=False, + non_default_punct_list=None, + text_preprocessing_func=vietnamese_text_preprocessing, + ): + """Vietnamese grapheme tokenizer. + Args: + punct: Whether to reserve grapheme for basic punctuation or not. + apostrophe: Whether to use apostrophe or not. + add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), + if None then no blank in labels. + pad_with_space: Whether to pad text with spaces at the beginning and at the end or not. + non_default_punct_list: List of punctuation marks which will be used instead default. + text_preprocessing_func: Text preprocessing function for correct execution of the tokenizer. By default, it + would keep any word lowercase. + """ + super().__init__( + chars=chars, + punct=punct, + apostrophe=apostrophe, + add_blank_at=add_blank_at, + pad_with_space=pad_with_space, + non_default_punct_list=non_default_punct_list, + text_preprocessing_func=vietnamese_text_preprocessing, + ) + + class GermanCharsTokenizer(BaseCharsTokenizer): _LOCALE = "de-DE" @@ -245,7 +284,12 @@ class SpanishCharsTokenizer(BaseCharsTokenizer): PUNCT_LIST = get_ipa_punctuation_list("es-ES") def __init__( - self, punct=True, apostrophe=True, add_blank_at=None, pad_with_space=False, non_default_punct_list=None, + self, + punct=True, + apostrophe=True, + add_blank_at=None, + pad_with_space=False, + non_default_punct_list=None, ): """Spanish grapheme tokenizer. Args: @@ -274,7 +318,12 @@ class FrenchCharsTokenizer(BaseCharsTokenizer): PUNCT_LIST = get_ipa_punctuation_list("fr-FR") def __init__( - self, punct=True, apostrophe=True, add_blank_at=None, pad_with_space=False, non_default_punct_list=None, + self, + punct=True, + apostrophe=True, + add_blank_at=None, + pad_with_space=False, + non_default_punct_list=None, ): """French grapheme tokenizer. Args: @@ -916,3 +965,112 @@ def encode_from_g2p(self, g2p_text: List[str], raw_text: Optional[str] = None): ps = [space] + ps + [space] return [self._token2id[p] for p in ps] + + +class JapanesePhonemeTokenizer(BaseTokenizer): + + JA_PUNCT_LIST = get_ipa_punctuation_list("ja-JP") + + def __init__( + self, + g2p, + punct=True, + non_default_punct_list=None, + *, + space=' ', + silence=None, + apostrophe=True, + sep='|', # To be able to distinguish between 2/3 letters codes. + add_blank_at=None, + pad_with_space=False, + text_preprocessing_func=japanese_text_preprocessing, + ): + """Japanese phoneme-based tokenizer. + Note: This tokenizer for now covers Japanese phonemes + Args: + g2p: Grapheme to phoneme module. + punct: Whether to reserve grapheme for basic punctuation or not. + non_default_punct_list: List of punctuation marks which will be used instead default. + space: Space token as string. + silence: Silence token as string (will be disabled if it is None). + apostrophe: Whether to use apostrophe or not. + sep: Separation token as string. + add_blank_at: Add blank to labels in the specified order ("last") or after tokens (any non None), + if None then no blank in labels. + pad_with_space: Whether to pad text with spaces at the beginning and at the end or not. + text_preprocessing_func: Text preprocessing function for correct execution of the tokenizer. + Basically, it replaces all non-unicode characters with unicode ones. + Note that lower() function shouldn't be applied here, in case the text contains phonemes (it will be handled by g2p). + """ + tokens = [] + self.space, tokens = len(tokens), tokens + [space] # Space + + if silence is not None: + self.silence, tokens = len(tokens), tokens + [silence] # Silence + + self.phoneme_list = g2p.phoneme_list + self.ascii_letter_list = g2p.ascii_letter_list + + tokens.extend(self.phoneme_list) + tokens.extend(self.ascii_letter_list) + + self.text_preprocessing_func = text_preprocessing_func + + if apostrophe: + tokens.append("'") # Apostrophe + + if punct: + if non_default_punct_list is not None: + self.PUNCT_LIST = non_default_punct_list + else: + self.PUNCT_LIST = list(self.JA_PUNCT_LIST) + tokens.extend(self.PUNCT_LIST) + + super().__init__(tokens, sep=sep, add_blank_at=add_blank_at) + + self.punct = punct + self.pad_with_space = pad_with_space + self.g2p = g2p + + def encode(self, text: str) -> List[int]: + """See base class for more information.""" + text = self.text_preprocessing_func(text) + g2p_text = self.g2p(text) + return self.encode_from_g2p(g2p_text, text) + + def encode_from_g2p(self, g2p_text: List[str], raw_text: Optional[str] = None): + """ + Encodes text that has already been run through G2P. + Called for encoding to tokens after text preprocessing and G2P. + + Args: + g2p_text: G2P's output, could be a mixture of Chinese phonemes and English letters. + raw_text: original raw input + """ + ps, space, tokens = [], self.tokens[self.space], set(self.tokens) + for p in g2p_text: # noqa + # Add space if last one isn't one + if p == space and len(ps) > 0 and ps[-1] != space: + ps.append(p) + # Add next phoneme or tone or ascii letter or apostrophe. + elif (p.isalnum() or p == "'" or p in self.phoneme_list + self.ascii_letter_list) and p in tokens: + ps.append(p) + # Add punctuation + elif (p in self.PUNCT_LIST) and self.punct: + ps.append(p) + # Warn about unknown char/phoneme + elif p != space: + message = f"Text: [{' '.join(g2p_text)}] contains unknown char/phoneme: [{p}]." + if raw_text is not None: + message += f"Original text: [{raw_text}]. Symbol will be skipped." + logging.warning(message) + + # Remove trailing spaces + if ps: + while ps[-1] == space: + ps.pop() + + if self.pad_with_space: + ps = [space] + ps + [space] + + return [self._token2id[p] for p in ps] diff --git a/nemo/collections/common/tokenizers/tiktoken_tokenizer.py b/nemo/collections/common/tokenizers/tiktoken_tokenizer.py new file mode 100644 index 000000000000..4b1847051cdc --- /dev/null +++ b/nemo/collections/common/tokenizers/tiktoken_tokenizer.py @@ -0,0 +1,200 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import json +import os +from pathlib import Path +from typing import Dict, List, Optional + +try: + import tiktoken +except ImportError: + pass + +from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec + +__all__ = ['TiktokenTokenizer'] + + +def reload_mergeable_ranks( + path: str, + max_vocab: Optional[int] = None, +) -> Dict[bytes, int]: + """ + Reload the tokenizer JSON file and convert it to Tiktoken format. + """ + assert path.endswith(".json") + + # reload vocab + with open(path, "r") as f: + vocab = json.load(f) + assert isinstance(vocab, list) + print(f"Vocab size: {len(vocab)}") + if max_vocab is not None: + vocab = vocab[:max_vocab] + print(f"Cutting vocab to first {len(vocab)} tokens.") + + # build ranks + ranks: Dict[bytes, int] = {} + for i, x in enumerate(vocab): + assert x.keys() == {"rank", "token_bytes", "token_str"} + assert x["rank"] == i + merge = base64.b64decode(x["token_bytes"]) + assert i >= 256 or merge == bytes([i]) + ranks[merge] = x["rank"] + + # sanity check + assert len(ranks) == len(vocab) + assert set(ranks.values()) == set(range(len(ranks))) + + return ranks + + +PATTERN_TIKTOKEN = "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" +DEFAULT_TIKTOKEN_MAX_VOCAB = 2**17 # 131072 +SPECIAL_TOKENS = ["", "", ""] +SPECIAL_TOKEN_TEMPLATE = "" + + +class TiktokenTokenizer(TokenizerSpec): + """ + TiktokenTokenizer https://github.com/openai/tiktoken. + + Args: + model_path: path to tokenizer vocabulary + num_special_tokens: number of special tokens to generate + special_tokens: template for user-defined special tokens + pattern: Regex pattern to split the text + """ + + def __init__( + self, + vocab_file: str, + pattern: str = PATTERN_TIKTOKEN, + vocab_size: int = DEFAULT_TIKTOKEN_MAX_VOCAB, # 131072 + num_special_tokens: int = 1000, + special_tokens: Optional[List[str]] = None, + ): + if not vocab_file or not os.path.exists(vocab_file): + raise ValueError(f"vocab_file: {vocab_file} is invalid") + + if special_tokens is None: + special_tokens = SPECIAL_TOKENS.copy() + + assert len(special_tokens) == len(set(special_tokens)), f"Special tokens should be unique: {special_tokens}" + assert len(special_tokens) <= num_special_tokens < vocab_size + assert set(SPECIAL_TOKENS) <= set(special_tokens), f"Custom special tokens should include {SPECIAL_TOKENS}" + + self._unk_id = special_tokens.index("") + self._bos_id = special_tokens.index("") + self._eos_id = special_tokens.index("") + + self._vocab_size = vocab_size + print(f'{self._vocab_size = }') + self.num_special_tokens = num_special_tokens + special_filler = [SPECIAL_TOKEN_TEMPLATE.format(id=i) for i in range(len(special_tokens), num_special_tokens)] + if special_filler: + print(f"Adding special tokens {special_filler[0]}, ..., {special_filler[-1]}") + self.special_tokens = special_tokens + special_filler + assert len(set(self.special_tokens)) == len(self.special_tokens) == num_special_tokens, self.special_tokens + self.inner_vocab_size = vocab_size - num_special_tokens + + # reload vocab + self.token2id = reload_mergeable_ranks(vocab_file, max_vocab=self.inner_vocab_size) + self.id2token = {v: k for k, v in self.token2id.items()} + assert set(range(self.inner_vocab_size)) == set(self.id2token.keys()) + + self.shifted_id2token = {i: tok for i, tok in enumerate(self.special_tokens)} + for key, value in self.id2token.items(): + self.shifted_id2token[key + self.num_special_tokens] = value + + self.tokenizer = tiktoken.Encoding( + name=Path(vocab_file).parent.name, + pat_str=pattern, + mergeable_ranks=self.token2id, + special_tokens={}, # special tokens are handled manually + ) + + def text_to_tokens(self, text: str): + token_ids = self.tokenizer.encode(text) + return [self.tokenizer.decode_single_token_bytes(token) for token in token_ids] + + def tokens_to_text(self, tokens: List[int]): + token_ids = [self.tokenizer.encode_single_token(tokens) for tokens in tokens] + return self.tokenizer.decode(token_ids) + + def token_to_id(self, token): + return self.tokenizer.encode_single_token(token) + + def tokens_to_ids(self, tokens): + return [self.tokenizer.encode_single_token(token) for token in tokens] + + def ids_to_tokens(self, token_ids): + tokens = [] + for token_id in token_ids: + if token_id < self.num_special_tokens: + tokens.append(self.special_tokens[token_id]) + else: + token_id -= self.num_special_tokens + token_bytes = self.tokenizer.decode_single_token_bytes(token_id) + tokens.append(token_bytes.decode('utf-8', errors='replace')) + return tokens + + def text_to_ids(self, text: str): + tokens = self.tokenizer.encode(text) + tokens = [t + self.num_special_tokens for t in tokens] + return tokens + + def ids_to_text(self, tokens: List[int]): + # Filter out special tokens and adjust the remaining tokens + adjusted_tokens = [ + t - self.num_special_tokens + for t in tokens + if t not in {self.bos, self.eos} and t >= self.num_special_tokens + ] + + # Decode only if there are tokens left after filtering + if adjusted_tokens: + return self.tokenizer.decode(adjusted_tokens) + else: + return "" # Return an empty string if all tokens were filtered out + + @property + def bos_id(self): + return self._bos_id + + @property + def eos_id(self): + return self._eos_id + + @property + def unk_id(self): + return self._unk_id + + @property + def vocab(self): + return self.token2id + + @property + def decoder(self): + return self.shifted_id2token + + @property + def encoder(self): + return self.vocab + + @property + def vocab_size(self) -> int: + return self._vocab_size diff --git a/nemo/collections/llm/__init__.py b/nemo/collections/llm/__init__.py index 83c0a3af48c0..3ef8f6dd7fe4 100644 --- a/nemo/collections/llm/__init__.py +++ b/nemo/collections/llm/__init__.py @@ -38,28 +38,12 @@ MistralConfig7B, MistralModel, MixtralConfig8x7B, + MixtralConfig8x22B, MixtralModel, gpt_data_step, gpt_forward_step, ) -from nemo.collections.llm.gpt.model.api import ( - code_gemma_2b, - code_gemma_7b, - code_llama_7b, - code_llama_13b, - code_llama_34b, - code_llama_70b, - gemma, - gemma_2b, - gemma_7b, - llama2_7b, - llama2_13b, - llama2_70b, - llama3_8b, - llama3_70b, - mistral, - mixtral, -) +from nemo.collections.llm.recipes import * # noqa __all__ = [ "MockDataModule", @@ -71,6 +55,7 @@ "MistralConfig7B", "MistralModel", "MixtralConfig8x7B", + "MixtralConfig8x22B", "MixtralModel", "LlamaConfig", "Llama2Config7B", @@ -103,21 +88,5 @@ "mock", "squad", "dolly", - "mistral", - "mixtral", - "llama2_7b", - "llama3_8b", - "llama2_13b", - "llama2_70b", - "llama3_70b", - "code_llama_7b", - "code_llama_13b", - "code_llama_34b", - "code_llama_70b", - "gemma", - "gemma_2b", - "gemma_7b", - "code_gemma_2b", - "code_gemma_7b", "peft", ] diff --git a/nemo/collections/llm/api.py b/nemo/collections/llm/api.py index 5c9703497597..56da9e5496b2 100644 --- a/nemo/collections/llm/api.py +++ b/nemo/collections/llm/api.py @@ -279,7 +279,7 @@ def _setup( model_transform: Optional[Union[PEFT, ModelTransform, Callable]], ) -> Any: # Return type is Any because app_state's type is not specified _log = log or NeMoLogger() - if resume and resume.adapter_path and _log.ckpt: + if resume and isinstance(model_transform, PEFT) and _log.ckpt: logging.info("Disabling try_restore_best_ckpt restoration for adapters") _log.ckpt.try_restore_best_ckpt = False @@ -289,7 +289,7 @@ def _setup( task_config=getattr(train, "__io__", None), ) if resume is not None: - resume.setup(model, trainer) + resume.setup(trainer, model) if optim: optim.connect(model) diff --git a/nemo/collections/llm/fn/mixin.py b/nemo/collections/llm/fn/mixin.py index b32f66366bfb..c566c6e9d392 100644 --- a/nemo/collections/llm/fn/mixin.py +++ b/nemo/collections/llm/fn/mixin.py @@ -2,6 +2,7 @@ from typing_extensions import Self from nemo.collections.llm.fn import base as fn +from nemo.utils import logging class FNMixin: @@ -114,8 +115,12 @@ def freeze(self) -> None: """ assert isinstance(self, nn.Module), "self is not a nn.Module" - for param in self.parameters(): - param.requires_grad = False + params = list(self.parameters()) + if not params: + logging.info(f"No parameters found in module {self.__class__.__name__}") + else: + for param in params: + param.requires_grad = False def unfreeze(self) -> None: """ @@ -124,5 +129,9 @@ def unfreeze(self) -> None: """ assert isinstance(self, nn.Module), "self is not a nn.Module" - for param in self.parameters(): - param.requires_grad = True + params = list(self.parameters()) + if not params: + logging.info(f"No parameters found in module {self.__class__.__name__}") + else: + for param in params: + param.requires_grad = True diff --git a/nemo/collections/llm/gpt/data/core.py b/nemo/collections/llm/gpt/data/core.py index 8d99583016a4..6f8fe237e10a 100644 --- a/nemo/collections/llm/gpt/data/core.py +++ b/nemo/collections/llm/gpt/data/core.py @@ -32,6 +32,7 @@ def create_sft_dataset( truncation_method: str = 'right', memmap_workers: int = 2, hf_dataset: bool = False, + global_sample_mapping: bool = False, **kwargs, ) -> "GPTSFTDataset": from nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_dataset import GPTSFTDataset @@ -42,6 +43,7 @@ def create_sft_dataset( max_seq_length=seq_length, memmap_workers=memmap_workers, hf_dataset=hf_dataset, + global_sample_mapping=global_sample_mapping, add_bos=add_bos, add_eos=add_eos, add_sep=add_sep, diff --git a/nemo/collections/llm/gpt/data/dolly.py b/nemo/collections/llm/gpt/data/dolly.py index 9632a142eb35..7ed17e460e0f 100644 --- a/nemo/collections/llm/gpt/data/dolly.py +++ b/nemo/collections/llm/gpt/data/dolly.py @@ -7,13 +7,14 @@ from nemo.collections.llm.gpt.data.core import get_dataset_root from nemo.collections.llm.gpt.data.fine_tuning import FineTuningDataModule +from nemo.lightning.io.mixin import IOMixin from nemo.utils import logging if TYPE_CHECKING: from nemo.collections.common.tokenizers import TokenizerSpec -class DollyDataModule(FineTuningDataModule): +class DollyDataModule(FineTuningDataModule, IOMixin): """A data module for fine-tuning on the Dolly dataset. This class inherits from the `FineTuningDataModule` class and is specifically designed for fine-tuning models on the diff --git a/nemo/collections/llm/gpt/data/pre_training.py b/nemo/collections/llm/gpt/data/pre_training.py index 46b407410d31..6622054c4f98 100644 --- a/nemo/collections/llm/gpt/data/pre_training.py +++ b/nemo/collections/llm/gpt/data/pre_training.py @@ -1,11 +1,14 @@ +import logging +import warnings from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, List, Optional import pytorch_lightning as pl from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS from torch.utils import data -from torch.utils.data import DataLoader +from nemo.lightning.data import WrappedDataLoader +from nemo.lightning.io.mixin import IOMixin from nemo.lightning.pytorch.plugins import MegatronDataSampler if TYPE_CHECKING: @@ -14,19 +17,53 @@ from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec -class PreTrainingDataModule(pl.LightningDataModule): +class PreTrainingDataModule(pl.LightningDataModule, IOMixin): + """PyTorch Lightning-compatible data module for pre-training + GPT-style models. + Args: + paths (Path | List | Dict[str, List]): Paths of the data distributions. Can be either a + single path, a list of paths, or a dictionary. If a single path or a list of paths, + the given paths will be used to generate the train, validation and test datasets. If + providing a list of paths, the format can be either (1) a list of paths, e.g. + ["path/to/dataset_1_prefix", "path/to/dataset_2_prefix"], + or (2) a flattened, zipped list of weights and paths, e.g. + ["30", "path/to/dataset_1_prefix", "70", "path/to/dataset_2_prefix"] + If a dictionary is provided, it is expected to have the following form: + { + 'train': , + 'validation': , + 'test': + } + where each value is either a path or a list of paths as described above. + In this case, each split will be generated using the given paths. + Note that if limit_val_batches <= 1, we generate the entire validaton dataset, so + weights should not be provided for the validation split. + seq_length (int): Sequence length. + tokenizer (Optional["TokenizerSpec"]): An instance of a TokenizerSpec object. + micro_batch_size (int): Batch size per GPU. + global_batch_size (int): Global batch size. + rampup_batch_size (Optional[List[int]]): Rampup batch size, should be in format of + [start_global_batch_size, batch_size_increment, ramup_samples]. + num_workers (int): See ``torch.utils.data.DataLoader`` documentation. + pin_memory (bool): See ``torch.utils.data.DataLoader`` documentation. + persistent_workers (bool): See ``torch.utils.data.DataLoader`` documentation. + reset_position_ids (bool): Option to reset the position IDs in the dataset at an interval. + reset_attention_mask (bool): Option to reset the attention mask from the dataset. + eod_mask_loss (int): Option to enable the EOD mask loss. + seed (int): Seed for generating the GPT dataset. + split (str): A string of 3 comma-separated integers denoting how much of the distribution + to allocate to train, validation, and test sets, respectively. Unused if ``paths`` is a dict. + index_mapping_dir (Optional[str]): Path to a directory to write index mapping files. + """ + def __init__( self, - paths: Path | List[Path], - weights: Optional[List[float]] = None, + paths: Path | List | Dict[str, List], seq_length: int = 2048, tokenizer: Optional["TokenizerSpec"] = None, micro_batch_size: int = 4, global_batch_size: int = 8, rampup_batch_size: Optional[List[int]] = None, - num_train_samples: int = 10_000, - num_val_samples: int = 10_000, - num_test_samples: int = 10_000, num_workers: int = 8, pin_memory: bool = True, persistent_workers: bool = False, @@ -38,21 +75,32 @@ def __init__( index_mapping_dir: Optional[str] = None, ) -> None: super().__init__() - if not isinstance(paths, (list, tuple)): + if not isinstance(paths, (list, tuple, dict)): paths = [paths] - if weights is not None: - assert len(weights) == len(paths) - if len(weights) == 1: - # weights must be None if there is only one dataset + + from megatron.core.datasets.utils import get_blend_from_list + + build_kwargs = {} + if isinstance(paths, dict): + if split is not None: + warnings.warn( + f"{split=} will be ignored since datasets are being created " f"from 3 separate distributions." + ) + build_kwargs["blend_per_split"] = [ + get_blend_from_list(paths["train"]), + get_blend_from_list(paths["validation"]), + get_blend_from_list(paths["test"]), + ] + else: + paths, weights = get_blend_from_list(paths) + if len(paths) == 1: weights = None + build_kwargs["blend"] = [paths, weights] + build_kwargs["split"] = split - self.paths = paths - self.weights = weights + self.build_kwargs = build_kwargs self.seq_length = seq_length self.tokenizer = tokenizer - self.num_train_samples = num_train_samples - self.num_val_samples = num_val_samples - self.num_test_samples = num_test_samples self.num_workers = num_workers self.pin_memory = pin_memory self.persistent_workers = persistent_workers @@ -92,8 +140,19 @@ def setup(self, stage: str = "") -> None: num_test_samples = int(test_iters * self.data_sampler.global_batch_size) if self.trainer.limit_val_batches <= 1.0 and isinstance(self.trainer.limit_val_batches, float): + assert "blend" not in self.build_kwargs, ( + "When using a single data distribution, limit_val_batches <= 1.0 is not supported. If you'd " + "like to run with a fractional value of limit_val_batches, please pass in separate datasets for " + "the train, validation, and test datasets by providing a dictionary of paths, e.g.: \n" + " paths={ \n " + " 'train': [PATHS FOR TRAIN], \n " + " 'validation': [PATHS FOR VALIDATION], \n " + " 'test' :[PATHS FOR TEST], \n" + " }" + ) + # This is to make sure we only have one epoch on every validation iteration - num_val_samples = None if self.weights is None else 1 + num_val_samples = None train_valid_test_num_samples = [num_train_samples, num_val_samples, num_test_samples] self._train_ds, self._validation_ds, self._test_ds = BlendedMegatronDatasetBuilder( @@ -121,39 +180,40 @@ def setup(self, stage: str = "") -> None: # ).build() def train_dataloader(self) -> TRAIN_DATALOADERS: - return self._create_dataloader(self._train_ds) + return self._create_dataloader(self._train_ds, mode='train') def val_dataloader(self) -> EVAL_DATALOADERS: - return self._create_dataloader(self._validation_ds) + return self._create_dataloader(self._validation_ds, mode='validation') def test_dataloader(self) -> EVAL_DATALOADERS: - return self._create_dataloader(self._test_ds) + return self._create_dataloader(self._test_ds, mode='test') - def _create_dataloader(self, dataset, **kwargs) -> DataLoader: + def _create_dataloader(self, dataset, mode, **kwargs) -> WrappedDataLoader: self.init_global_step = self.trainer.global_step - return DataLoader( - dataset, + dataloader = WrappedDataLoader( + mode=mode, + dataset=dataset, num_workers=self.num_workers, pin_memory=self.pin_memory, persistent_workers=self.persistent_workers, collate_fn=getattr(dataset, 'collate_fn', data.dataloader.default_collate), **kwargs, ) + return dataloader @property def gpt_dataset_config(self) -> "GPTDatasetConfig": from megatron.core.datasets.gpt_dataset import GPTDatasetConfig return GPTDatasetConfig( - blend=[[str(path) for path in self.paths], self.weights], random_seed=self.seed, sequence_length=self.seq_length, tokenizer=self.tokenizer, - split=self.split, path_to_cache=self.index_mapping_dir, reset_position_ids=self.reset_position_ids, reset_attention_mask=self.reset_attention_mask, eod_mask_loss=self.eod_mask_loss, + **self.build_kwargs, ) def state_dict(self) -> Dict[str, Any]: @@ -174,24 +234,72 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: """ try: - from apex.transformer.pipeline_parallel.utils import _GLOBAL_NUM_MICROBATCHES_CALCULATOR - except ModuleNotFoundError: - from nemo.lightning.apex_utils import _GLOBAL_NUM_MICROBATCHES_CALCULATOR + from megatron.core.num_microbatches_calculator import update_num_microbatches + + except (ImportError, ModuleNotFoundError): + logging.warning("Megatron num_microbatches_calculator not found, using Apex version.") + from apex.transformer.pipeline_parallel.utils import update_num_microbatches + consumed_samples = state_dict['consumed_samples'] self.data_sampler.init_consumed_samples = consumed_samples self.data_sampler.prev_consumed_samples = consumed_samples - num_microbatch_calculator = _GLOBAL_NUM_MICROBATCHES_CALCULATOR # noqa: SLF001 - num_microbatch_calculator.update( + update_num_microbatches( consumed_samples=consumed_samples, consistency_check=False, ) - current_global_batch_size = num_microbatch_calculator.current_global_batch_size - '''pl_module.log( - "global_batch_size", - current_global_batch_size, - prog_bar=True, - rank_zero_only=True, - batch_size=1, - )''' - self.if_first_step = 1 + self.data_sampler.if_first_step = 1 + + def reconfigure_limit_batches(self): + # Override limit_train_batches in terms of num of microbatches + self._reconfigure_limit_batches(self.trainer.limit_train_batches, self._train_ds, 'train') + # Override limit_val_batches to be a multiple of num microbatches to prevent val_step from exiting in between a step + self._reconfigure_limit_batches(self.trainer.limit_val_batches, self._validation_ds, 'val') + + def _reconfigure_limit_batches(self, limit_batches, dataloader, mode): + """ + Reconfigure trainer.limit_val_batches for pretraining + """ + # Override limit_batches in terms of num microbatches and so there are limit_batches//num_micro_batches num of global batches + try: + from megatron.core.num_microbatches_calculator import get_num_microbatches + + except (ImportError, ModuleNotFoundError): + logging.warning("Megatron num_microbatches_calculator not found, using Apex version.") + from apex.transformer.pipeline_parallel.utils import get_num_microbatches + + if isinstance(limit_batches, int): + limit_batches *= get_num_microbatches() + else: + assert isinstance(limit_batches, float) + # Don't reconfigure if limit_batches is 0.0 or if there's no dataloader + if limit_batches == 0.0 or dataloader is None: + return + # len(dataloader) returns len as num of microbatches + dl_len_in_micro_batches = len(dataloader) + if len(dataloader) != float("inf"): + if limit_batches == 1.0: + limit_batches = dl_len_in_micro_batches + else: + limit_micro_batches = int(dl_len_in_micro_batches * limit_batches) + if limit_micro_batches == 0 and limit_batches > 0.0: + min_percentage = 1.0 / len(dataloader) + raise MisconfigurationException( + f"You requested to check {limit_batches} of the val_dataloader but" + f" {limit_batches} * {len(dataloader)} < 1. Please increase the" + f" `limit_val_batches` argument. Try at least" + f" `limit_val_batches={min_percentage}`" + ) + # Make sure trainer.limit_val_batches is a multiple of num of microbatches + if limit_micro_batches < get_num_microbatches(): + limit_batches = get_num_microbatches() + else: + limit_batches = limit_batches - limit_batches % get_num_microbatches() + + if mode == 'train': + self.trainer.limit_train_batches = limit_batches + else: + self.trainer.limit_val_batches = limit_batches + + # Override num sanity steps to be a multiple of num of microbatches + self.trainer.num_sanity_val_steps *= get_num_microbatches() diff --git a/nemo/collections/llm/gpt/data/squad.py b/nemo/collections/llm/gpt/data/squad.py index 77d48da98a0e..11104fe3cab2 100644 --- a/nemo/collections/llm/gpt/data/squad.py +++ b/nemo/collections/llm/gpt/data/squad.py @@ -6,13 +6,14 @@ from nemo.collections.llm.gpt.data.core import get_dataset_root from nemo.collections.llm.gpt.data.fine_tuning import FineTuningDataModule +from nemo.lightning.io.mixin import IOMixin from nemo.utils import logging if TYPE_CHECKING: from nemo.collections.common.tokenizers import TokenizerSpec -class SquadDataModule(FineTuningDataModule): +class SquadDataModule(FineTuningDataModule, IOMixin): """A data module for fine-tuning on the Squad dataset. This class inherits from the `FineTuningDataModule` class and is specifically designed for fine-tuning models on the @@ -124,3 +125,6 @@ def _preprocess_and_split_data( shutil.rmtree(p) elif '.jsonl' not in str(p.name): p.unlink() + + def reconfigure_limit_batches(self): + return diff --git a/nemo/collections/llm/gpt/model/__init__.py b/nemo/collections/llm/gpt/model/__init__.py index 4391a41293ee..e63c45ca99cd 100644 --- a/nemo/collections/llm/gpt/model/__init__.py +++ b/nemo/collections/llm/gpt/model/__init__.py @@ -29,7 +29,7 @@ LlamaModel, ) from nemo.collections.llm.gpt.model.mistral import MistralConfig7B, MistralModel -from nemo.collections.llm.gpt.model.mixtral import MixtralConfig8x7B, MixtralModel +from nemo.collections.llm.gpt.model.mixtral import MixtralConfig8x7B, MixtralConfig8x22B, MixtralModel __all__ = [ "GPTConfig", diff --git a/nemo/collections/llm/gpt/model/api.py b/nemo/collections/llm/gpt/model/api.py deleted file mode 100644 index 7c8cbf4d02e6..000000000000 --- a/nemo/collections/llm/gpt/model/api.py +++ /dev/null @@ -1,125 +0,0 @@ -import pytorch_lightning as pl - -from nemo.collections.llm.gpt.model.gemma import ( - CodeGemmaConfig2B, - CodeGemmaConfig7B, - GemmaConfig, - GemmaConfig2B, - GemmaConfig7B, - GemmaModel, -) -from nemo.collections.llm.gpt.model.llama import ( - CodeLlamaConfig7B, - CodeLlamaConfig13B, - CodeLlamaConfig34B, - CodeLlamaConfig70B, - Llama2Config7B, - Llama2Config13B, - Llama2Config70B, - Llama3Config8B, - Llama3Config70B, - LlamaModel, -) -from nemo.collections.llm.gpt.model.mistral import MistralConfig7B, MistralModel -from nemo.collections.llm.gpt.model.mixtral import MixtralConfig8x7B, MixtralModel -from nemo.collections.llm.utils import factory - - -@factory -def mistral() -> pl.LightningModule: - return MistralModel(MistralConfig7B()) - - -@factory -def mixtral() -> pl.LightningModule: - return MixtralModel(MixtralConfig8x7B()) - - -@factory -def llama2_7b() -> pl.LightningModule: - return LlamaModel(Llama2Config7B()) - - -@factory -def llama3_8b() -> pl.LightningModule: - return LlamaModel(Llama3Config8B()) - - -@factory -def llama2_13b() -> pl.LightningModule: - return LlamaModel(Llama2Config13B()) - - -@factory -def llama2_70b() -> pl.LightningModule: - return LlamaModel(Llama2Config70B()) - - -@factory -def llama3_70b() -> pl.LightningModule: - return LlamaModel(Llama3Config70B()) - - -@factory -def code_llama_7b() -> pl.LightningModule: - return LlamaModel(CodeLlamaConfig7B()) - - -@factory -def code_llama_13b() -> pl.LightningModule: - return LlamaModel(CodeLlamaConfig13B()) - - -@factory -def code_llama_34b() -> pl.LightningModule: - return LlamaModel(CodeLlamaConfig34B()) - - -@factory -def code_llama_70b() -> pl.LightningModule: - return LlamaModel(CodeLlamaConfig70B()) - - -@factory -def gemma() -> pl.LightningModule: - return GemmaModel(GemmaConfig()) - - -@factory -def gemma_2b() -> pl.LightningModule: - return GemmaModel(GemmaConfig2B()) - - -@factory -def gemma_7b() -> pl.LightningModule: - return GemmaModel(GemmaConfig7B()) - - -@factory -def code_gemma_2b() -> pl.LightningModule: - return GemmaModel(CodeGemmaConfig2B()) - - -@factory -def code_gemma_7b() -> pl.LightningModule: - return GemmaModel(CodeGemmaConfig7B()) - - -__all__ = [ - "mistral", - "mixtral", - "llama2_7b", - "llama3_8b", - "llama2_13b", - "llama2_70b", - "llama3_70b", - "code_llama_7b", - "code_llama_13b", - "code_llama_34b", - "code_llama_70b", - "gemma", - "gemma_2b", - "gemma_7b", - "code_gemma_2b", - "code_gemma_7b", -] diff --git a/nemo/collections/llm/gpt/model/base.py b/nemo/collections/llm/gpt/model/base.py index 4c1f425d7f99..a8339e124564 100644 --- a/nemo/collections/llm/gpt/model/base.py +++ b/nemo/collections/llm/gpt/model/base.py @@ -4,7 +4,6 @@ import pytorch_lightning as L import torch import torch.distributed -from megatron.core.models.gpt import gpt_layer_specs from megatron.core.optimizer import OptimizerConfig from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.transformer_config import TransformerConfig @@ -15,6 +14,12 @@ from nemo.lightning.megatron_parallel import MaskedTokenLossReduction from nemo.lightning.pytorch.optim import MegatronOptimizerModule, OptimizerModule +HAVE_TE = True +try: + import transformer_engine +except (ImportError, ModuleNotFoundError): + HAVE_TE = False + if TYPE_CHECKING: from megatron.core.models.gpt.gpt_model import GPTModel as MCoreGPTModel @@ -66,17 +71,28 @@ def gpt_forward_step(model, batch) -> torch.Tensor: def transformer_engine_layer_spec(config: "GPTConfig") -> ModuleSpec: + from megatron.core.models.gpt import gpt_layer_specs + return gpt_layer_specs.get_gpt_layer_with_transformer_engine_spec( num_experts=config.num_moe_experts, moe_grouped_gemm=config.moe_grouped_gemm, qk_layernorm=config.qk_layernorm ) def local_layer_spec(config: "GPTConfig") -> ModuleSpec: + from megatron.core.models.gpt import gpt_layer_specs + return gpt_layer_specs.get_gpt_layer_local_spec( num_experts=config.num_moe_experts, moe_grouped_gemm=config.moe_grouped_gemm, qk_layernorm=config.qk_layernorm ) +def default_layer_spec(config: "GPTConfig") -> ModuleSpec: + if HAVE_TE: + return transformer_engine_layer_spec(config) + else: + return local_layer_spec(config) + + @dataclass class GPTConfig(TransformerConfig, io.IOMixin): # From megatron.core.models.gpt.gpt_model.GPTModel @@ -93,7 +109,7 @@ class GPTConfig(TransformerConfig, io.IOMixin): # TODO: Move this to better places? get_attention_mask_from_fusion: bool = False - transformer_layer_spec: Union[ModuleSpec, Callable[["GPTConfig"], ModuleSpec]] = transformer_engine_layer_spec + transformer_layer_spec: Union[ModuleSpec, Callable[["GPTConfig"], ModuleSpec]] = default_layer_spec forward_step_fn: Callable = gpt_forward_step data_step_fn: Callable = gpt_data_step @@ -144,6 +160,8 @@ def __init__( self.optim = optim or MegatronOptimizerModule(config=OptimizerConfig(lr=1e-4, use_distributed_optimizer=True)) self.optim.connect(self) # This will bind the `configure_optimizers` method self.model_transform = model_transform + self._training_loss_reduction = None + self._validation_loss_reduction = None def configure_model(self) -> None: if not hasattr(self, "module"): @@ -184,11 +202,19 @@ def validation_step(self, batch, batch_idx=None) -> torch.Tensor: return self.forward_step(batch) + @property def training_loss_reduction(self) -> MaskedTokenLossReduction: - return MaskedTokenLossReduction() + if not self._training_loss_reduction: + self._training_loss_reduction = MaskedTokenLossReduction() + return self._training_loss_reduction + + @property def validation_loss_reduction(self) -> MaskedTokenLossReduction: - return MaskedTokenLossReduction(validation_step=True) + if not self._validation_loss_reduction: + self._validation_loss_reduction = MaskedTokenLossReduction(validation_step=True) + + return self._validation_loss_reduction def get_batch_on_this_context_parallel_rank(batch): diff --git a/nemo/collections/llm/gpt/model/mistral.py b/nemo/collections/llm/gpt/model/mistral.py index d1049cfe77ce..7e4cf8b6c74e 100644 --- a/nemo/collections/llm/gpt/model/mistral.py +++ b/nemo/collections/llm/gpt/model/mistral.py @@ -111,6 +111,7 @@ def make_vocab_size_divisible_by(mistral_vocab_size): num_layers=source.num_hidden_layers, hidden_size=source.hidden_size, ffn_hidden_size=source.intermediate_size, + kv_channels=source.get('head_dim', source.hidden_size // source.num_attention_heads), num_attention_heads=source.num_attention_heads, # max_position_embeddings=source.max_position_embeddings, init_method_std=source.initializer_range, @@ -183,6 +184,7 @@ def config(self) -> "MistralConfig": num_key_value_heads=source.num_query_groups, rope_theta=source.rotary_base, vocab_size=self.tokenizer.vocab_size, + head_dim=source.kv_channels, ) @@ -202,7 +204,7 @@ def _import_qkv(ctx: io.TransformCTX, q, k, v): heads_per_group = head_num // num_query_groups hidden_size = megatron_config.hidden_size head_num = megatron_config.num_attention_heads - head_size = hidden_size // head_num + head_size = megatron_config.kv_channels old_tensor_shape = q.size() new_q_tensor_shape = (head_num, head_size) + old_tensor_shape[1:] @@ -244,7 +246,7 @@ def _export_qkv(ctx: io.TransformCTX, linear_qkv): heads_per_group = head_num // num_query_groups hidden_size = megatron_config.hidden_size head_num = megatron_config.num_attention_heads - head_size = hidden_size // head_num + head_size = megatron_config.kv_channels qkv_total_dim = head_num + 2 * num_query_groups linear_qkv = linear_qkv.reshape([qkv_total_dim, head_size, hidden_size]) diff --git a/nemo/collections/llm/gpt/model/mixtral.py b/nemo/collections/llm/gpt/model/mixtral.py index 6256b67515ee..868ad5b332b5 100644 --- a/nemo/collections/llm/gpt/model/mixtral.py +++ b/nemo/collections/llm/gpt/model/mixtral.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Callable, Optional +from typing import TYPE_CHECKING, Callable, Optional, Union import torch import torch.nn.functional as F @@ -41,18 +41,55 @@ class MixtralConfig8x7B(GPTConfig): # MoE num_moe_experts: int = 8 moe_router_topk: int = 1 + moe_router_pre_softmax: bool = True init_method_std: float = 0.02 layernorm_epsilon: float = 1e-5 # rotary rotary_percent: float = 0.5 rotary_base: float = 10000 + bf16: bool = True + params_dtype: torch.dtype = torch.bfloat16 + + +@dataclass +class MixtralConfig8x22B(GPTConfig): + """ + Config for Mixtral-8x7B model + Official announcement: https://mistral.ai/news/mixtral-8x22b/ + """ + + normalization: str = "RMSNorm" + activation_func: Callable = F.silu + position_embedding_type: str = "rope" + add_bias_linear: bool = False + gated_linear_unit: bool = True + apply_query_key_layer_scaling: bool = False # TODO: Should this be True? + + num_layers: int = 56 + hidden_size: int = 6144 + num_attention_heads: int = 48 + num_query_groups: int = 8 + ffn_hidden_size: int = 16384 + max_position_embeddings: int = 65536 + seq_length: int = 4096 # 65536 + # MoE + num_moe_experts: int = 8 + moe_router_topk: int = 2 + + init_method_std: float = 0.02 + layernorm_epsilon: float = 1e-5 + # rotary + rotary_percent: float = 0 # TODO: @akoumparouli: is this correct? + rotary_base: float = 1000000 + bf16: bool = True + params_dtype: torch.dtype = torch.bfloat16 class MixtralModel(GPTModel): def __init__( self, - config: Optional[MixtralConfig8x7B] = None, + config: Optional[Union[MixtralConfig8x7B, MixtralConfig8x22B]] = None, optim: Optional[OptimizerModule] = None, tokenizer: Optional["TokenizerSpec"] = None, model_transform: Optional[Callable[[nn.Module], nn.Module]] = None, @@ -70,7 +107,7 @@ def init(self) -> MixtralModel: def apply(self, output_path: Path) -> Path: from transformers import MixtralForCausalLM - source = MixtralForCausalLM.from_pretrained(str(self)) + source = MixtralForCausalLM.from_pretrained(str(self), torch_dtype='auto', use_safetensors=True) target = self.init() trainer = self.nemo_setup(target) self.convert_state(source, target) @@ -104,16 +141,21 @@ def tokenizer(self) -> "AutoTokenizer": return AutoTokenizer(str(self)) @property - def config(self) -> MixtralConfig8x7B: + def config(self) -> MixtralConfig8x7B | MixtralConfig8x22B: from transformers import MixtralConfig as HfMixtralConfig config = HfMixtralConfig.from_pretrained(str(self)) - return MixtralConfig8x7B( + config_cls = MixtralConfig8x7B + if '8x22b' in str(self).lower(): + config_cls = MixtralConfig8x22B + return config_cls( + bf16=getattr(config, "torch_dtype", None) == torch.bfloat16, activation_func=F.silu, # network num_layers=config.num_hidden_layers, hidden_size=config.hidden_size, ffn_hidden_size=config.intermediate_size, + kv_channels=config.get('head_dim', config.hidden_size // config.num_attention_heads), max_position_embeddings=config.max_position_embeddings, # TODO seq_length=config.max_position_embeddings, # RoPE @@ -124,6 +166,7 @@ def config(self) -> MixtralConfig8x7B: num_query_groups=config.num_key_value_heads, num_moe_experts=config.num_local_experts, moe_router_topk=config.num_experts_per_tok, + moe_router_pre_softmax=True, # norm normalization='RMSNorm', layernorm_epsilon=config.rms_norm_eps, @@ -132,6 +175,10 @@ def config(self) -> MixtralConfig8x7B: gated_linear_unit=True, # Vocab make_vocab_size_divisible_by=128, + # CPU init + use_cpu_initialization=True, + perform_initialization=False, + params_dtype=getattr(config, "torch_dtype", torch.bfloat16), ) @@ -151,7 +198,7 @@ def _import_qkv(ctx: io.TransformCTX, q, k, v): heads_per_group = head_num // num_query_groups hidden_size = megatron_config.hidden_size head_num = megatron_config.num_attention_heads - head_size = hidden_size // head_num + head_size = megatron_config.kv_channels old_tensor_shape = q.size() new_q_tensor_shape = (head_num, head_size) + old_tensor_shape[1:] @@ -232,7 +279,8 @@ def tokenizer(self): @property def config(self) -> "MixtralConfig": - source: MixtralConfig7B = io.load_ckpt(str(self)).model.config + # Either MixtralConfig8x7B or MixtralConfig8x22B + source: MixtralConfig8x7B = io.load_ckpt(str(self)).model.config from transformers import MixtralConfig as HfMixtralConfig @@ -255,6 +303,7 @@ def config(self) -> "MixtralConfig": initializer_range=source.init_method_std, # vocab vocab_size=self.tokenizer.vocab_size, + head_dim=source.kv_channels, ) @@ -274,7 +323,7 @@ def _export_qkv(ctx: io.TransformCTX, linear_qkv): heads_per_group = head_num // num_query_groups hidden_size = megatron_config.hidden_size head_num = megatron_config.num_attention_heads - head_size = hidden_size // head_num + head_size = megatron_config.kv_channels qkv_total_dim = head_num + 2 * num_query_groups linear_qkv = linear_qkv.reshape([qkv_total_dim, head_size, hidden_size]) diff --git a/nemo/collections/llm/recipes/__init__.py b/nemo/collections/llm/recipes/__init__.py new file mode 100644 index 000000000000..8d4d874362a9 --- /dev/null +++ b/nemo/collections/llm/recipes/__init__.py @@ -0,0 +1,13 @@ +from nemo.collections.llm.recipes import llama2_7b, llama3_8b, llama3_8b_16k, llama3_8b_64k, mistral +from nemo.collections.llm.recipes.log.default import default_log +from nemo.collections.llm.recipes.optim import adam + +__all__ = [ + "llama3_8b", + "llama3_8b_16k", + "llama3_8b_64k", + "llama2_7b", + "mistral", + "adam", + "default_log", +] diff --git a/nemo/collections/llm/recipes/llama2_7b.py b/nemo/collections/llm/recipes/llama2_7b.py new file mode 100644 index 000000000000..1767dc4690c8 --- /dev/null +++ b/nemo/collections/llm/recipes/llama2_7b.py @@ -0,0 +1,61 @@ +import pytorch_lightning as pl + +from nemo import lightning as nl +from nemo.collections.llm.api import finetune, pretrain +from nemo.collections.llm.gpt.data.api import squad +from nemo.collections.llm.gpt.model.llama import Llama2Config7B, LlamaModel +from nemo.collections.llm.peft.api import gpt_lora +from nemo.collections.llm.recipes.log.default import default_log +from nemo.collections.llm.recipes.optim.adam import adam_with_cosine_annealing +from nemo.collections.llm.utils import Partial, factory + +NAME = "llama2_7b" + + +@factory(name=NAME) +def model() -> pl.LightningModule: + return LlamaModel(Llama2Config7B()) + + +@factory(name=NAME) +def trainer(devices=8) -> nl.Trainer: + strategy = nl.MegatronStrategy(tensor_model_parallel_size=2) + + return nl.Trainer( + devices=devices, + max_steps=100, + accelerator="gpu", + strategy=strategy, + plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"), + ) + + +@factory(name=NAME + "_hf") +def hf_resume() -> nl.AutoResume: + return nl.AutoResume(import_path="hf://meta-llama/Llama-2-7b-hf") + + +@factory(name=NAME, for_task="llm.pretrain") +def pretrain_recipe() -> Partial: + return Partial( + pretrain, + model=model, + trainer=trainer, + data=squad, + log=default_log, + optim=adam_with_cosine_annealing, + ) + + +@factory(name=NAME, for_task="llm.finetune") +def finetune_recipe() -> Partial: + return Partial( + finetune, + model=model, + trainer=trainer, + data=squad, + log=default_log, + optim=adam_with_cosine_annealing, + peft=gpt_lora, + resume=hf_resume, + ) diff --git a/nemo/collections/llm/recipes/llama3_8b.py b/nemo/collections/llm/recipes/llama3_8b.py new file mode 100644 index 000000000000..34ce418a0701 --- /dev/null +++ b/nemo/collections/llm/recipes/llama3_8b.py @@ -0,0 +1,61 @@ +import pytorch_lightning as pl + +from nemo import lightning as nl +from nemo.collections.llm.api import finetune, pretrain +from nemo.collections.llm.gpt.data.api import squad +from nemo.collections.llm.gpt.model.llama import Llama3Config8B, LlamaModel +from nemo.collections.llm.peft.api import gpt_lora +from nemo.collections.llm.recipes.log.default import default_log +from nemo.collections.llm.recipes.optim.adam import adam_with_cosine_annealing +from nemo.collections.llm.utils import Partial, factory + +NAME = "llama3_8b" + + +@factory(name=NAME) +def model() -> pl.LightningModule: + return LlamaModel(Llama3Config8B(seq_length=16384)) + + +@factory(name=NAME) +def trainer(devices=8) -> nl.Trainer: + strategy = nl.MegatronStrategy(tensor_model_parallel_size=2) + + return nl.Trainer( + devices=devices, + max_steps=100, + accelerator="gpu", + strategy=strategy, + plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"), + ) + + +@factory(name=NAME + "_hf") +def hf_resume() -> nl.AutoResume: + return nl.AutoResume(import_path="hf://meta-llama/Meta-Llama-3-8B") + + +@factory(name=NAME, for_task="llm.pretrain") +def pretrain_recipe() -> Partial: + return Partial( + pretrain, + model=model, + trainer=trainer, + data=squad, + log=default_log, + optim=adam_with_cosine_annealing, + ) + + +@factory(name=NAME, for_task="llm.finetune") +def finetune_recipe() -> Partial: + return Partial( + finetune, + model=model, + trainer=trainer, + data=squad, + log=default_log, + optim=adam_with_cosine_annealing, + peft=gpt_lora, + resume=hf_resume, + ) diff --git a/nemo/collections/llm/recipes/llama3_8b_16k.py b/nemo/collections/llm/recipes/llama3_8b_16k.py new file mode 100644 index 000000000000..3a590f26894e --- /dev/null +++ b/nemo/collections/llm/recipes/llama3_8b_16k.py @@ -0,0 +1,45 @@ +import pytorch_lightning as pl + +from nemo import lightning as nl +from nemo.collections.llm.api import pretrain +from nemo.collections.llm.gpt.data.api import squad +from nemo.collections.llm.gpt.model.llama import Llama3Config8B, LlamaModel +from nemo.collections.llm.recipes.log.default import default_log +from nemo.collections.llm.recipes.optim.adam import adam_with_cosine_annealing +from nemo.collections.llm.utils import Partial, factory + +NAME = "llama3_8b_16k" + + +@factory(name=NAME) +def model() -> pl.LightningModule: + return LlamaModel(Llama3Config8B(seq_length=16384)) + + +@factory(name=NAME) +def trainer(devices=8) -> nl.Trainer: + strategy = nl.MegatronStrategy( + tensor_model_parallel_size=4, + context_parallel_size=2, + sequence_parallel=True, + ) + + return nl.Trainer( + devices=devices, + max_steps=100, + accelerator="gpu", + strategy=strategy, + plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"), + ) + + +@factory(name=NAME, for_task="llm.pretrain") +def pretrain_recipe() -> Partial: + return Partial( + pretrain, + model=model, + trainer=trainer, + data=squad, + log=default_log, + optim=adam_with_cosine_annealing, + ) diff --git a/nemo/collections/llm/recipes/llama3_8b_64k.py b/nemo/collections/llm/recipes/llama3_8b_64k.py new file mode 100644 index 000000000000..c826feb28901 --- /dev/null +++ b/nemo/collections/llm/recipes/llama3_8b_64k.py @@ -0,0 +1,45 @@ +import pytorch_lightning as pl + +from nemo import lightning as nl +from nemo.collections.llm.api import pretrain +from nemo.collections.llm.gpt.data.api import squad +from nemo.collections.llm.gpt.model.llama import Llama3Config8B, LlamaModel +from nemo.collections.llm.recipes.log.default import default_log +from nemo.collections.llm.recipes.optim.adam import adam_with_cosine_annealing +from nemo.collections.llm.utils import Partial, factory + +NAME = "llama3_8b_64k" + + +@factory(name=NAME) +def model() -> pl.LightningModule: + return LlamaModel(Llama3Config8B(seq_length=65536)) + + +@factory(name=NAME) +def trainer(devices=8) -> nl.Trainer: + strategy = nl.MegatronStrategy( + tensor_model_parallel_size=8, + context_parallel_size=4, + sequence_parallel=True, + ) + + return nl.Trainer( + devices=devices, + max_steps=100, + accelerator="gpu", + strategy=strategy, + plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"), + ) + + +@factory(name=NAME, for_task="llm.pretrain") +def pretrain_recipe() -> Partial: + return Partial( + pretrain, + model=model, + trainer=trainer, + data=squad, + log=default_log, + optim=adam_with_cosine_annealing, + ) diff --git a/nemo/collections/llm/recipes/log/__init__.py b/nemo/collections/llm/recipes/log/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/nemo/collections/llm/recipes/log/default.py b/nemo/collections/llm/recipes/log/default.py new file mode 100644 index 000000000000..a40e141bfa95 --- /dev/null +++ b/nemo/collections/llm/recipes/log/default.py @@ -0,0 +1,15 @@ +from nemo import lightning as nl +from nemo.collections.llm.utils import factory + + +@factory +def default_log() -> nl.NeMoLogger: + ckpt = nl.ModelCheckpoint( + save_best_model=True, + save_last=True, + monitor="reduced_train_loss", + save_top_k=2, + save_on_train_epoch_end=True, + ) + + return nl.NeMoLogger(ckpt=ckpt) diff --git a/nemo/collections/llm/recipes/mistral.py b/nemo/collections/llm/recipes/mistral.py new file mode 100644 index 000000000000..12af8d5d18ff --- /dev/null +++ b/nemo/collections/llm/recipes/mistral.py @@ -0,0 +1,61 @@ +import pytorch_lightning as pl + +from nemo import lightning as nl +from nemo.collections.llm.api import finetune, pretrain +from nemo.collections.llm.gpt.data.api import squad +from nemo.collections.llm.gpt.model.mistral import MistralConfig7B, MistralModel +from nemo.collections.llm.peft.api import gpt_lora +from nemo.collections.llm.recipes.log.default import default_log +from nemo.collections.llm.recipes.optim.adam import adam_with_cosine_annealing +from nemo.collections.llm.utils import Partial, factory + +NAME = "mistral" + + +@factory(name=NAME) +def model() -> pl.LightningModule: + return MistralModel(MistralConfig7B()) + + +@factory(name=NAME) +def trainer(devices=8) -> nl.Trainer: + strategy = nl.MegatronStrategy(tensor_model_parallel_size=2) + + return nl.Trainer( + devices=devices, + max_steps=100, + accelerator="gpu", + strategy=strategy, + plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"), + ) + + +@factory(name=NAME + "_hf") +def hf_resume() -> nl.AutoResume: + return nl.AutoResume(import_path="hf://mistralai/Mistral-7B-v0.3") + + +@factory(name=NAME, for_task="llm.pretrain") +def pretrain_recipe() -> Partial: + return Partial( + pretrain, + model=model, + trainer=trainer, + data=squad, + log=default_log, + optim=adam_with_cosine_annealing, + ) + + +@factory(name=NAME, for_task="llm.finetune") +def finetune_recipe() -> Partial: + return Partial( + finetune, + model=model, + trainer=trainer, + data=squad, + log=default_log, + optim=adam_with_cosine_annealing, + peft=gpt_lora, + resume=hf_resume, + ) diff --git a/nemo/collections/llm/recipes/mixtral_8x22b_4k.py b/nemo/collections/llm/recipes/mixtral_8x22b_4k.py new file mode 100644 index 000000000000..4385e5a54827 --- /dev/null +++ b/nemo/collections/llm/recipes/mixtral_8x22b_4k.py @@ -0,0 +1,64 @@ +import pytorch_lightning as pl + +from nemo import lightning as nl +from nemo.collections.llm.api import finetune, pretrain +from nemo.collections.llm.gpt.data.api import squad +from nemo.collections.llm.gpt.model.llama import MixtralConfig8x22B, MixtralModel +from nemo.collections.llm.peft.api import gpt_lora +from nemo.collections.llm.recipes.log.default import default_log +from nemo.collections.llm.recipes.optim.adam import adam_with_cosine_annealing +from nemo.collections.llm.utils import Partial, factory + +NAME = "mixtral_8x22b_4k" + + +@factory(name=NAME) +def model() -> pl.LightningModule: + return MixtralModel(MixtralConfig8x22B(seq_length=4096)) + + +@factory(name=NAME) +def trainer(devices=8) -> nl.Trainer: + strategy = nl.MegatronStrategy( + tensor_model_parallel_size=8, + sequence_parallel=True, + ) + + return nl.Trainer( + devices=devices, + max_steps=100, + accelerator="gpu", + strategy=strategy, + plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"), + ) + + +@factory(name=NAME + "_hf") +def hf_resume() -> nl.AutoResume: + return nl.AutoResume(import_path="hf://mistralai/Mixtral-8x22B-v0.1") + + +@factory(name=NAME, for_task="llm.pretrain") +def pretrain_recipe() -> Partial: + return Partial( + pretrain, + model=model, + trainer=trainer, + data=squad, + log=default_log, + optim=adam_with_cosine_annealing, + ) + + +@factory(name=NAME, for_task="llm.finetune") +def finetune_recipe() -> Partial: + return Partial( + finetune, + model=model, + trainer=trainer, + data=squad, + log=default_log, + optim=adam_with_cosine_annealing, + peft=gpt_lora, + resume=hf_resume, + ) diff --git a/nemo/collections/llm/recipes/mixtral_8x7b_4k.py b/nemo/collections/llm/recipes/mixtral_8x7b_4k.py new file mode 100644 index 000000000000..d7543e51812e --- /dev/null +++ b/nemo/collections/llm/recipes/mixtral_8x7b_4k.py @@ -0,0 +1,64 @@ +import pytorch_lightning as pl + +from nemo import lightning as nl +from nemo.collections.llm.api import finetune, pretrain +from nemo.collections.llm.gpt.data.api import squad +from nemo.collections.llm.gpt.model.llama import MixtralConfig8x7B, MixtralModel +from nemo.collections.llm.peft.api import gpt_lora +from nemo.collections.llm.recipes.log.default import default_log +from nemo.collections.llm.recipes.optim.adam import adam_with_cosine_annealing +from nemo.collections.llm.utils import Partial, factory + +NAME = "mixtral_8x7b_4k" + + +@factory(name=NAME) +def model() -> pl.LightningModule: + return MixtralModel(MixtralConfig8x7B(seq_length=4096)) + + +@factory(name=NAME) +def trainer(devices=8) -> nl.Trainer: + strategy = nl.MegatronStrategy( + tensor_model_parallel_size=8, + sequence_parallel=True, + ) + + return nl.Trainer( + devices=devices, + max_steps=100, + accelerator="gpu", + strategy=strategy, + plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"), + ) + + +@factory(name=NAME + "_hf") +def hf_resume() -> nl.AutoResume: + return nl.AutoResume(import_path="hf://mistralai/Mixtral-8x7B-v0.1") + + +@factory(name=NAME, for_task="llm.pretrain") +def pretrain_recipe() -> Partial: + return Partial( + pretrain, + model=model, + trainer=trainer, + data=squad, + log=default_log, + optim=adam_with_cosine_annealing, + ) + + +@factory(name=NAME, for_task="llm.finetune") +def finetune_recipe() -> Partial: + return Partial( + finetune, + model=model, + trainer=trainer, + data=squad, + log=default_log, + optim=adam_with_cosine_annealing, + peft=gpt_lora, + resume=hf_resume, + ) diff --git a/nemo/collections/llm/recipes/optim/__init__.py b/nemo/collections/llm/recipes/optim/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/nemo/collections/llm/recipes/optim/adam.py b/nemo/collections/llm/recipes/optim/adam.py new file mode 100644 index 000000000000..4229001b2130 --- /dev/null +++ b/nemo/collections/llm/recipes/optim/adam.py @@ -0,0 +1,16 @@ +from megatron.core.optimizer import OptimizerConfig + +from nemo import lightning as nl +from nemo.collections.llm.utils import factory + + +@factory +def adam_with_cosine_annealing() -> nl.OptimizerModule: + return nl.MegatronOptimizerModule( + config=OptimizerConfig(optimizer="adam", lr=0.001, use_distributed_optimizer=True), + lr_scheduler=nl.lr_scheduler.CosineAnnealingScheduler(), + ) + + +# TODO: Fix the name-arg inside the factory-function so we don't need to do this +with_cosine_annealing = adam_with_cosine_annealing diff --git a/nemo/collections/llm/utils.py b/nemo/collections/llm/utils.py index b4382d0afd5f..26b511fcb26d 100644 --- a/nemo/collections/llm/utils.py +++ b/nemo/collections/llm/utils.py @@ -3,10 +3,10 @@ T = TypeVar('T', bound=Callable[..., Any]) try: - import nemo_sdk as sdk + import nemo_run as run - Config = sdk.Config - Partial = sdk.Partial + Config = run.Config + Partial = run.Partial except ImportError: _T = TypeVar('_T') @@ -19,9 +19,9 @@ class Partial(Generic[_T]): def task(*args: Any, **kwargs: Any) -> Callable[[T], T]: try: - import nemo_sdk as sdk + import nemo_run as run - return sdk.task(*args, **kwargs) + return run.task(*args, **kwargs) except ImportError: # Return a no-op function def noop_decorator(func: T) -> T: @@ -40,14 +40,13 @@ def factory(*args: Any, **kwargs: Any) -> Callable[[T], T]: ... def factory(*args: Any, **kwargs: Any) -> Union[Callable[[T], T], T]: try: - import nemo_sdk as sdk + import nemo_run as run - if not args and not kwargs: - # Used as @factory without arguments - return sdk.factory() + if not args: + return run.factory(**kwargs) else: # Used as @factory(*args, **kwargs) - return sdk.factory(*args, **kwargs) + return run.factory(*args, **kwargs) except ImportError: # Return a no-op function def noop_decorator(func: T) -> T: diff --git a/nemo/collections/multimodal/data/clip/clip_dataset.py b/nemo/collections/multimodal/data/clip/clip_dataset.py index 6b63d546194a..448efba4b8ba 100644 --- a/nemo/collections/multimodal/data/clip/clip_dataset.py +++ b/nemo/collections/multimodal/data/clip/clip_dataset.py @@ -57,8 +57,9 @@ def tokenize(texts: Union[str, List[str]], tokenizer: Any, context_length: int = bos_id = tokenizer.bos_id eos_id = tokenizer.eos_id - all_tokens = [[bos_id] + tokenizer.text_to_ids(text) + [eos_id] for text in texts] - result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + pad_id = tokenizer.pad_id + all_tokens = [([bos_id] if bos_id is not None else []) + tokenizer.text_to_ids(text) + [eos_id] for text in texts] + result = torch.ones(len(all_tokens), context_length, dtype=torch.long) * pad_id for i, tokens in enumerate(all_tokens): if len(tokens) > context_length: diff --git a/nemo/collections/multimodal/data/neva/conversation.py b/nemo/collections/multimodal/data/neva/conversation.py index 10a6c9e7283d..2e110eebe9e6 100644 --- a/nemo/collections/multimodal/data/neva/conversation.py +++ b/nemo/collections/multimodal/data/neva/conversation.py @@ -34,6 +34,10 @@ DEFAULT_IM_START_TOKEN["llama_3"] = "<|reserved_special_token_4|>" DEFAULT_IM_END_TOKEN["llama_3"] = "<|reserved_special_token_5|>" +DEFAULT_VID_START_TOKEN = "" +DEFAULT_VID_END_TOKEN = "" +TIME_TOKEN_TEMPLATE = "" + class SeparatorStyle(Enum): """Different separator style.""" diff --git a/nemo/collections/multimodal/data/neva/neva_dataset.py b/nemo/collections/multimodal/data/neva/neva_dataset.py index 7eef677e13a8..17cb6e6cf644 100644 --- a/nemo/collections/multimodal/data/neva/neva_dataset.py +++ b/nemo/collections/multimodal/data/neva/neva_dataset.py @@ -20,7 +20,6 @@ from dataclasses import dataclass from typing import Any, Dict, List, Sequence, Tuple, Union -import decord import numpy as np import torch import torch.nn.functional as F @@ -34,11 +33,15 @@ import nemo.collections.multimodal.data.neva.conversation as conversation_lib from nemo.collections.multimodal.data.clip.augmentations.augmentations import image_transform from nemo.collections.multimodal.data.neva.conversation import ( + DEFAULT_BOS_TOKEN, + DEFAULT_EOS_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IMAGE_TOKEN, DEFAULT_LABELS_TOKEN, + DEFAULT_VID_END_TOKEN, + DEFAULT_VID_START_TOKEN, DEFAULT_VIDEO_TOKEN, ) from nemo.collections.nlp.modules.common.megatron.utils import get_ltor_masks_and_position_ids @@ -46,6 +49,11 @@ MAX_NUM_IMAGES = 1 IGNORE_INDEX = -1 +try: + import decord +except Exception: + logging.warning("The package `decord` was not installed in this environment.") + try: from megatron.core.datasets.indexed_dataset import IndexedDataset @@ -145,7 +153,7 @@ def open_video(self, file_name): cap = decord.VideoReader(f) return self.flatten_frames(cap) else: - decord.bridge.set_bridge("torch") + # decord.bridge.set_bridge("torch") cap = decord.VideoReader(os.path.join(self.video_folder, file_name)) return self.flatten_frames(cap) return None @@ -171,9 +179,7 @@ def flatten_frames(self, cap): else: num_frames = min(len(cap), self.data_cfg['num_frames']) indices = np.linspace(0, len(cap) - 1, num_frames, dtype=int) - frames = [] - frames = cap.get_batch(indices) - + frames = [Image.fromarray(cap[i].asnumpy()).convert('RGB') for i in indices] while len(frames) < self.data_cfg['num_frames']: frames.append(frames[-1]) return frames @@ -226,6 +232,25 @@ def tokenize( return result +def get_tokens_ids(tokenizer, tokens): + """ + Returns the token id for a given token. + + Parameters + ---------- + tokenizer : nemo tokenizer + A tokenizer to be used for tokenization. + tokens : list + A list of tokens to get the token id for. + + Returns + ------- + List + The token ids. + """ + return [tokenizer.token_to_id(token) for token in tokens] + + def preprocess_multimodal(sources: dict, multimodal_cfg: dict, cur_token_len: int, use_plain: bool = False) -> Dict: """ Preprocesses multimodal sources based on the provided configuration. @@ -259,13 +284,15 @@ def preprocess_multimodal(sources: dict, multimodal_cfg: dict, cur_token_len: in if not is_multimodal: return sources - num_patches = image_token_len + num_frames = multimodal_cfg['num_frames'] + # vila + if multimodal_cfg['mm_mlp_adapter_type'] == 'mlp_downsample': + image_token_len //= 4 + num_patches = image_token_len + # TO DO: to support multiple images if media_type == 'video': - num_patches *= multimodal_cfg['num_frames'] - - if multimodal_cfg['mm_mlp_adapter_type'] == 'mlp_downsample': - num_patches //= 4 + num_patches *= num_frames if multimodal_cfg['use_im_start_end']: replace_token = DEFAULT_IMAGE_PATCH_TOKEN[model_type] * num_patches @@ -273,6 +300,44 @@ def preprocess_multimodal(sources: dict, multimodal_cfg: dict, cur_token_len: in replace_token = DEFAULT_IMAGE_PATCH_TOKEN[model_type] * (num_patches - 2) replace_token = DEFAULT_IM_START_TOKEN[model_type] + replace_token + DEFAULT_IM_END_TOKEN[model_type] + if media_type == 'video' and multimodal_cfg.get("use_lita", False): + if not multimodal_cfg.get('lita', None): + raise ValueError("LITA config is missing") + lita_video_arch = multimodal_cfg['lita']['lita_video_arch'] + num_temporal_tokens, num_spatial_tokens = num_frames, 0 + if lita_video_arch == 'temporal_all_resolution': + sample_frames = min(multimodal_cfg['lita']['sample_frames'], num_frames) + # num_frames for temporal tokens, sample_frames * num_patches for spatial tokens + num_spatial_tokens = sample_frames * image_token_len + else: + # num_frames for temporal tokens and num_patches for spatial tokens + num_spatial_tokens = image_token_len + num_tokens = num_temporal_tokens + num_spatial_tokens + + visual_token_format = multimodal_cfg['lita'].get('visual_token_format', 'v1') + media_start = DEFAULT_IM_START_TOKEN[model_type] + media_end = DEFAULT_IM_END_TOKEN[model_type] + image_patch = DEFAULT_IMAGE_PATCH_TOKEN[model_type] + if visual_token_format == 'im_vid_start_end': + image_start, image_end = DEFAULT_IM_START_TOKEN[model_type], DEFAULT_IM_END_TOKEN[model_type] + vid_start, vid_end = DEFAULT_VID_START_TOKEN, DEFAULT_VID_END_TOKEN + if multimodal_cfg['use_im_start_end']: + replace_token_list = [image_start + image_patch * image_token_len + image_end] * sample_frames + replace_token_list += [vid_start + image_patch * num_temporal_tokens + vid_end] + replace_token = "".join(replace_token_list) + else: + replace_token_list = [image_start + image_patch * (image_token_len - 1) + image_end] + replace_token_list += [image_start + image_patch * image_token_len + image_end] * (sample_frames - 1) + replace_token_list += [vid_start + image_patch * (num_temporal_tokens - 1) + vid_end] + replace_token = "".join(replace_token_list) + replace_token = media_start + replace_token + media_end + else: + if multimodal_cfg['use_im_start_end']: + replace_token = image_patch * num_tokens + else: + replace_token = image_patch * (num_tokens - 2) + replace_token = media_start + replace_token + media_end + for source in sources: conversation = source['conversations'] if multimodal_cfg['sep_image_conv_front']: @@ -290,7 +355,6 @@ def preprocess_multimodal(sources: dict, multimodal_cfg: dict, cur_token_len: in conversation[0]['value'] = default_token for turn in conversation: turn["value"] = turn["value"].replace(default_token, replace_token) - return sources @@ -475,9 +539,13 @@ def preprocess_llama_2( ) # llama tricks - tokens[tokens == 32003] = 0 # DEFAULT_IMAGE_PATCH_TOKEN - tokens[tokens == 32006] = 1 # - tokens[tokens == 32007] = 2 # + # 32003, 32006, 32007 + image_patch_token = DEFAULT_IMAGE_PATCH_TOKEN["llama_2"] + DEFAULT_TOKENS = [image_patch_token, DEFAULT_BOS_TOKEN, DEFAULT_EOS_TOKEN] + img_patch_id, bos_id, eos_id = get_tokens_ids(tokenizer, DEFAULT_TOKENS) + tokens[tokens == img_patch_id] = 0 # DEFAULT_IMAGE_PATCH_TOKEN + tokens[tokens == bos_id] = 1 # + tokens[tokens == eos_id] = 2 # labels = tokens.clone().detach() # Mask labels @@ -577,9 +645,14 @@ def preprocess_v1( ) # llama tricks - tokens[tokens == 32003] = 0 # DEFAULT_IMAGE_PATCH_TOKEN - tokens[tokens == 32006] = 1 # - tokens[tokens == 32007] = 2 # + # 32003, 32006, 32007 + image_patch_token = DEFAULT_IMAGE_PATCH_TOKEN["llama_2"] + DEFAULT_TOKENS = [image_patch_token, DEFAULT_BOS_TOKEN, DEFAULT_EOS_TOKEN] + img_patch_id, bos_id, eos_id = get_tokens_ids(tokenizer, DEFAULT_TOKENS) + tokens[tokens == img_patch_id] = 0 # DEFAULT_IMAGE_PATCH_TOKEN + tokens[tokens == bos_id] = 1 # + tokens[tokens == eos_id] = 2 # + # tokens = torch.concat((torch.tensor([[1]]), tokens), axis=1) #lita 1.5 legacy labels = tokens.clone().detach() # Mask labels @@ -977,7 +1050,7 @@ def __getitem__(self, i) -> Dict[str, torch.Tensor]: frames = self.video_loader.open_video(video_file) if frames is None: logging.warning(f"Video {video_file} could not be found!") - if isinstance(self.processor, CLIPImageProcessor): + if isinstance(self.processor, CLIPImageProcessor) or isinstance(self.processor, SiglipImageProcessor): # image processor from HF if self.multimodal_cfg['image_aspect_ratio'] == 'keep': max_hw, min_hw = max(frames.size), min(frames.size) @@ -1268,6 +1341,8 @@ def make_supervised_data_module(tokenizer, image_processor, model_cfg) -> Dict: context_length=model_cfg.encoder_seq_length, media_type=data_cfg.get('media_type', 'image'), num_frames=data_cfg.get('num_frames', -1), + use_lita=getattr(model_cfg.mm_cfg, 'use_lita', False), + lita=getattr(model_cfg.mm_cfg, 'lita', {}), mm_mlp_adapter_type=model_cfg.mm_cfg.get('mm_mlp_adapter_type', 'linear'), ), data_cfg=dict( diff --git a/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py b/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py index 376237e89ecc..40b1b4ed9a02 100644 --- a/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py +++ b/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py @@ -17,9 +17,10 @@ from itertools import chain from typing import Any, Optional +import numpy as np import torch import torch.nn.functional as F -from einops import rearrange, repeat +from einops import rearrange, reduce, repeat from omegaconf.dictconfig import DictConfig from pkg_resources import packaging from pytorch_lightning.trainer.trainer import Trainer @@ -64,18 +65,10 @@ from nemo.core.classes.common import PretrainedModelInfo from nemo.utils import logging -try: - import apex.transformer.pipeline_parallel.utils - from apex.transformer.pipeline_parallel.utils import get_num_microbatches - - HAVE_APEX = True - -except (ImportError, ModuleNotFoundError): - - HAVE_APEX = False - try: from megatron.core import InferenceParams, dist_checkpointing, parallel_state, tensor_parallel + from megatron.core.dist_checkpointing.dict_utils import dict_list_map_inplace + from megatron.core.dist_checkpointing.mapping import LocalNonpersitentObject, ShardedObject from megatron.core.models.gpt import GPTModel as MCoreGPTModel from megatron.core.pipeline_parallel.schedules import get_forward_backward_func from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint @@ -86,6 +79,19 @@ HAVE_MEGATRON_CORE = False +try: + from megatron.core.num_microbatches_calculator import get_num_microbatches + +except (ImportError, ModuleNotFoundError): + logging.warning("Megatron num_microbatches_calculator not found, using Apex version.") + from apex.transformer.pipeline_parallel.utils import get_num_microbatches + + +def skip_fp8_load(x): + if isinstance(x, ShardedObject) and 'fused_attention' in x.key and '_extra_state' in x.key: + x = LocalNonpersitentObject(x.data) # use the FP8 state from initialization, not from ckpt + return x + class FrozenCLIPVisionTransformer(CLIPVisionTransformer): """Frozen version of CLIPVisionTransformer""" @@ -137,6 +143,7 @@ def init_vision( media_start_id, media_end_id, vision_select_layer=-1, + vision_select_feature="patch", class_token_length=1, use_im_start_end=False, ): @@ -147,6 +154,7 @@ def init_vision( self.class_token_length = class_token_length self.use_im_start_end = use_im_start_end self.vision_select_layer = vision_select_layer + self.vision_select_feature = vision_select_feature self.media = None self.set_accepted_adapter_types([MultimodalProjectorAdapterConfig._target_]) @@ -208,7 +216,10 @@ def encode_vision_x(self, vision_x: torch.Tensor): self.vision_encoder.backbone.transformer.return_select_layer = self.vision_select_layer vision_x = self.vision_encoder(vision_x) vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F) - vision_x = vision_x[:, :, :, self.class_token_length :] + if self.vision_select_feature == "patch": + vision_x = vision_x[:, :, :, self.class_token_length :] + elif self.vision_select_feature != "cls_patch": + raise ValueError(f"Unsupported vision_select_feature {self.vision_select_feature}") assert self.is_adapter_available(), "Cannot find multimodal vision adapter!" vision_connector = self.get_adapter_module(AdapterName.MULTIMODAL_PROJECTOR_ADAPTER) vision_x = vision_connector(vision_x) @@ -273,6 +284,147 @@ def sharded_state_dict(self, prefix: str = '', sharded_offsets: tuple = (), **kw return sharded_state_dict +class LitaWordEmbeddingMixin(NevaWordEmbeddingMixin): + def init_lita( + self, + lita_video_arch: str, + visual_token_format: str = "v1", + use_media_start_end: bool = False, + sample_frames: int = 4, + ): + """_summary_ + + Args: + lita_video_arch (str): ['temporal_spatial_pool', 'temporal_spatial', 'temporal_all_resolution'] + visual_token_format (str, optional): default to 'v1', other option ["v1", "im_vid_start_end"] + v1: no video_start_id and video_end_id, video tokens are inserted between fast/slow (temporal/spatial) tokens + im_vid_start_end: video start and end tokens are inserted before and after temporal tokens + image start and end tokens are inserted before and after spatial tokens + use_media_start_end (bool, optional): + whether media start and media end is used in input_ids, Defaults to False. + Notice, when it is false, the media_start_id and media_end_id will play as an placeholder + input_ids = [..., media_start_id, t1, t2, t3...., media_end_id, ...] + use_media_start_end = False + we will replace the tokens including and between: [media_start_id, ... media_end_id] + use_media_start_end = True + we will replace the tokens between: (media_start_id, ... media_end_id) + num_frames (int, optional): number of frames to sample from the video, default to 4 + """ + self.lita_video_arch = lita_video_arch + self.visual_token_format = visual_token_format + self.use_media_start_end = use_media_start_end + self.sample_frames = sample_frames + + def add_lita_layer(self, media_features): + """_summary_ + + Args: + media_features (torch.Tensor): + feature after encoded by vision encoder + shape: Batch, T (number of images), S (num patches), H (hidden size) + Returns: + tokens (torch.Tensor): + shape: Batch, T + M, D (hidden size) + """ + + b, T, S, H = media_features.shape + tokens = media_features + if self.lita_video_arch == 'temporal_spatial_pool': + pool_size = 2 + h = w = int(np.sqrt(S)) + selected_frames = np.round(np.linspace(0, tokens.shape[1] - 1, pool_size * pool_size)).astype(int) + s_tokens = tokens[:, selected_frames, ...] + s_tokens = rearrange(s_tokens, 'b t (h w) d -> (b t) d h w', h=h, w=w) + s_tokens = F.avg_pool2d(s_tokens, kernel_size=pool_size) + s_tokens = rearrange(s_tokens, '(b t) d h w -> b (t h w) d', b=b) # B, M, D + t_tokens = reduce(tokens, 'b t s d -> b t d', 'mean') + # tokens = torch.cat([t_tokens, s_tokens], dim=1) # B, T + M, D + return t_tokens, s_tokens + elif self.lita_video_arch == 'temporal_spatial': + t_tokens = reduce(tokens, 'b t s d -> b t d', 'mean') + s_tokens = reduce(tokens, 'b t s d -> b s d', 'mean') + # tokens = torch.cat([t_tokens, s_tokens], dim=1) # B, T + M, D + return t_tokens, s_tokens + elif self.lita_video_arch == 'temporal_all_resolution': + idx = np.round(np.linspace(0, tokens.shape[1] - 1, self.sample_frames)).astype(int) + im_features = tokens[:, idx, ...] # B, num_frames, S, D + # im_tokens = im_features.view(b, -1, H) # flatten the B, num_frames * S, D + im_tokens = im_features + vid_tokens = reduce(tokens, 'b t s d -> b t d', 'mean') + # s and t tokens have been changed position + return im_tokens, vid_tokens + else: + raise ValueError(f"Unknown video architecture: {self.lita_video_arch}") + + def replace_media_embeddings(self, input_ids, inputs_embeds, media): + """_summary_ + + Args: + input_ids (torch.tensor): The input token ids [B, T] + words_embeddings (torch.tensor): The input embeddings [B, T, D] + media (torch.Tensor): Vision input + shape (B, T_img, F, C, H, W) + """ + if input_ids.shape[1] == 1: + return inputs_embeds + + if media is None: + return inputs_embeds + if type(media) is list: + raise NotImplementedError("dynamic length of videos not supported yet, only fixed length of videos now") + # 1, 1, num_frames, 3, 244, 244 + media_features = self.encode_vision_x(media) # B T F S(eq) H(idden) + B, T, F, S, H = media_features.shape + assert T == 1, "multiple videos per sample not supported yet" + media_features = media_features.squeeze(1) + t_tokens, s_tokens = self.add_lita_layer(media_features) # B, T, D & B, M, D + T = t_tokens.shape[1] + M = s_tokens.shape[1] + inputs_embeds = inputs_embeds.clone() + for idx, input_id in enumerate(input_ids): + media_start_position = torch.where(input_id == self.media_start_id)[0] + media_end_position = torch.where(input_id == self.media_end_id)[0] + if self.visual_token_format != 'im_vid_start_end': + assert len(media_start_position) == 1, "Only 1 video per sample supported" + assert len(media_end_position) == 1, "Only 1 video per sample supported" + + media_start_position = media_start_position[0] + media_end_position = media_end_position[-1] + if self.use_media_start_end: + # replace the tokens between media_start_id and media_end_id + start, end = media_start_position + 1, media_end_position - 1 + else: + # replace the tokens including and between media_start_id and media_end_id + start, end = media_start_position, media_end_position + + if self.visual_token_format == 'v1': + t_token_start, t_token_end = start, start + T + s_token_start, s_token_end = start + T, start + T + M + assert s_token_end == end + 1, "Token replacement error" + inputs_embeds[idx, t_token_start:t_token_end] = t_tokens[idx] + inputs_embeds[idx, s_token_start:s_token_end] = s_tokens[idx] + elif self.visual_token_format == 'im_vid_start_end': # v1.5 lita + if not self.use_media_start_end: + # replace the media start and media end embedding with + # img_start and vid_end token embedding + inputs_embeds[idx, start] = inputs_embeds[idx, start + 1] + inputs_embeds[idx, end] = inputs_embeds[idx, end - 1] + # TO DO: To optimize the below codes + im_features, vid_features = t_tokens[idx], s_tokens[idx] + # im_feature: num_frames * S, D + emb_start = start + 1 # skip the img_start token + num_frames, S, D = im_features.shape + for i in range(num_frames): + inputs_embeds[idx, emb_start : emb_start + S] = im_features[i] + emb_start = emb_start + S + 2 # skip the img_end token and img_start token + T = vid_features.shape[0] + inputs_embeds[idx, emb_start : emb_start + T] = vid_features + assert emb_start + T == end + else: + raise ValueError(f"Unsupported visual_token_format {self.visual_token_format}") + return inputs_embeds + + class NevaBaseModel: """ Base class for a multimedia model integrating vision and language models. @@ -307,12 +459,24 @@ def __init__( # Monkey patch embedding if kwargs.get("pre_process", True): - extend_instance(self.embedding.word_embeddings, NevaWordEmbeddingMixin) + if not mm_cfg.get("use_lita", False): + extend_instance(self.embedding.word_embeddings, NevaWordEmbeddingMixin) + else: + extend_instance(self.embedding.word_embeddings, LitaWordEmbeddingMixin) + lita_conf = mm_cfg.get('lita', {}) + self.embedding.word_embeddings.init_lita( + lita_video_arch=lita_conf.get('lita_video_arch', 'temporal_spatial_pool'), + visual_token_format=lita_conf.get('visual_token_format', 'v1'), + use_media_start_end=mm_cfg.get('use_im_start_end', False), # we need to make this clear + sample_frames=lita_conf.get('sample_frames', 4), + ) + self.embedding.word_embeddings.init_vision( vision_encoder, media_start_id, media_end_id, vision_select_layer=mm_cfg.vision_encoder.get("vision_select_layer", -2), + vision_select_feature=mm_cfg.vision_encoder.get("vision_select_feature", "patch"), class_token_length=mm_cfg.vision_encoder.get("class_token_length", 1), use_im_start_end=mm_cfg.get("use_im_start_end", False), ) @@ -320,7 +484,10 @@ def __init__( def create_vision_encoder_and_processor(self, mm_cfg): # Initialize vision encoder and freeze it if mm_cfg.vision_encoder.get("from_hf", False): - if "clip" in mm_cfg.vision_encoder.from_pretrained: + from transformers import AutoConfig + + config = AutoConfig.from_pretrained(mm_cfg.vision_encoder.from_pretrained) + if config.architectures[0] == "CLIPVisionModel" or config.architectures[0] == "CLIPModel": vision_encoder = CLIPVisionModel.from_pretrained( mm_cfg.vision_encoder.from_pretrained, torch_dtype=torch.bfloat16, @@ -330,7 +497,7 @@ def create_vision_encoder_and_processor(self, mm_cfg): for param in vision_encoder.parameters(): param.requires_grad = False vision_encoder = vision_encoder.eval() - elif "siglip" in mm_cfg.vision_encoder.from_pretrained: + elif config.architectures[0] == "SiglipVisionModel" or config.architectures[0] == "SiglipModel": vision_encoder = SiglipVisionModel.from_pretrained( mm_cfg.vision_encoder.from_pretrained, torch_dtype=torch.bfloat16, @@ -365,6 +532,9 @@ def _load_model_weights(self, nemo_path): sharded_state_dict = None if getattr(self, "sharded_state_dict", None) is not None: sharded_state_dict = self.sharded_state_dict(prefix="model.") + # WAR: This is a temporary fix to skip loading FP8 parameters for Dot Product Attention + # TODO(yuya): Check if this skip affecting fp8 native checkpoints loading + dict_list_map_inplace(skip_fp8_load, sharded_state_dict) state_dict, self.is_dist_ckpt = load_nemo_model_weights(nemo_path, sharded_state_dict) return state_dict @@ -575,8 +745,7 @@ def dummy(): config=self.transformer_config, transformer_layer_spec=get_specs( self.spec_name, - self.transformer_config.num_moe_experts, - self.transformer_config.moe_grouped_gemm, + self.transformer_config, self.transformer_engine, ), vocab_size=self.cfg.get('override_vocab_size', self.padded_vocab_size), @@ -913,9 +1082,10 @@ def fwd_output_only_func(dataloader_iter, model): inference_max_sequence_len, ) = batch tokens = tokens.cuda() - attention_mask = attention_mask.cuda() position_ids = position_ids.cuda() - attention_mask = attention_mask[0:1] + if attention_mask != None: + attention_mask = attention_mask.cuda() + attention_mask = attention_mask[0:1] if media is not None: media = media.cuda() labels = None diff --git a/nemo/collections/multimodal/models/text_to_image/controlnet/controlnet.py b/nemo/collections/multimodal/models/text_to_image/controlnet/controlnet.py index 3f59eb66c81a..3b795aa7618c 100644 --- a/nemo/collections/multimodal/models/text_to_image/controlnet/controlnet.py +++ b/nemo/collections/multimodal/models/text_to_image/controlnet/controlnet.py @@ -46,13 +46,11 @@ from nemo.utils import logging try: - from apex import amp - from apex.transformer.enums import AttnMaskType - from apex.transformer.pipeline_parallel.utils import get_num_microbatches + from megatron.core.num_microbatches_calculator import get_num_microbatches - HAVE_APEX = True except (ImportError, ModuleNotFoundError): - HAVE_APEX = False + logging.warning("Megatron num_microbatches_calculator not found, using Apex version.") + from apex.transformer.pipeline_parallel.utils import get_num_microbatches try: from megatron.core import parallel_state @@ -380,7 +378,9 @@ def __init__( time_embed_dim = model_channels * 4 self.time_embed = nn.Sequential( - linear(model_channels, time_embed_dim), nn.SiLU(), linear(time_embed_dim, time_embed_dim), + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), ) self.input_blocks = nn.ModuleList( @@ -505,24 +505,26 @@ def __init__( use_checkpoint=use_checkpoint, use_scale_shift_norm=use_scale_shift_norm, ), - AttentionBlock( - ch, - use_checkpoint=use_checkpoint, - num_heads=num_heads, - num_head_channels=dim_head, - use_new_attention_order=use_new_attention_order, - ) - if not use_spatial_transformer - else SpatialTransformer( # always uses a self-attn - ch, - num_heads, - dim_head, - depth=transformer_depth, - context_dim=context_dim, - disable_self_attn=disable_middle_self_attn, - use_linear=use_linear_in_transformer, - use_checkpoint=use_checkpoint, - use_flash_attention=use_flash_attention, + ( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer( # always uses a self-attn + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disable_middle_self_attn, + use_linear=use_linear_in_transformer, + use_checkpoint=use_checkpoint, + use_flash_attention=use_flash_attention, + ) ), ResBlock( ch, @@ -546,11 +548,18 @@ def load_from_unet(self, from_pretrained_unet, from_NeMo=True): print("Loading unet blocks from sd") state_dict = torch.load(from_pretrained_unet, map_location='cpu') - state_dict = state_dict['state_dict'] + if 'state_dict' in state_dict.keys(): + state_dict = state_dict['state_dict'] model_state_dict = self.state_dict() + model_state_keys = model_state_dict.keys() re_state_dict = {} for key_, value_ in state_dict.items(): + # check if key is a raw parameter + if key_ in model_state_keys: + re_state_dict[key_] = value_ + continue + # prune from model prefix if key_.startswith('model.model.diffusion_model'): re_state_dict[key_.replace('model.model.diffusion_model.', '')] = value_ if key_.startswith('model.diffusion_model'): @@ -621,11 +630,6 @@ def forward(self, x, hint, timesteps, context, **kwargs): class MegatronControlNet(MegatronBaseModel): def __init__(self, cfg: DictConfig, trainer: Trainer): - if not HAVE_APEX: - raise ImportError( - "Apex was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." - ) - if not HAVE_MEGATRON_CORE: raise ImportError( "megatron-core was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." @@ -684,7 +688,10 @@ def fwd_bwd_step(self, dataloader_iter, forward_only): # handle asynchronous grad reduction no_sync_func = None if not forward_only and self.with_distributed_adam: - no_sync_func = partial(self._optimizer.no_sync, greedy_grad_copy=self.megatron_amp_O2,) + no_sync_func = partial( + self._optimizer.no_sync, + greedy_grad_copy=self.megatron_amp_O2, + ) # pipeline schedules will get these from self.model.config for module in self.get_module_list(): @@ -728,12 +735,12 @@ def fwd_bwd_step(self, dataloader_iter, forward_only): def training_step(self, dataloader_iter): """ - Our dataloaders produce a micro-batch and then we fetch - a number of microbatches depending on the global batch size and model parallel size - from the dataloader to produce a list of microbatches. - Batch should be a list of microbatches and those microbatches should on CPU. - Microbatches are then moved to GPU during the pipeline. - The list of microbatches is then piped through the pipeline using Apex fwd/bwd functions. + Our dataloaders produce a micro-batch and then we fetch + a number of microbatches depending on the global batch size and model parallel size + from the dataloader to produce a list of microbatches. + Batch should be a list of microbatches and those microbatches should on CPU. + Microbatches are then moved to GPU during the pipeline. + The list of microbatches is then piped through the pipeline using Apex fwd/bwd functions. """ # we zero grads here because we also call backward in the apex fwd/bwd functions self._optimizer.zero_grad() @@ -777,20 +784,20 @@ def training_step(self, dataloader_iter): return loss_mean def backward(self, *args, **kwargs): - """ LightningModule hook to do backward. - We want this to do nothing since we run backward in the fwd/bwd functions from apex. - No need to call it here. + """LightningModule hook to do backward. + We want this to do nothing since we run backward in the fwd/bwd functions from apex. + No need to call it here. """ pass def optimizer_zero_grad(self, *args, **kwargs): - """ LightningModule hook to zero grad. - We want this to do nothing as we are zeroing grads during the training_step. + """LightningModule hook to zero grad. + We want this to do nothing as we are zeroing grads during the training_step. """ pass def _append_sequence_parallel_module_grads(self, module, grads): - """ Helper method for allreduce_sequence_parallel_gradients""" + """Helper method for allreduce_sequence_parallel_gradients""" for param in module.parameters(): sequence_parallel_param = getattr(param, 'sequence_parallel', False) @@ -803,8 +810,8 @@ def _append_sequence_parallel_module_grads(self, module, grads): def get_forward_output_and_loss_func(self): def process_batch(batch): - """ Prepares the global batch for apex fwd/bwd functions. - Global batch is a list of micro batches. + """Prepares the global batch for apex fwd/bwd functions. + Global batch is a list of micro batches. """ # noise_map, condition batch[self.cfg.first_stage_key] = batch[self.cfg.first_stage_key].cuda(non_blocking=True) @@ -814,7 +821,8 @@ def process_batch(batch): # SD has more dedicated structure for encoding, so we enable autocasting here as well with torch.cuda.amp.autocast( - self.autocast_dtype in (torch.half, torch.bfloat16), dtype=self.autocast_dtype, + self.autocast_dtype in (torch.half, torch.bfloat16), + dtype=self.autocast_dtype, ): x, c = self.model.get_input(batch, self.cfg.first_stage_key) @@ -881,7 +889,7 @@ def validation_step(self, batch, batch_idx): self.log_dict(val_loss_dict, prog_bar=False, logger=True, on_step=False, on_epoch=True) def setup(self, stage=None): - """ PTL hook that is executed after DDP spawns. + """PTL hook that is executed after DDP spawns. We setup datasets here as megatron datasets require DDP to instantiate. See https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#setup for more information. Args: @@ -935,7 +943,8 @@ def build_train_valid_test_datasets(self): if self.cfg.first_stage_key.endswith("encoded"): self._train_ds, self._validation_ds = build_train_valid_precached_datasets( - model_cfg=self.cfg, consumed_samples=self.compute_consumed_samples(0), + model_cfg=self.cfg, + consumed_samples=self.compute_consumed_samples(0), ) else: self._train_ds, self._validation_ds = build_train_valid_datasets( @@ -989,20 +998,23 @@ def setup_test_data(self, cfg): f'Setting up test dataloader with len(len(self._test_ds)): {len(self._test_ds)} and consumed samples: {consumed_samples}' ) self._test_dl = torch.utils.data.DataLoader( - self._test_ds, batch_size=self._micro_batch_size, num_workers=cfg.num_workers, pin_memory=True, + self._test_ds, + batch_size=self._micro_batch_size, + num_workers=cfg.num_workers, + pin_memory=True, ) def transfer_batch_to_device(self, batch: Any, device: torch.device, dataloader_idx: int) -> Any: - """ PTL hook: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#transfer-batch-to-device - When using pipeline parallelism, we need the global batch to remain on the CPU, - since the memory overhead will be too high when using a large number of microbatches. - Microbatches are transferred from CPU to GPU inside the pipeline. + """PTL hook: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#transfer-batch-to-device + When using pipeline parallelism, we need the global batch to remain on the CPU, + since the memory overhead will be too high when using a large number of microbatches. + Microbatches are transferred from CPU to GPU inside the pipeline. """ return batch def _validate_trainer(self): - """ Certain trainer configurations can break training. - Here we try to catch them and raise an error. + """Certain trainer configurations can break training. + Here we try to catch them and raise an error. """ if self.trainer.accumulate_grad_batches > 1: raise ValueError( diff --git a/nemo/collections/multimodal/models/text_to_image/dreambooth/dreambooth.py b/nemo/collections/multimodal/models/text_to_image/dreambooth/dreambooth.py index 0b830ac7319b..24712ed30021 100644 --- a/nemo/collections/multimodal/models/text_to_image/dreambooth/dreambooth.py +++ b/nemo/collections/multimodal/models/text_to_image/dreambooth/dreambooth.py @@ -37,7 +37,6 @@ try: from apex import amp from apex.transformer.enums import AttnMaskType - from apex.transformer.pipeline_parallel.utils import get_num_microbatches HAVE_APEX = True except (ImportError, ModuleNotFoundError): @@ -53,6 +52,13 @@ HAVE_MEGATRON_CORE = False +try: + from megatron.core.num_microbatches_calculator import get_num_microbatches + +except (ImportError, ModuleNotFoundError): + logging.warning("Megatron num_microbatches_calculator not found, using Apex version.") + from apex.transformer.pipeline_parallel.utils import get_num_microbatches + def disabled_train(self, mode=True): """Overwrite model.train with this function to make sure train/eval mode @@ -99,7 +105,9 @@ def __init__(self, cfg, model_parallel_config): self.get_noise_scheduler(self.cfg.noise_scheduler) self.model_type = None - self.rng = torch.Generator(device=torch.cuda.current_device(),) + self.rng = torch.Generator( + device=torch.cuda.current_device(), + ) self.use_cached_latents = self.cfg.use_cached_latents @@ -246,7 +254,10 @@ def fwd_bwd_step(self, dataloader_iter, forward_only): # handle asynchronous grad reduction no_sync_func = None if not forward_only and self.with_distributed_adam: - no_sync_func = partial(self._optimizer.no_sync, greedy_grad_copy=self.megatron_amp_O2,) + no_sync_func = partial( + self._optimizer.no_sync, + greedy_grad_copy=self.megatron_amp_O2, + ) # pipeline schedules will get these from self.model.config for module in self.get_module_list(): @@ -291,12 +302,12 @@ def fwd_bwd_step(self, dataloader_iter, forward_only): def training_step(self, dataloader_iter): """ - Our dataloaders produce a micro-batch and then we fetch - a number of microbatches depending on the global batch size and model parallel size - from the dataloader to produce a list of microbatches. - Batch should be a list of microbatches and those microbatches should on CPU. - Microbatches are then moved to GPU during the pipeline. - The list of microbatches is then piped through the pipeline using Apex fwd/bwd functions. + Our dataloaders produce a micro-batch and then we fetch + a number of microbatches depending on the global batch size and model parallel size + from the dataloader to produce a list of microbatches. + Batch should be a list of microbatches and those microbatches should on CPU. + Microbatches are then moved to GPU during the pipeline. + The list of microbatches is then piped through the pipeline using Apex fwd/bwd functions. """ # we zero grads here because we also call backward in the apex fwd/bwd functions @@ -351,20 +362,20 @@ def validation_step(self, dataloader_iter): return loss def backward(self, *args, **kwargs): - """ LightningModule hook to do backward. - We want this to do nothing since we run backward in the fwd/bwd functions from apex. - No need to call it here. + """LightningModule hook to do backward. + We want this to do nothing since we run backward in the fwd/bwd functions from apex. + No need to call it here. """ pass def optimizer_zero_grad(self, *args, **kwargs): - """ LightningModule hook to zero grad. - We want this to do nothing as we are zeroing grads during the training_step. + """LightningModule hook to zero grad. + We want this to do nothing as we are zeroing grads during the training_step. """ pass def _append_sequence_parallel_module_grads(self, module, grads): - """ Helper method for allreduce_sequence_parallel_gradients""" + """Helper method for allreduce_sequence_parallel_gradients""" for param in module.parameters(): sequence_parallel_param = getattr(param, 'sequence_parallel', False) @@ -381,7 +392,8 @@ def process_batch(batch): prompts, images = batch # DB has more dedicated structure for encoding, so we enable autocasting here as well with torch.cuda.amp.autocast( - self.autocast_dtype in (torch.half, torch.bfloat16), dtype=self.autocast_dtype, + self.autocast_dtype in (torch.half, torch.bfloat16), + dtype=self.autocast_dtype, ): images = images.cuda(non_blocking=True) @@ -412,7 +424,7 @@ def fwd_output_only_func(batch, model): return fwd_output_only_func def setup(self, stage=None): - """ PTL hook that is executed after DDP spawns. + """PTL hook that is executed after DDP spawns. We setup datasets here as megatron datasets require DDP to instantiate. See https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#setup for more information. Args: @@ -472,9 +484,9 @@ def setup_training_data(self, cfg): center_crop=cfg.center_crop, load_cache_latents=self.model.use_cached_latents, cached_instance_data_root=self.cfg.data.get("cached_instance_dir", None), - cached_reg_data_root=self.cfg.data.get("cached_reg_dir", None) - if self.cfg.with_prior_preservation - else None, + cached_reg_data_root=( + self.cfg.data.get("cached_reg_dir", None) if self.cfg.with_prior_preservation else None + ), vae=self.model.vae, text_encoder=self.model.text_encoder, ) @@ -505,16 +517,16 @@ def setup_test_data(self, cfg): pass def transfer_batch_to_device(self, batch: Any, device: torch.device, dataloader_idx: int) -> Any: - """ PTL hook: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#transfer-batch-to-device - When using pipeline parallelism, we need the global batch to remain on the CPU, - since the memory overhead will be too high when using a large number of microbatches. - Microbatches are transferred from CPU to GPU inside the pipeline. + """PTL hook: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#transfer-batch-to-device + When using pipeline parallelism, we need the global batch to remain on the CPU, + since the memory overhead will be too high when using a large number of microbatches. + Microbatches are transferred from CPU to GPU inside the pipeline. """ return batch def _validate_trainer(self): - """ Certain trainer configurations can break training. - Here we try to catch them and raise an error. + """Certain trainer configurations can break training. + Here we try to catch them and raise an error. """ if self.trainer.accumulate_grad_batches > 1: raise ValueError( diff --git a/nemo/collections/multimodal/models/text_to_image/imagen/imagen.py b/nemo/collections/multimodal/models/text_to_image/imagen/imagen.py index 4fa6cd230e03..b7cf6d629d65 100644 --- a/nemo/collections/multimodal/models/text_to_image/imagen/imagen.py +++ b/nemo/collections/multimodal/models/text_to_image/imagen/imagen.py @@ -34,7 +34,6 @@ try: from apex import amp - from apex.transformer.pipeline_parallel.utils import get_num_microbatches HAVE_APEX = True except (ImportError, ModuleNotFoundError): @@ -49,6 +48,13 @@ except (ImportError, ModuleNotFoundError): HAVE_MEGATRON_CORE = False +try: + from megatron.core.num_microbatches_calculator import get_num_microbatches + +except (ImportError, ModuleNotFoundError): + logging.warning("Megatron num_microbatches_calculator not found, using Apex version.") + from apex.transformer.pipeline_parallel.utils import get_num_microbatches + try: from apex.contrib.group_norm import GroupNorm @@ -218,8 +224,8 @@ def model_provider_func(self, pre_process=True, post_process=True): def get_forward_output_and_loss_func(self): def process_batch(batch): - """ Prepares the batch for megatron fwd/bwd functions. - Global batch is a list of micro batches. + """Prepares the batch for megatron fwd/bwd functions. + Global batch is a list of micro batches. """ # Base model and SR models have slightly different batch input: # Base model would only require images (64x64), @@ -323,7 +329,10 @@ def setup_test_data(self, cfg): f'Setting up test dataloader with len(len(self._test_ds)): {len(self._test_ds)} and consumed samples: {consumed_samples}' ) self._test_dl = torch.utils.data.DataLoader( - self._test_ds, batch_size=self._micro_batch_size, num_workers=cfg.num_workers, pin_memory=True, + self._test_ds, + batch_size=self._micro_batch_size, + num_workers=cfg.num_workers, + pin_memory=True, ) def fwd_bwd_step(self, dataloader_iter, forward_only): @@ -332,7 +341,10 @@ def fwd_bwd_step(self, dataloader_iter, forward_only): # handle asynchronous grad reduction no_sync_func = None if not forward_only and self.with_distributed_adam: - no_sync_func = partial(self._optimizer.no_sync, greedy_grad_copy=self.megatron_amp_O2,) + no_sync_func = partial( + self._optimizer.no_sync, + greedy_grad_copy=self.megatron_amp_O2, + ) # pipeline schedules will get these from self.model.config for module in self.get_module_list(): @@ -379,12 +391,12 @@ def fwd_bwd_step(self, dataloader_iter, forward_only): def training_step(self, dataloader_iter): """ - Our dataloaders produce a micro-batch and then we fetch - a number of microbatches depending on the global batch size and model parallel size - from the dataloader to produce a list of microbatches. - Batch should be a list of microbatches and those microbatches should on CPU. - Microbatches are then moved to GPU during the pipeline. - The list of microbatches is then piped through the pipeline using Apex fwd/bwd functions. + Our dataloaders produce a micro-batch and then we fetch + a number of microbatches depending on the global batch size and model parallel size + from the dataloader to produce a list of microbatches. + Batch should be a list of microbatches and those microbatches should on CPU. + Microbatches are then moved to GPU during the pipeline. + The list of microbatches is then piped through the pipeline using Apex fwd/bwd functions. """ # we zero grads here because we also call backward in the megatron-core fwd/bwd functions @@ -434,20 +446,20 @@ def training_step(self, dataloader_iter): return loss_mean def backward(self, *args, **kwargs): - """ LightningModule hook to do backward. - We want this to do nothing since we run backward in the fwd/bwd functions from apex. - No need to call it here. + """LightningModule hook to do backward. + We want this to do nothing since we run backward in the fwd/bwd functions from apex. + No need to call it here. """ pass def optimizer_zero_grad(self, *args, **kwargs): - """ LightningModule hook to zero grad. - We want this to do nothing as we are zeroing grads during the training_step. + """LightningModule hook to zero grad. + We want this to do nothing as we are zeroing grads during the training_step. """ pass def _append_sequence_parallel_module_grads(self, module, grads): - """ Helper method for allreduce_sequence_parallel_gradients""" + """Helper method for allreduce_sequence_parallel_gradients""" for param in module.parameters(): sequence_parallel_param = getattr(param, 'sequence_parallel', False) @@ -460,10 +472,10 @@ def _append_sequence_parallel_module_grads(self, module, grads): def validation_step(self, dataloader_iter): """ - Our dataloaders produce a micro-batch and then we fetch - a number of microbatches depending on the global batch size and model parallel size - from the dataloader to produce a list of microbatches. - The list of microbatches is then piped through the pipeline using megatron-core fwd/bwd functions. """ + Our dataloaders produce a micro-batch and then we fetch + a number of microbatches depending on the global batch size and model parallel size + from the dataloader to produce a list of microbatches. + The list of microbatches is then piped through the pipeline using megatron-core fwd/bwd functions.""" loss, val_loss_dict = self.fwd_bwd_step(dataloader_iter, True) @@ -471,7 +483,7 @@ def validation_step(self, dataloader_iter): return loss def setup(self, stage=None): - """ PTL hook that is executed after DDP spawns. + """PTL hook that is executed after DDP spawns. We setup datasets here as megatron datasets require DDP to instantiate. See https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#setup for more information. Args: @@ -520,16 +532,16 @@ def setup(self, stage=None): self.model.setup_rng() def transfer_batch_to_device(self, batch: Any, device: torch.device, dataloader_idx: int) -> Any: - """ PTL hook: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#transfer-batch-to-device - When using pipeline parallelism, we need the global batch to remain on the CPU, - since the memory overhead will be too high when using a large number of microbatches. - Microbatches are transferred from CPU to GPU inside the pipeline. + """PTL hook: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#transfer-batch-to-device + When using pipeline parallelism, we need the global batch to remain on the CPU, + since the memory overhead will be too high when using a large number of microbatches. + Microbatches are transferred from CPU to GPU inside the pipeline. """ return batch def _validate_trainer(self): - """ Certain trainer configurations can break training. - Here we try to catch them and raise an error. + """Certain trainer configurations can break training. + Here we try to catch them and raise an error. """ if self.trainer.accumulate_grad_batches > 1: raise ValueError( @@ -558,7 +570,10 @@ def on_load_checkpoint(self, checkpoint) -> None: inductor_enabled = self.cfg.get('inductor', False) state_dict = checkpoint['state_dict'] inductor_checkpoint = False - for k, v, in state_dict.items(): + for ( + k, + v, + ) in state_dict.items(): if '_orig_mod' in k: inductor_checkpoint = True break diff --git a/nemo/collections/multimodal/models/text_to_image/stable_diffusion/diffusion_engine.py b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/diffusion_engine.py index efc1550113a0..77a8caa58b40 100644 --- a/nemo/collections/multimodal/models/text_to_image/stable_diffusion/diffusion_engine.py +++ b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/diffusion_engine.py @@ -52,15 +52,6 @@ from nemo.core.classes.mixins.adapter_mixins import AdapterModuleMixin from nemo.utils import logging, model_utils -try: - from apex import amp - from apex.transformer.enums import AttnMaskType - from apex.transformer.pipeline_parallel.utils import get_num_microbatches - - HAVE_APEX = True -except (ImportError, ModuleNotFoundError): - HAVE_APEX = False - try: from megatron.core import parallel_state from megatron.core.pipeline_parallel.schedules import get_forward_backward_func @@ -71,6 +62,13 @@ HAVE_MEGATRON_CORE = False +try: + from megatron.core.num_microbatches_calculator import get_num_microbatches + +except (ImportError, ModuleNotFoundError): + logging.warning("Megatron num_microbatches_calculator not found, using Apex version.") + from apex.transformer.pipeline_parallel.utils import get_num_microbatches + UNCONDITIONAL_CONFIG = { "target": "sgm.modules.GeneralConditioner", "params": {"emb_models": []}, @@ -119,7 +117,9 @@ def __init__(self, cfg, model_parallel_config): self._init_first_stage(first_stage_config) self.model_type = None - self.rng = torch.Generator(device=torch.cuda.current_device(),) + self.rng = torch.Generator( + device=torch.cuda.current_device(), + ) self.use_ema = False # TODO use_ema need to switch to NeMo style if self.use_ema: @@ -158,6 +158,13 @@ def decode_first_stage(self, z): out = self.first_stage_model.decode(z) return out + # same as above but differentiable + def differentiable_decode_first_stage(self, z): + z = 1.0 / self.scale_factor * z + with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast): + out = self.first_stage_model.decode(z) + return out + @torch.no_grad() def encode_first_stage(self, x): with torch.autocast("cuda", enabled=not self.disable_first_stage_autocast): @@ -185,7 +192,12 @@ def training_step(self, batch, batch_idx): self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=False) self.log( - "global_step", self.global_step, prog_bar=True, logger=True, on_step=True, on_epoch=False, + "global_step", + self.global_step, + prog_bar=True, + logger=True, + on_step=True, + on_epoch=False, ) if self.scheduler_config is not None: @@ -231,7 +243,11 @@ def configure_optimizers(self): scheduler = DiffusionEngine.from_config_dict(self.scheduler_config) print("Setting up LambdaLR scheduler...") scheduler = [ - {"scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule), "interval": "step", "frequency": 1,} + { + "scheduler": LambdaLR(opt, lr_lambda=scheduler.schedule), + "interval": "step", + "frequency": 1, + } ] return [opt], scheduler return opt @@ -291,7 +307,14 @@ def set_input_tensor(self, input_tensor): pass @torch.no_grad() - def log_images(self, batch: Dict, N: int = 8, sample: bool = True, ucg_keys: List[str] = None, **kwargs,) -> Dict: + def log_images( + self, + batch: Dict, + N: int = 8, + sample: bool = True, + ucg_keys: List[str] = None, + **kwargs, + ) -> Dict: conditioner_input_keys = [e.input_key for e in self.conditioner.embedders] if ucg_keys: assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), ( @@ -305,7 +328,8 @@ def log_images(self, batch: Dict, N: int = 8, sample: bool = True, ucg_keys: Lis x = self.get_input(batch) c, uc = self.conditioner.get_unconditional_conditioning( - batch, force_uc_zero_embeddings=ucg_keys if len(self.conditioner.embedders) > 0 else [], + batch, + force_uc_zero_embeddings=ucg_keys if len(self.conditioner.embedders) > 0 else [], ) sampling_kwargs = {} @@ -333,10 +357,6 @@ class MegatronDiffusionEngine(NLPAdapterModelMixin, MegatronBaseModel): """Megatron DiffusionEngine Model.""" def __init__(self, cfg: DictConfig, trainer: Trainer): - if not HAVE_APEX: - raise ImportError( - "Apex was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." - ) if not HAVE_MEGATRON_CORE: raise ImportError( "megatron-core was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." @@ -400,7 +420,10 @@ def fwd_bwd_step(self, dataloader_iter, forward_only): # handle asynchronous grad reduction no_sync_func = None if not forward_only and self.with_distributed_adam: - no_sync_func = partial(self._optimizer.no_sync, greedy_grad_copy=self.megatron_amp_O2,) + no_sync_func = partial( + self._optimizer.no_sync, + greedy_grad_copy=self.megatron_amp_O2, + ) # pipeline schedules will get these from self.model.config for module in self.get_module_list(): @@ -438,12 +461,12 @@ def fwd_bwd_step(self, dataloader_iter, forward_only): def training_step(self, dataloader_iter): """ - Our dataloaders produce a micro-batch and then we fetch - a number of microbatches depending on the global batch size and model parallel size - from the dataloader to produce a list of microbatches. - Batch should be a list of microbatches and those microbatches should on CPU. - Microbatches are then moved to GPU during the pipeline. - The list of microbatches is then piped through the pipeline using Apex fwd/bwd functions. + Our dataloaders produce a micro-batch and then we fetch + a number of microbatches depending on the global batch size and model parallel size + from the dataloader to produce a list of microbatches. + Batch should be a list of microbatches and those microbatches should on CPU. + Microbatches are then moved to GPU during the pipeline. + The list of microbatches is then piped through the pipeline using Apex fwd/bwd functions. """ self._optimizer.zero_grad() @@ -491,20 +514,20 @@ def training_step(self, dataloader_iter): return loss_mean def backward(self, *args, **kwargs): - """ LightningModule hook to do backward. - We want this to do nothing since we run backward in the fwd/bwd functions from apex. - No need to call it here. + """LightningModule hook to do backward. + We want this to do nothing since we run backward in the fwd/bwd functions from apex. + No need to call it here. """ pass def optimizer_zero_grad(self, *args, **kwargs): - """ LightningModule hook to zero grad. - We want this to do nothing as we are zeroing grads during the training_step. + """LightningModule hook to zero grad. + We want this to do nothing as we are zeroing grads during the training_step. """ pass def _append_sequence_parallel_module_grads(self, module, grads): - """ Helper method for allreduce_sequence_parallel_gradients""" + """Helper method for allreduce_sequence_parallel_gradients""" for param in module.parameters(): sequence_parallel_param = getattr(param, 'sequence_parallel', False) @@ -517,12 +540,13 @@ def _append_sequence_parallel_module_grads(self, module, grads): def get_forward_output_and_loss_func(self): def process_batch(batch): - """ Prepares the global batch for apex fwd/bwd functions. - Global batch is a list of micro batches. + """Prepares the global batch for apex fwd/bwd functions. + Global batch is a list of micro batches. """ # SD has more dedicated structure for encoding, so we enable autocasting here as well with torch.cuda.amp.autocast( - self.autocast_dtype in (torch.half, torch.bfloat16), dtype=self.autocast_dtype, + self.autocast_dtype in (torch.half, torch.bfloat16), + dtype=self.autocast_dtype, ): if self.model.precache_mode == 'both': x = batch[self.model.input_key].to(torch.cuda.current_device()) @@ -565,7 +589,7 @@ def validation_step(self, dataloader_iter, batch_idx): return loss def setup(self, stage=None): - """ PTL hook that is executed after DDP spawns. + """PTL hook that is executed after DDP spawns. We setup datasets here as megatron datasets require DDP to instantiate. See https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#setup for more information. Args: @@ -678,20 +702,23 @@ def setup_test_data(self, cfg): f'Setting up test dataloader with len(len(self._test_ds)): {len(self._test_ds)} and consumed samples: {consumed_samples}' ) self._test_dl = torch.utils.data.DataLoader( - self._test_ds, batch_size=self._micro_batch_size, num_workers=cfg.num_workers, pin_memory=True, + self._test_ds, + batch_size=self._micro_batch_size, + num_workers=cfg.num_workers, + pin_memory=True, ) def transfer_batch_to_device(self, batch: Any, device: torch.device, dataloader_idx: int) -> Any: - """ PTL hook: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#transfer-batch-to-device - When using pipeline parallelism, we need the global batch to remain on the CPU, - since the memory overhead will be too high when using a large number of microbatches. - Microbatches are transferred from CPU to GPU inside the pipeline. + """PTL hook: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#transfer-batch-to-device + When using pipeline parallelism, we need the global batch to remain on the CPU, + since the memory overhead will be too high when using a large number of microbatches. + Microbatches are transferred from CPU to GPU inside the pipeline. """ return batch def _validate_trainer(self): - """ Certain trainer configurations can break training. - Here we try to catch them and raise an error. + """Certain trainer configurations can break training. + Here we try to catch them and raise an error. """ if self.trainer.accumulate_grad_batches > 1: raise ValueError( diff --git a/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/autoencoder.py b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/autoencoder.py index 6bd47a78fbcf..d79d85c2e026 100644 --- a/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/autoencoder.py +++ b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/autoencoder.py @@ -16,6 +16,7 @@ import pytorch_lightning as pl import torch import torch.nn.functional as F +from nemo.utils import logging try: from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer @@ -316,6 +317,7 @@ def __init__( ignore_keys=[], image_key="image", colorize_nlabels=None, + from_NeMo=False, monitor=None, from_pretrained: str = None, ): @@ -337,6 +339,7 @@ def __init__( self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) if from_pretrained is not None: + logging.info(f"Attempting to load vae weights from {from_pretrained}") if from_pretrained.endswith('safetensors'): from safetensors.torch import load_file as load_safetensors @@ -345,7 +348,7 @@ def __init__( state_dict = torch.load(from_pretrained) if 'state_dict' in state_dict: state_dict = state_dict['state_dict'] - missing_key, unexpected_key, _, _ = self._load_pretrained_model(state_dict) + missing_key, unexpected_key, _, _ = self._load_pretrained_model(state_dict, from_NeMo=from_NeMo) if len(missing_key) > 0: print( f'{self.__class__.__name__}: Following keys are missing during loading VAE weights, which may lead to compromised image quality for a resumed training. Please check the checkpoint you provided.' @@ -395,8 +398,9 @@ def _state_key_mapping(self, state_dict: dict): res_dict[key_] = val_ return res_dict - def _load_pretrained_model(self, state_dict, ignore_mismatched_sizes=False): - state_dict = self._state_key_mapping(state_dict) + def _load_pretrained_model(self, state_dict, ignore_mismatched_sizes=False, from_NeMo=False): + if not from_NeMo: + state_dict = self._state_key_mapping(state_dict) model_state_dict = self.state_dict() loaded_keys = [k for k in state_dict.keys()] expected_keys = list(model_state_dict.keys()) @@ -405,7 +409,10 @@ def _load_pretrained_model(self, state_dict, ignore_mismatched_sizes=False): unexpected_keys = list(set(loaded_keys) - set(expected_keys)) def _find_mismatched_keys( - state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes, + state_dict, + model_state_dict, + loaded_keys, + ignore_mismatched_sizes, ): mismatched_keys = [] if ignore_mismatched_sizes: @@ -440,7 +447,10 @@ def _find_mismatched_keys( if state_dict is not None: # Whole checkpoint mismatched_keys = _find_mismatched_keys( - state_dict, model_state_dict, original_loaded_keys, ignore_mismatched_sizes, + state_dict, + model_state_dict, + original_loaded_keys, + ignore_mismatched_sizes, ) error_msgs = self._load_state_dict_into_model(state_dict) return missing_keys, unexpected_keys, mismatched_keys, error_msgs diff --git a/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/ddpm.py b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/ddpm.py index 6ea4314ab71f..89b1d88819b8 100644 --- a/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/ddpm.py +++ b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/ddpm.py @@ -75,15 +75,6 @@ from nemo.core.classes.mixins.adapter_mixins import AdapterModuleMixin from nemo.utils import logging, model_utils -try: - from apex import amp - from apex.transformer.enums import AttnMaskType - from apex.transformer.pipeline_parallel.utils import get_num_microbatches - - HAVE_APEX = True -except (ImportError, ModuleNotFoundError): - HAVE_APEX = False - try: from megatron.core import parallel_state from megatron.core.pipeline_parallel.schedules import get_forward_backward_func @@ -94,6 +85,14 @@ HAVE_MEGATRON_CORE = False +try: + from megatron.core.num_microbatches_calculator import get_num_microbatches + +except (ImportError, ModuleNotFoundError): + logging.warning("Megatron num_microbatches_calculator not found, using Apex version.") + from apex.transformer.pipeline_parallel.utils import get_num_microbatches + + __conditioning_keys__ = {'concat': 'c_concat', 'crossattn': 'c_crossattn', 'adm': 'y'} @@ -163,7 +162,9 @@ def __init__(self, cfg): cuda_graph_enabled = cfg.get("capture_cudagraph_iters", -1) >= 0 if not cuda_graph_enabled: logging.info("Use custom random generator") - self.rng = torch.Generator(device=torch.cuda.current_device(),) + self.rng = torch.Generator( + device=torch.cuda.current_device(), + ) else: logging.info("Use system random generator since CUDA graph enabled") self.rng = None @@ -222,14 +223,12 @@ def register_schedule( ) if self.parameterization == "eps": - lvlb_weights = self.betas ** 2 / ( - 2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod) - ) + lvlb_weights = self.betas**2 / (2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) elif self.parameterization == "x0": lvlb_weights = 0.5 * np.sqrt(torch.Tensor(alphas_cumprod)) / (2.0 * 1 - torch.Tensor(alphas_cumprod)) elif self.parameterization == "v": lvlb_weights = torch.ones_like( - self.betas ** 2 / (2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) + self.betas**2 / (2 * self.posterior_variance * to_torch(alphas) * (1 - self.alphas_cumprod)) ) else: raise NotImplementedError("mu not supported") @@ -239,7 +238,13 @@ def register_schedule( assert not torch.isnan(self.lvlb_weights).all() def init_from_ckpt( - self, path, ignore_keys=list(), only_model=False, load_vae=True, load_unet=True, load_encoder=True, + self, + path, + ignore_keys=list(), + only_model=False, + load_vae=True, + load_unet=True, + load_encoder=True, ): pl_sd = torch.load(path, map_location="cpu") if "state_dict" in list(pl_sd.keys()): @@ -561,7 +566,11 @@ def __init__(self, cfg, model_parallel_config): load_encoder = True if cfg.get("load_encoder", None) is None else cfg.load_encoder self.init_from_ckpt( - ckpt_path, ignore_keys, load_vae=load_vae, load_unet=load_unet, load_encoder=load_encoder, + ckpt_path, + ignore_keys, + load_vae=load_vae, + load_unet=load_unet, + load_encoder=load_encoder, ) self.restarted_from_ckpt = True @@ -569,7 +578,9 @@ def __init__(self, cfg, model_parallel_config): self.first_stage_model = self.first_stage_model.to(memory_format=torch.channels_last) self.model = self.model.to(memory_format=torch.channels_last) - def make_cond_schedule(self,): + def make_cond_schedule( + self, + ): self.cond_ids = torch.full(size=(self.num_timesteps,), fill_value=self.num_timesteps - 1, dtype=torch.long) ids = torch.round(torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond)).long() self.cond_ids[: self.num_timesteps_cond] = ids @@ -686,7 +697,9 @@ def delta_border(self, h, w): def get_weighting(self, h, w, Ly, Lx, device): weighting = self.delta_border(h, w) weighting = torch.clip( - weighting, self.split_input_params["clip_min_weight"], self.split_input_params["clip_max_weight"], + weighting, + self.split_input_params["clip_min_weight"], + self.split_input_params["clip_max_weight"], ) weighting = weighting.view(1, h * w, 1).repeat(1, 1, Ly * Lx).to(device) @@ -1322,9 +1335,11 @@ def progressive_denoising( if cond is not None: if isinstance(cond, dict): cond = { - key: cond[key][:batch_size] - if not isinstance(cond[key], list) - else list(map(lambda x: x[:batch_size], cond[key])) + key: ( + cond[key][:batch_size] + if not isinstance(cond[key], list) + else list(map(lambda x: x[:batch_size], cond[key])) + ) for key in cond } else: @@ -1458,9 +1473,11 @@ def sample( if cond is not None: if isinstance(cond, dict): cond = { - key: cond[key][:batch_size] - if not isinstance(cond[key], list) - else list(map(lambda x: x[:batch_size], cond[key])) + key: ( + cond[key][:batch_size] + if not isinstance(cond[key], list) + else list(map(lambda x: x[:batch_size], cond[key])) + ) for key in cond } else: @@ -1656,10 +1673,6 @@ class MegatronLatentDiffusion(NLPAdapterModelMixin, MegatronBaseModel): """Megatron LatentDiffusion Model.""" def __init__(self, cfg: DictConfig, trainer: Trainer): - if not HAVE_APEX: - raise ImportError( - "Apex was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." - ) if not HAVE_MEGATRON_CORE: raise ImportError( "megatron-core was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." @@ -1731,7 +1744,10 @@ def fwd_bwd_step(self, dataloader_iter, forward_only): # handle asynchronous grad reduction no_sync_func = None if not forward_only and self.with_distributed_adam: - no_sync_func = partial(self._optimizer.no_sync, greedy_grad_copy=self.megatron_amp_O2,) + no_sync_func = partial( + self._optimizer.no_sync, + greedy_grad_copy=self.megatron_amp_O2, + ) # pipeline schedules will get these from self.model.config for module in self.get_module_list(): @@ -1779,29 +1795,31 @@ def fwd_bwd_step(self, dataloader_iter, forward_only): if self.loss_broadcast_src_rank is None: self.loss_broadcast_src_rank = parallel_state.get_pipeline_model_parallel_last_rank() torch.distributed.broadcast( - loss_mean, self.loss_broadcast_src_rank, group=parallel_state.get_pipeline_model_parallel_group(), + loss_mean, + self.loss_broadcast_src_rank, + group=parallel_state.get_pipeline_model_parallel_group(), ) return loss_mean, loss_dict def training_step(self, batch): """ - Notice: `training_step` used to have the following signature to support pipeline - parallelism: - - def training_step(self, dataloader_iter, batch_idx): - - However, full iteration CUDA Graph callback is not compatible with this signature - right now, due to we need to wrap the dataloader to generate static tensor outside - the CUDA Graph. This signature moves `next(dataloader)` into the CUDA Graph - capturing region, thus we disabled it. - - Our dataloaders produce a micro-batch and then we fetch - a number of microbatches depending on the global batch size and model parallel size - from the dataloader to produce a list of microbatches. - Batch should be a list of microbatches and those microbatches should on CPU. - Microbatches are then moved to GPU during the pipeline. - The list of microbatches is then piped through the pipeline using Apex fwd/bwd functions. + Notice: `training_step` used to have the following signature to support pipeline + parallelism: + + def training_step(self, dataloader_iter, batch_idx): + + However, full iteration CUDA Graph callback is not compatible with this signature + right now, due to we need to wrap the dataloader to generate static tensor outside + the CUDA Graph. This signature moves `next(dataloader)` into the CUDA Graph + capturing region, thus we disabled it. + + Our dataloaders produce a micro-batch and then we fetch + a number of microbatches depending on the global batch size and model parallel size + from the dataloader to produce a list of microbatches. + Batch should be a list of microbatches and those microbatches should on CPU. + Microbatches are then moved to GPU during the pipeline. + The list of microbatches is then piped through the pipeline using Apex fwd/bwd functions. """ # we zero grads here because we also call backward in the megatron-core fwd/bwd functions @@ -1875,20 +1893,20 @@ def non_cuda_graph_capturable(self): self.log("timestamp", ts, batch_size=1, rank_zero_only=True) def backward(self, *args, **kwargs): - """ LightningModule hook to do backward. - We want this to do nothing since we run backward in the fwd/bwd functions from apex. - No need to call it here. + """LightningModule hook to do backward. + We want this to do nothing since we run backward in the fwd/bwd functions from apex. + No need to call it here. """ pass def optimizer_zero_grad(self, *args, **kwargs): - """ LightningModule hook to zero grad. - We want this to do nothing as we are zeroing grads during the training_step. + """LightningModule hook to zero grad. + We want this to do nothing as we are zeroing grads during the training_step. """ pass def _append_sequence_parallel_module_grads(self, module, grads): - """ Helper method for allreduce_sequence_parallel_gradients""" + """Helper method for allreduce_sequence_parallel_gradients""" for param in module.parameters(): sequence_parallel_param = getattr(param, 'sequence_parallel', False) @@ -1901,8 +1919,8 @@ def _append_sequence_parallel_module_grads(self, module, grads): def get_forward_output_and_loss_func(self): def process_batch(batch): - """ Prepares the global batch for apex fwd/bwd functions. - Global batch is a list of micro batches. + """Prepares the global batch for apex fwd/bwd functions. + Global batch is a list of micro batches. """ # noise_map, condition batch[self.cfg.first_stage_key] = batch[self.cfg.first_stage_key].cuda(non_blocking=True) @@ -1912,7 +1930,8 @@ def process_batch(batch): # SD has more dedicated structure for encoding, so we enable autocasting here as well with torch.cuda.amp.autocast( - self.autocast_dtype in (torch.half, torch.bfloat16), dtype=self.autocast_dtype, + self.autocast_dtype in (torch.half, torch.bfloat16), + dtype=self.autocast_dtype, ): x, c = self.model.get_input(batch, self.cfg.first_stage_key) @@ -1959,7 +1978,7 @@ def validation_step(self, dataloader_iter): return loss def setup(self, stage=None): - """ PTL hook that is executed after DDP spawns. + """PTL hook that is executed after DDP spawns. We setup datasets here as megatron datasets require DDP to instantiate. See https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#setup for more information. Args: @@ -2016,11 +2035,13 @@ def build_train_valid_test_datasets(self): if self.cfg.first_stage_key.endswith("encoded") or self.cfg.first_stage_key.endswith("moments"): if self.cfg.cond_stage_key.endswith("clip_encoded"): self._train_ds, self._validation_ds = build_train_valid_precached_clip_datasets( - model_cfg=self.cfg, consumed_samples=self.compute_consumed_samples(0), + model_cfg=self.cfg, + consumed_samples=self.compute_consumed_samples(0), ) else: self._train_ds, self._validation_ds = build_train_valid_precached_datasets( - model_cfg=self.cfg, consumed_samples=self.compute_consumed_samples(0), + model_cfg=self.cfg, + consumed_samples=self.compute_consumed_samples(0), ) else: self._train_ds, self._validation_ds = build_train_valid_datasets( @@ -2045,7 +2066,8 @@ def setup_training_data(self, cfg): ) if self.cfg.cond_stage_key.endswith("clip_encoded"): collate_fn = get_collate_fn( - first_stage_key=self.cfg.first_stage_key, cond_stage_key=self.cfg.cond_stage_key, + first_stage_key=self.cfg.first_stage_key, + cond_stage_key=self.cfg.cond_stage_key, ) else: collate_fn = None @@ -2082,20 +2104,23 @@ def setup_test_data(self, cfg): f'Setting up test dataloader with len(len(self._test_ds)): {len(self._test_ds)} and consumed samples: {consumed_samples}' ) self._test_dl = torch.utils.data.DataLoader( - self._test_ds, batch_size=self._micro_batch_size, num_workers=cfg.num_workers, pin_memory=True, + self._test_ds, + batch_size=self._micro_batch_size, + num_workers=cfg.num_workers, + pin_memory=True, ) def transfer_batch_to_device(self, batch: Any, device: torch.device, dataloader_idx: int) -> Any: - """ PTL hook: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#transfer-batch-to-device - When using pipeline parallelism, we need the global batch to remain on the CPU, - since the memory overhead will be too high when using a large number of microbatches. - Microbatches are transferred from CPU to GPU inside the pipeline. + """PTL hook: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#transfer-batch-to-device + When using pipeline parallelism, we need the global batch to remain on the CPU, + since the memory overhead will be too high when using a large number of microbatches. + Microbatches are transferred from CPU to GPU inside the pipeline. """ return batch def _validate_trainer(self): - """ Certain trainer configurations can break training. - Here we try to catch them and raise an error. + """Certain trainer configurations can break training. + Here we try to catch them and raise an error. """ if self.trainer.accumulate_grad_batches > 1: raise ValueError( @@ -2253,23 +2278,26 @@ def _check_and_add_adapter(self, name, module, peft_name, peft_cfg, name_key_to_ ) def load_adapters( - self, filepath: str, peft_cfgs: Optional[Union[PEFTConfig, List[PEFTConfig]]] = None, map_location: str = None, + self, + filepath: str, + peft_cfgs: Optional[Union[PEFTConfig, List[PEFTConfig]]] = None, + map_location: str = None, ): """ - Utility method that restores only the adapter module(s), and not the entire model itself. - This allows the sharing of adapters which are often just a fraction of the size of the full model, - enabling easier deliver. + Utility method that restores only the adapter module(s), and not the entire model itself. + This allows the sharing of adapters which are often just a fraction of the size of the full model, + enabling easier deliver. - .. note:: + .. note:: - During restoration, assumes that the model does not currently already have one or more adapter modules. + During restoration, assumes that the model does not currently already have one or more adapter modules. - Args: - filepath: Filepath of the .ckpt or .nemo file. - peft_cfgs: One or more PEFTConfig objects that specify the PEFT method configuration. - If none, will infer from the .nemo checkpoint - map_location: Pytorch flag, where to place the adapter(s) state dict(s). - """ + Args: + filepath: Filepath of the .ckpt or .nemo file. + peft_cfgs: One or more PEFTConfig objects that specify the PEFT method configuration. + If none, will infer from the .nemo checkpoint + map_location: Pytorch flag, where to place the adapter(s) state dict(s). + """ def _modify_state_dict(state_dict): # Modify state key for Dreambooth inference @@ -2310,7 +2338,11 @@ def _modify_state_dict(state_dict): class DiffusionWrapper(pl.LightningModule, Serialization): def __init__( - self, diff_model_config, conditioning_key, inductor: bool = False, inductor_cudagraphs: bool = False, + self, + diff_model_config, + conditioning_key, + inductor: bool = False, + inductor_cudagraphs: bool = False, ): super().__init__() self.diffusion_model = DiffusionWrapper.from_config_dict(diff_model_config) diff --git a/nemo/collections/multimodal/models/vision_language_foundation/clip/megatron_clip_models.py b/nemo/collections/multimodal/models/vision_language_foundation/clip/megatron_clip_models.py index a83960307672..99bb8a23cf47 100644 --- a/nemo/collections/multimodal/models/vision_language_foundation/clip/megatron_clip_models.py +++ b/nemo/collections/multimodal/models/vision_language_foundation/clip/megatron_clip_models.py @@ -54,7 +54,6 @@ try: from apex.transformer.enums import AttnMaskType - from apex.transformer.pipeline_parallel.utils import get_num_microbatches HAVE_APEX = True except (ImportError, ModuleNotFoundError): @@ -96,6 +95,13 @@ HAVE_MEGATRON_CORE = False +try: + from megatron.core.num_microbatches_calculator import get_num_microbatches + +except (ImportError, ModuleNotFoundError): + logging.warning("Megatron num_microbatches_calculator not found, using Apex version.") + from apex.transformer.pipeline_parallel.utils import get_num_microbatches + try: import transformer_engine from transformer_engine.pytorch import module as te_module @@ -501,8 +507,7 @@ def __init__( add_class_token = True vision_layer_spec = get_specs( model_cfg.text.get('name', ''), - vision_transformer_config.num_moe_experts, - vision_transformer_config.moe_grouped_gemm, + vision_transformer_config, model_cfg.get('transformer_engine', True), ) vision_layer_spec.submodules.self_attention.params['attn_mask_type'] = MCoreAttnMaskType.no_mask @@ -527,8 +532,7 @@ def __init__( config=text_transformer_config, transformer_layer_spec=get_specs( model_cfg.text.get('name', ''), - text_transformer_config.num_moe_experts, - text_transformer_config.moe_grouped_gemm, + text_transformer_config, model_cfg.get('transformer_engine', True), ), vocab_size=model_cfg.text.get('override_vocab_size', padded_vocab_size), @@ -984,6 +988,7 @@ def training_step(self, dataloader_iter): for module in modules: if isinstance(module, (Float16Module, MCoreFloat16Module)): module = module.module + module = module.text_encoder if not self.mcore_gpt: module = module.language_model if hasattr(module, 'embedding'): diff --git a/nemo/collections/multimodal/models/vision_language_foundation/megatron_nsfw_clip_models.py b/nemo/collections/multimodal/models/vision_language_foundation/megatron_nsfw_clip_models.py index 24c2bfc58be7..79c0f3910be0 100644 --- a/nemo/collections/multimodal/models/vision_language_foundation/megatron_nsfw_clip_models.py +++ b/nemo/collections/multimodal/models/vision_language_foundation/megatron_nsfw_clip_models.py @@ -19,7 +19,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -from apex.transformer.pipeline_parallel.utils import get_num_microbatches from megatron.core import parallel_state from megatron.core.pipeline_parallel.schedules import get_forward_backward_func from omegaconf.dictconfig import DictConfig @@ -40,6 +39,14 @@ from nemo.utils import logging +try: + from megatron.core.num_microbatches_calculator import get_num_microbatches + +except (ImportError, ModuleNotFoundError): + logging.warning("Megatron num_microbatches_calculator not found, using Apex version.") + from apex.transformer.pipeline_parallel.utils import get_num_microbatches + + class ContentFilteringModel(MegatronModule): """Clip based content filtering model for NSFW.""" diff --git a/nemo/collections/multimodal/modules/imagen/diffusionmodules/blocks.py b/nemo/collections/multimodal/modules/imagen/diffusionmodules/blocks.py index 1d6b8395a58f..9a7f0a572743 100644 --- a/nemo/collections/multimodal/modules/imagen/diffusionmodules/blocks.py +++ b/nemo/collections/multimodal/modules/imagen/diffusionmodules/blocks.py @@ -58,10 +58,9 @@ def check_cuda(): dprops = th.cuda.get_device_properties(cur_device) is_sm75 = dprops.major == 7 and dprops.minor == 5 - is_sm8x = dprops.major == 8 and dprops.minor >= 0 - is_sm90 = dprops.major == 9 and dprops.minor >= 0 + is_sm8x_or_later = dprops.major >= 8 - return is_sm8x or is_sm75 or is_sm90 + return is_sm75 or is_sm8x_or_later try: @@ -154,7 +153,9 @@ def __init__( self.use_scale_shift_norm = use_scale_shift_norm self.in_layers = nn.Sequential( - normalization(channels), nn.SiLU(), conv_nd(dims, channels, self.out_channels, 3, padding=1), + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), ) self.updown = up or down @@ -173,7 +174,11 @@ def __init__( self.h_upd = self.x_upd = nn.Identity() self.emb_layers = nn.Sequential( - nn.SiLU(), linear(emb_channels, 2 * self.out_channels if use_scale_shift_norm else self.out_channels,), + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), ) self.out_layers = nn.Sequential( normalization(self.out_channels), @@ -263,7 +268,11 @@ def __init__( ) self.emb_layers = nn.Sequential( - nn.SiLU(), nn.Linear(emb_channels, 2 * out_channels if use_scale_shift_norm else out_channels,), + nn.SiLU(), + nn.Linear( + emb_channels, + 2 * out_channels if use_scale_shift_norm else out_channels, + ), ) self.out_layers = nn.Sequential( diff --git a/nemo/collections/multimodal/modules/stable_diffusion/attention.py b/nemo/collections/multimodal/modules/stable_diffusion/attention.py index 2eeed97db781..492f68af032e 100644 --- a/nemo/collections/multimodal/modules/stable_diffusion/attention.py +++ b/nemo/collections/multimodal/modules/stable_diffusion/attention.py @@ -56,10 +56,9 @@ def check_cuda(): dprops = torch.cuda.get_device_properties(cur_device) is_sm75 = dprops.major == 7 and dprops.minor == 5 - is_sm8x = dprops.major == 8 and dprops.minor >= 0 - is_sm90 = dprops.major == 9 and dprops.minor >= 0 + is_sm8x_or_later = dprops.major >= 8 - return is_sm8x or is_sm75 or is_sm90 + return is_sm75 or is_sm8x_or_later try: @@ -227,6 +226,10 @@ def __init__(self, in_features, out_features, bias=True, lora_network_alpha=None def forward(self, x): mixed_x = super().forward(x) if self.is_adapter_available(): + # return this output if lora is not enabled + cfg = self.get_adapter_cfg(AdapterName.PARALLEL_LINEAR_ADAPTER) + if not cfg['enabled']: + return mixed_x lora_linear_adapter = self.get_adapter_module(AdapterName.PARALLEL_LINEAR_ADAPTER) lora_mixed_x = lora_linear_adapter(x) # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. diff --git a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/denoiser.py b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/denoiser.py index df1f27449bd1..a358bb08f92d 100644 --- a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/denoiser.py +++ b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/denoiser.py @@ -33,13 +33,18 @@ def possibly_quantize_c_noise(self, c_noise): def w(self, sigma): return self.weighting(sigma) - def __call__(self, network, input, sigma, cond): + def __call__(self, network, input, sigma, cond, return_noise=False): sigma = self.possibly_quantize_sigma(sigma) sigma_shape = sigma.shape sigma = append_dims(sigma, input.ndim) c_skip, c_out, c_in, c_noise = self.scaling(sigma) c_noise = self.possibly_quantize_c_noise(c_noise.reshape(sigma_shape)) - return network(input * c_in, c_noise, cond) * c_out + input * c_skip + # predict noise from network + noise_pred = network(input * c_in, c_noise, cond) + denoised = noise_pred * c_out + input * c_skip + if return_noise: + return denoised, noise_pred + return denoised class DiscreteDenoiser(Denoiser): diff --git a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py index 7f8b2fb20bff..b94624b33ba2 100644 --- a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py +++ b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/openaimodel.py @@ -789,6 +789,7 @@ def __init__( self.input_blocks.append(TimestepEmbedSequential(*layers)) self._feature_size += ch input_block_chans.append(ch) + if level != len(channel_mult) - 1: out_ch = ch self.input_blocks.append( @@ -954,6 +955,7 @@ def __init__( ) if from_pretrained is not None: + logging.info(f"Attempting to load pretrained unet from {from_pretrained}") if from_pretrained.endswith('safetensors'): from safetensors.torch import load_file as load_safetensors @@ -969,6 +971,8 @@ def __init__( ) logging.info(f"Missing keys: {missing_key}") logging.info(f"Unexpected keys: {unexpected_keys}") + else: + logging.info(f"There are no missing keys, model loaded properly!") if unet_precision == "fp16-mixed": # AMP O2 self.convert_to_fp16() @@ -1021,6 +1025,16 @@ def _input_blocks_mapping(self, input_dict): .replace('conv2', 'out_layers.3') .replace('conv_shortcut', 'skip_connection') ) + ## Rohit: I've changed this to make sure it is compatible + # post_fix = ( + # key_[25:] + # .replace('time_emb_proj', 'emb_layers.1') + # .replace('norm1', 'in_layers.0') + # .replace('norm2', 'out_layers.0') + # .replace('conv1', 'in_layers.1') + # .replace('conv2', 'out_layers.2') + # .replace('conv_shortcut', 'skip_connection') + # ) res_dict["input_blocks." + str(target_id) + '.0.' + post_fix] = value_ elif "attentions" in key_: id_1 = int(key_[26]) @@ -1168,7 +1182,7 @@ def te_fp8_key_mapping(self, unet_dict): return new_state_dict def _state_key_mapping(self, state_dict: dict): - + # state_dict is a HF model res_dict = {} input_dict = {} mid_dict = {} @@ -1205,6 +1219,7 @@ def _state_key_mapping(self, state_dict: dict): def _load_pretrained_model(self, state_dict, ignore_mismatched_sizes=False, from_NeMo=False): state_dict = self._strip_unet_key_prefix(state_dict) if not from_NeMo: + logging.info("creating state key mapping from HF") state_dict = self._state_key_mapping(state_dict) state_dict = self._legacy_unet_ckpt_mapping(state_dict) diff --git a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/sampling.py b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/sampling.py index c636ffec345d..bfae8790eeb2 100644 --- a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/sampling.py +++ b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/sampling.py @@ -47,7 +47,12 @@ def __init__( ): self.num_steps = num_steps self.discretization = instantiate_from_config(discretization_config) - self.guider = instantiate_from_config(default(guider_config, DEFAULT_GUIDER,)) + self.guider = instantiate_from_config( + default( + guider_config, + DEFAULT_GUIDER, + ) + ) self.verbose = verbose self.device = device @@ -93,35 +98,50 @@ def euler_step(self, x, d, dt): class EDMSampler(SingleStepDiffusionSampler): def __init__(self, s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, *args, **kwargs): super().__init__(*args, **kwargs) - self.s_churn = s_churn self.s_tmin = s_tmin self.s_tmax = s_tmax self.s_noise = s_noise - def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, gamma=0.0): + def sampler_step(self, sigma, next_sigma, denoiser, x, cond, uc=None, gamma=0.0, return_noise=False): + # x is actually \bar{x} as in the DDIM paper sigma_hat = sigma * (gamma + 1.0) if gamma > 0: eps = torch.randn_like(x) * self.s_noise - x = x + eps * append_dims(sigma_hat ** 2 - sigma ** 2, x.ndim) ** 0.5 + x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5 denoised = self.denoise(x, denoiser, sigma_hat, cond, uc) + # this is the noise (e_t) d = to_d(x, sigma_hat, denoised) dt = append_dims(next_sigma - sigma_hat, x.ndim) - euler_step = self.euler_step(x, d, dt) + euler_step = self.euler_step(x, d, dt) # this is x_{t-\delta{t}} x = self.possible_correction_step(euler_step, x, d, dt, next_sigma, denoiser, cond, uc) + if return_noise: + return x, d return x + def get_gamma(self, sigmas, num_sigmas, index): + gamma = ( + min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) if self.s_tmin <= sigmas[index] <= self.s_tmax else 0.0 + ) + return gamma + def __call__(self, denoiser, x, cond, uc=None, num_steps=None): + # prepare_sampling_loop converts x into \bar{x} = x / \sqrt{\tilde{\alpha_t}} x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps) for i in self.get_sigma_gen(num_sigmas): - gamma = ( - min(self.s_churn / (num_sigmas - 1), 2 ** 0.5 - 1) if self.s_tmin <= sigmas[i] <= self.s_tmax else 0.0 + gamma = self.get_gamma(sigmas, num_sigmas, i) + x = self.sampler_step( + s_in * sigmas[i], + s_in * sigmas[i + 1], + denoiser, + x, + cond, + uc, + gamma, ) - x = self.sampler_step(s_in * sigmas[i], s_in * sigmas[i + 1], denoiser, x, cond, uc, gamma,) - return x @@ -151,14 +171,24 @@ def __call__(self, denoiser, x, cond, uc=None, num_steps=None): x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(x, cond, uc, num_steps) for i in self.get_sigma_gen(num_sigmas): - x = self.sampler_step(s_in * sigmas[i], s_in * sigmas[i + 1], denoiser, x, cond, uc,) + x = self.sampler_step( + s_in * sigmas[i], + s_in * sigmas[i + 1], + denoiser, + x, + cond, + uc, + ) return x class LinearMultistepSampler(BaseDiffusionSampler): def __init__( - self, order=4, *args, **kwargs, + self, + order=4, + *args, + **kwargs, ): super().__init__(*args, **kwargs) @@ -276,7 +306,15 @@ def get_mult(self, h, r, t, t_next, previous_sigma): return mult1, mult2 def sampler_step( - self, old_denoised, previous_sigma, sigma, next_sigma, denoiser, x, cond, uc=None, + self, + old_denoised, + previous_sigma, + sigma, + next_sigma, + denoiser, + x, + cond, + uc=None, ): denoised = self.denoise(x, denoiser, sigma, cond, uc) diff --git a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/wrappers.py b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/wrappers.py index 0d465c1275c6..24e2124e6f83 100644 --- a/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/wrappers.py +++ b/nemo/collections/multimodal/modules/stable_diffusion/diffusionmodules/wrappers.py @@ -37,6 +37,11 @@ class OpenAIWrapper(IdentityWrapper): def forward(self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs) -> torch.Tensor: if c.get("concat", None): x = torch.cat((x, c.get("concat")), dim=1) + return self.diffusion_model( - x, timesteps=t, context=c.get("crossattn", None), y=c.get("vector", None), **kwargs, + x, + timesteps=t, + context=c.get("crossattn", None), + y=c.get("vector", None), + **kwargs, ) diff --git a/nemo/collections/multimodal/modules/stable_diffusion/encoders/modules.py b/nemo/collections/multimodal/modules/stable_diffusion/encoders/modules.py index ab33532c3c1f..0443d75a61e8 100644 --- a/nemo/collections/multimodal/modules/stable_diffusion/encoders/modules.py +++ b/nemo/collections/multimodal/modules/stable_diffusion/encoders/modules.py @@ -689,6 +689,8 @@ def load_model(self, cfg, state_dict): model_cfg=cfg, model_parallel_config=ModelParallelConfig(), padded_vocab_size=padded_vocab_size, + vision_transformer_config=None, # assumed mcore to be false + text_transformer_config=None, pre_process=cfg.text.pre_process, post_process=cfg.text.post_process, ) diff --git a/nemo/collections/multimodal/parts/utils.py b/nemo/collections/multimodal/parts/utils.py index 7eb72b38d0f0..ea8053398a88 100644 --- a/nemo/collections/multimodal/parts/utils.py +++ b/nemo/collections/multimodal/parts/utils.py @@ -15,7 +15,6 @@ import tempfile from typing import Any, Callable, Tuple -import decord import numpy as np import torch from omegaconf import DictConfig, OmegaConf, open_dict @@ -23,16 +22,21 @@ from pytorch_lightning import Trainer from pytorch_lightning.plugins.environments import TorchElasticEnvironment from transformers import CLIPImageProcessor, SiglipImageProcessor -from nemo.collections.multimodal.data.clip.augmentations.augmentations import image_transform +from nemo.collections.multimodal.data.clip.augmentations.augmentations import image_transform from nemo.collections.multimodal.data.neva.neva_dataset import process_image from nemo.collections.nlp.modules.common.megatron.megatron_init import fake_initialize_model_parallel -from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy, NLPSaveRestoreConnector +from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy, NLPFSDPStrategy, NLPSaveRestoreConnector from nemo.collections.nlp.parts.peft_config import PEFT_CONFIG_MAP from nemo.collections.nlp.parts.utils_funcs import torch_dtype_from_precision from nemo.utils import AppState, logging from nemo.utils.model_utils import inject_model_parallel_rank +try: + import decord +except Exception: + logging.warning("The package `decord` was not installed in this environment.") + try: from megatron.core import dist_checkpointing @@ -276,54 +280,69 @@ def setup_trainer_and_model_for_inference( # Use the NLPDDPStrategy for the distributed data parallel strategy. # We don't use DDP for async grad allreduce and don't find unused parameters. - strategy = NLPDDPStrategy( - no_ddp_communication_hook=True, - find_unused_parameters=False, - ) + if not cfg.model.get('fsdp', False): + logging.info("FSDP is False, using DDP strategy.") + strategy = NLPDDPStrategy( + no_ddp_communication_hook=True, + find_unused_parameters=False, + ) + else: + logging.info("Using FSDP strategy.") + strategy = NLPFSDPStrategy( + limit_all_gathers=cfg.model.get('fsdp_limit_all_gathers', True), + sharding_strategy=cfg.model.get('fsdp_sharding_strategy', 'full'), + cpu_offload=cfg.model.get('fsdp_cpu_offload', True), + grad_reduce_dtype=cfg.model.get('fsdp_grad_reduce_dtype', 32), + precision=cfg.trainer.precision, + # use_orig_params=cfg.model.inductor, + set_buffer_dtype=cfg.get('fsdp_set_buffer_dtype', None), + ) # Set up the trainer with the specified plugins and strategy. trainer = Trainer(plugins=plugins, strategy=strategy, **cfg.trainer) # Create the NLPSaveRestoreConnector object for model saving and restoring. save_restore_connector = NLPSaveRestoreConnector() + if cfg.model.restore_from_path is not None: + if cfg.model.restore_from_path.endswith(".nemo") or os.path.isdir(cfg.model.restore_from_path): + # Set the model_extracted_dir attribute if the restore path is a directory. + if os.path.isdir(cfg.model.restore_from_path): + save_restore_connector.model_extracted_dir = cfg.model.restore_from_path - if cfg.model.restore_from_path.endswith(".nemo") or os.path.isdir(cfg.model.restore_from_path): - # Set the model_extracted_dir attribute if the restore path is a directory. - if os.path.isdir(cfg.model.restore_from_path): - save_restore_connector.model_extracted_dir = cfg.model.restore_from_path - - # Restore the model configuration from the specified path and modify it for inference. - model_cfg = model_provider.restore_from( - restore_path=cfg.model.restore_from_path, - trainer=trainer, - save_restore_connector=save_restore_connector, - return_config=True, - ) - with open_dict(model_cfg): - model_cfg_modifier(model_cfg) # modify the configuration for inference + # Restore the model configuration from the specified path and modify it for inference. + model_cfg = model_provider.restore_from( + restore_path=cfg.model.restore_from_path, + trainer=trainer, + save_restore_connector=save_restore_connector, + return_config=True, + ) + with open_dict(model_cfg): + model_cfg_modifier(model_cfg) # modify the configuration for inference - # Restore the model from the specified path and configuration, and set it up for inference. - model = model_provider.restore_from( - restore_path=cfg.model.restore_from_path, - trainer=trainer, - override_config_path=model_cfg, - save_restore_connector=save_restore_connector, - strict=True, - ) + # Restore the model from the specified path and configuration, and set it up for inference. + model = model_provider.restore_from( + restore_path=cfg.model.restore_from_path, + trainer=trainer, + override_config_path=model_cfg, + save_restore_connector=save_restore_connector, + strict=True, + ) - elif cfg.model.restore_from_path.endswith(".ckpt"): - logging.warning( - "Loading from .ckpt checkpoint for inference is experimental! It doesn't support models with model parallelism!" - ) + elif cfg.model.restore_from_path.endswith(".ckpt"): + logging.warning( + "Loading from .ckpt checkpoint for inference is experimental! It doesn't support models with model parallelism!" + ) - model = model_provider.load_from_checkpoint( - cfg.model.restore_from_path, - hparams_file=cfg.model.get("hparams_file"), - trainer=trainer, - ) + model = model_provider.load_from_checkpoint( + cfg.model.restore_from_path, + hparams_file=cfg.model.get("hparams_file"), + trainer=trainer, + ) else: - raise ValueError(f"Unrecognized checkpoint type: {cfg.model.restore_from_path}") + # load a model from scratch + logging.warning("Loading a model from scratch for inference. Tread carefully.") + model = model_provider(cfg=cfg.model, trainer=trainer) # initialize apex DDP strategy def dummy(): @@ -451,7 +470,6 @@ def image_processor(maybe_image_path): def video_processor(maybe_video_path): if isinstance(maybe_video_path, str): - decord.bridge.set_bridge("torch") vr = decord.VideoReader(maybe_video_path) if neva_cfg.data.splice_single_frame == 'first': frames = [Image.fromarray(vr[0].asnumpy()).convert('RGB')] @@ -465,20 +483,15 @@ def video_processor(maybe_video_path): else: num_frames = min(len(vr), neva_cfg.data.num_frames) indices = np.linspace(0, len(vr) - 1, num_frames, dtype=int) - frames = vr.get_batch(indices) - + frames = [Image.fromarray(vr[i].asnumpy()).convert('RGB') for i in indices] while len(frames) < neva_cfg.data.num_frames: frames.append(frames[-1]) else: frames = maybe_video_path - if neva_cfg.mm_cfg.vision_encoder.from_hf: - processor = CLIPImageProcessor.from_pretrained( - neva_cfg.mm_cfg.vision_encoder.from_pretrained, torch_dtype=torch.bfloat16 - ) - else: - processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-large-patch14", torch_dtype=torch.bfloat16) - + processor = ( + model.model.module.image_processor if hasattr(model.model, "module") else model.model.image_processor + ) # support single video inference if neva_cfg.data.image_aspect_ratio == 'keep': max_hw, min_hw = max(frames.size), min(frames.size) @@ -503,7 +516,7 @@ def expand2square(pil_img, background_color): result.paste(pil_img, ((height - width) // 2, 0)) return result - frames = [expand2square(frame, tuple(int(x * 255) for x in self.processor.image_mean)) for frame in frames] + frames = [expand2square(frame, tuple(int(x * 255) for x in processor.image_mean)) for frame in frames] frames = processor.preprocess(frames, return_tensors='pt')['pixel_values'] else: frames = processor.preprocess(frames, return_tensors='pt')['pixel_values'] @@ -516,11 +529,14 @@ def expand2square(pil_img, background_color): def create_image_processor(mm_cfg): if mm_cfg.vision_encoder.get("from_hf", False): - if "clip" in mm_cfg.vision_encoder.from_pretrained: + from transformers import AutoConfig + + config = AutoConfig.from_pretrained(mm_cfg.vision_encoder.from_pretrained) + if config.architectures[0] == "CLIPVisionModel" or config.architectures[0] == "CLIPModel": image_processor = CLIPImageProcessor.from_pretrained( mm_cfg.vision_encoder.from_pretrained, torch_dtype=torch.bfloat16 ) - elif "siglip" in mm_cfg.vision_encoder.from_pretrained: + elif config.architectures[0] == "SiglipVisionModel" or config.architectures[0] == "SiglipModel": image_processor = SiglipImageProcessor.from_pretrained( mm_cfg.vision_encoder.from_pretrained, torch_dtype=torch.bfloat16 ) diff --git a/nemo/collections/multimodal/speech_llm/models/__init__.py b/nemo/collections/multimodal/speech_llm/models/__init__.py index ec188828ec87..ee51bd94af2c 100644 --- a/nemo/collections/multimodal/speech_llm/models/__init__.py +++ b/nemo/collections/multimodal/speech_llm/models/__init__.py @@ -12,4 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from nemo.collections.multimodal.speech_llm.models.modular_models import ModularAudioGPTModel +from nemo.collections.multimodal.speech_llm.models.modular_models import ( + CrossAttendModularAudioGPTModel, + ModularAudioGPTModel, +) diff --git a/nemo/collections/multimodal/speech_llm/models/modular_models.py b/nemo/collections/multimodal/speech_llm/models/modular_models.py index cce74e7b6a1d..edabbfd82f87 100644 --- a/nemo/collections/multimodal/speech_llm/models/modular_models.py +++ b/nemo/collections/multimodal/speech_llm/models/modular_models.py @@ -57,13 +57,6 @@ from nemo.utils import AppState, logging, model_utils from nemo.utils.model_utils import inject_model_parallel_rank -try: - from apex.transformer.pipeline_parallel.utils import _reconfigure_microbatch_calculator, get_num_microbatches - - HAVE_APEX = True -except (ImportError, ModuleNotFoundError): - HAVE_APEX = False - try: from megatron.core import InferenceParams, parallel_state, tensor_parallel from megatron.core.models.gpt import GPTModel as MCoreGPTModel @@ -74,7 +67,17 @@ HAVE_MEGATRON_CORE = False -__all__ = ["ModularAudioGPTModel"] +try: + from megatron.core.num_microbatches_calculator import get_num_microbatches, reconfigure_num_microbatches_calculator + +except (ImportError, ModuleNotFoundError): + logging.warning("Megatron num_microbatches_calculator not found, using Apex version.") + from apex.transformer.pipeline_parallel.utils import ( + _reconfigure_microbatch_calculator as reconfigure_num_microbatches_calculator, + ) + from apex.transformer.pipeline_parallel.utils import get_num_microbatches + +__all__ = ["ModularAudioGPTModel", "CrossAttendModularAudioGPTModel"] default_inference_config = {'tokens_to_generate': 30} @@ -1196,7 +1199,7 @@ def predict_step(self, batch: dict, batch_idx: int, dataloader_idx: Optional[int response = generate(self, **inference_config) app_state = AppState() - _reconfigure_microbatch_calculator( + reconfigure_num_microbatches_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=global_batch_size_per_gpu * parallel_state.get_data_parallel_world_size(), @@ -1365,7 +1368,7 @@ def inference_epoch_end(self, outputs, mode, data_cfg): app_state = AppState() self._restore_activation_checkpointing_args() if hasattr(self, "_train_ds"): - _reconfigure_microbatch_calculator( + reconfigure_num_microbatches_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=self.cfg.data.train_ds.global_batch_size, @@ -1375,7 +1378,7 @@ def inference_epoch_end(self, outputs, mode, data_cfg): # When running `trainer.validate()`, the training dataset is not available. else: logging.warning('No training data found, reconfiguring microbatches based on validation batch sizes.') - _reconfigure_microbatch_calculator( + reconfigure_num_microbatches_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=data_cfg.global_batch_size, diff --git a/nemo/collections/multimodal/speech_llm/models/modular_t5_models.py b/nemo/collections/multimodal/speech_llm/models/modular_t5_models.py index a96ee823e197..fce31d031abd 100644 --- a/nemo/collections/multimodal/speech_llm/models/modular_t5_models.py +++ b/nemo/collections/multimodal/speech_llm/models/modular_t5_models.py @@ -38,6 +38,7 @@ MultiAudioPerceptionModule, ) from nemo.collections.nlp.models.language_modeling.megatron_t5_adapter_model import MegatronT5LoraModel +from nemo.collections.nlp.models.language_modeling.megatron_t5_model import MegatronT5Model from nemo.collections.nlp.models.language_modeling.megatron_t5_sft_model import MegatronT5SFTModel from nemo.collections.nlp.models.nlp_model import NLPModel from nemo.collections.nlp.modules.common.megatron.utils import ( @@ -50,19 +51,6 @@ from nemo.core.classes.mixins import adapter_mixins from nemo.utils import AppState, logging, model_utils -try: - from apex.transformer.pipeline_parallel.utils import ( - _reconfigure_microbatch_calculator, - get_current_global_batch_size, - get_micro_batch_size, - get_num_microbatches, - ) - - HAVE_APEX = True -except (ImportError, ModuleNotFoundError): - HAVE_APEX = False -from nemo.collections.nlp.models.language_modeling.megatron_t5_model import MegatronT5Model - try: from megatron.core import parallel_state, tensor_parallel from megatron.core.pipeline_parallel.schedules import get_forward_backward_func @@ -72,8 +60,27 @@ except (ImportError, ModuleNotFoundError): HAVE_MEGATRON_CORE = False +try: + from megatron.core.num_microbatches_calculator import ( + get_current_global_batch_size, + get_micro_batch_size, + get_num_microbatches, + reconfigure_num_microbatches_calculator, + ) + +except (ImportError, ModuleNotFoundError): + logging.warning("Megatron num_microbatches_calculator not found, using Apex version.") + from apex.transformer.pipeline_parallel.utils import ( + _reconfigure_microbatch_calculator as reconfigure_num_microbatches_calculator, + ) + from apex.transformer.pipeline_parallel.utils import ( + get_current_global_batch_size, + get_micro_batch_size, + get_num_microbatches, + ) + -__all__ = ["ModularizedAudioT5Model"] +__all__ = ["ModularizedAudioT5Model", "DecoderTextPromptModularizedAudioT5Model"] default_inference_config = {'tokens_to_generate': 30} @@ -815,7 +822,7 @@ def _reconfigure_and_process_inference_batch(self, batch, data_cfg): != data_cfg.global_batch_size // parallel_state.get_data_parallel_world_size() ): app_state = AppState() - _reconfigure_microbatch_calculator( + reconfigure_num_microbatches_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=global_batch_size_per_gpu * parallel_state.get_data_parallel_world_size(), @@ -825,7 +832,7 @@ def _reconfigure_and_process_inference_batch(self, batch, data_cfg): # NOTE: need to explicitly handle resetting for multi-validation else: app_state = AppState() - _reconfigure_microbatch_calculator( + reconfigure_num_microbatches_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=data_cfg.global_batch_size, @@ -1114,7 +1121,7 @@ def inference_epoch_end(self, outputs, mode, data_cfg): app_state = AppState() # TODO(zhehuai): add _restore_sequence_parallelism_args after sync to HEAD if hasattr(self, "_train_ds"): - _reconfigure_microbatch_calculator( + reconfigure_num_microbatches_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=self.cfg.data.train_ds.global_batch_size, @@ -1124,7 +1131,7 @@ def inference_epoch_end(self, outputs, mode, data_cfg): # When running `trainer.validate()`, the training dataset is not available. else: logging.warning('No training data found, reconfiguring microbatches based on validation batch sizes.') - _reconfigure_microbatch_calculator( + reconfigure_num_microbatches_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=data_cfg.global_batch_size, diff --git a/nemo/collections/multimodal/speech_llm/modules/__init__.py b/nemo/collections/multimodal/speech_llm/modules/__init__.py index d9562652ce84..7effb0894da7 100644 --- a/nemo/collections/multimodal/speech_llm/modules/__init__.py +++ b/nemo/collections/multimodal/speech_llm/modules/__init__.py @@ -17,4 +17,5 @@ AudioPerceptionModule, MultiAudioPerceptionModule, MultiFeatureAggregator, + TransformerCrossAttention, ) diff --git a/nemo/collections/multimodal/speech_llm/modules/common/__init__.py b/nemo/collections/multimodal/speech_llm/modules/common/__init__.py new file mode 100644 index 000000000000..8b137891791f --- /dev/null +++ b/nemo/collections/multimodal/speech_llm/modules/common/__init__.py @@ -0,0 +1 @@ + diff --git a/nemo/collections/multimodal/speech_llm/modules/common/audio_text_generation_utils.py b/nemo/collections/multimodal/speech_llm/modules/common/audio_text_generation_utils.py index 136418031586..4399c4174dd3 100644 --- a/nemo/collections/multimodal/speech_llm/modules/common/audio_text_generation_utils.py +++ b/nemo/collections/multimodal/speech_llm/modules/common/audio_text_generation_utils.py @@ -27,25 +27,25 @@ model_inference_strategy_dispatcher, ) from nemo.collections.nlp.modules.common.transformer.text_generation import OutputType -from nemo.utils import AppState +from nemo.utils import AppState, logging try: - from apex.transformer.pipeline_parallel.utils import _reconfigure_microbatch_calculator + from megatron.core import parallel_state, tensor_parallel - HAVE_APEX = True + HAVE_MEGATRON_CORE = True except (ImportError, ModuleNotFoundError): - HAVE_APEX = False + HAVE_MEGATRON_CORE = False try: - from megatron.core import parallel_state, tensor_parallel - - HAVE_MEGATRON_CORE = True + from megatron.core.num_microbatches_calculator import reconfigure_num_microbatches_calculator except (ImportError, ModuleNotFoundError): - - HAVE_MEGATRON_CORE = False + logging.warning("Megatron num_microbatches_calculator not found, using Apex version.") + from apex.transformer.pipeline_parallel.utils import ( + _reconfigure_microbatch_calculator as reconfigure_num_microbatches_calculator, + ) __all__ = [ "get_computeprob_response", @@ -520,7 +520,7 @@ def sample_sequence_batch( ): app_state = AppState() micro_batch_size = context_tokens.shape[0] - _reconfigure_microbatch_calculator( + reconfigure_num_microbatches_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=micro_batch_size, diff --git a/nemo/collections/multimodal/speech_llm/modules/perception_modules.py b/nemo/collections/multimodal/speech_llm/modules/perception_modules.py index a42c7d06cba0..021ac1ff3dad 100644 --- a/nemo/collections/multimodal/speech_llm/modules/perception_modules.py +++ b/nemo/collections/multimodal/speech_llm/modules/perception_modules.py @@ -29,7 +29,7 @@ from nemo.core.neural_types import AcousticEncodedRepresentation, AudioSignal, LengthsType, NeuralType, SpectrogramType from nemo.utils.decorators import experimental -__all__ = ["AudioPerceptionModule", "MultiAudioPerceptionModule"] +__all__ = ["AudioPerceptionModule", "MultiAudioPerceptionModule", "TransformerCrossAttention"] class AudioPerceptionModule(NeuralModule, Exportable): diff --git a/nemo/collections/nlp/data/information_retrieval/bert_embedding_dataset.py b/nemo/collections/nlp/data/information_retrieval/bert_embedding_dataset.py index 3c57b1af4cca..8bca618dce3d 100644 --- a/nemo/collections/nlp/data/information_retrieval/bert_embedding_dataset.py +++ b/nemo/collections/nlp/data/information_retrieval/bert_embedding_dataset.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from random import choices, sample from typing import Mapping, Optional import datasets @@ -50,7 +51,7 @@ def __init__( num_hard_negatives: int = 4, ): """ - file_path: Path to a JSONL dataset with (query,pos_doc,neg_doc) triplets in jsonl format. + file_path: Path to a JSONL dataset with (query,pos_doc,neg_doc) triplets in jsonl format. tokenizer: Tokenizer for the dataset. Instance of a class that inherits TokenizerSpec (ex: YTTM, SentencePiece). max_seq_length (int): maximum sequence length for each dataset examples. Examples will either be truncated to fit this length or dropped if they cannot be truncated. min_seq_length (int): min length of each data example in the dataset. Data examples will be dropped if they do not meet the min length requirements. @@ -132,7 +133,10 @@ def __getitem__(self, idx): if isinstance(idx, np.uint32): idx = idx.item() - assert idx < len(self.indexed_dataset) + if idx is not None: + assert idx < len(self.indexed_dataset) + else: + idx = -1 # idx may < 0 because we pad_samples_to_global_batch_size, e.g. id = -1 if idx < 0: idx = len(self) + idx @@ -159,10 +163,16 @@ def _process_example(self, example): if self.data_type == 'train': q = self.tokenizer.text_to_ids("query: " + example['query'].strip()) d = self.tokenizer.text_to_ids("passage: " + example['pos_doc'].strip()) - nd = [ - self.tokenizer.text_to_ids("passage: " + example['neg_doc'][i].strip()) - for i in range(self.num_hard_negatives) - ] + # handle cases where the required number of hard negatives are not present + if len(example['neg_doc']) < self.num_hard_negatives: + nd = example['neg_doc'] + # sample rest with replacement + nd = nd + choices(example['neg_doc'], k=self.num_hard_negatives - len(example['neg_doc'])) + else: + # sample without replacement + nd = sample(example['neg_doc'], k=self.num_hard_negatives) + assert len(nd) == self.num_hard_negatives, "Error in sampling required number of hard negatives" + nd = [self.tokenizer.text_to_ids("passage: " + ex.strip()) for ex in nd] elif self.data_type == 'query': q = self.tokenizer.text_to_ids("query: " + example['query'].strip()) @@ -292,6 +302,7 @@ def collate_fn(self, batch): 'input_ids': input_ids, 'token_type_ids': torch.zeros_like(input_ids), 'attention_mask': attention_mask, + 'metadata': metadata, } return processed_batch diff --git a/nemo/collections/nlp/data/information_retrieval/gpt_embedding_dataset.py b/nemo/collections/nlp/data/information_retrieval/gpt_embedding_dataset.py index e697d5ec3bf6..3a2a8152313e 100644 --- a/nemo/collections/nlp/data/information_retrieval/gpt_embedding_dataset.py +++ b/nemo/collections/nlp/data/information_retrieval/gpt_embedding_dataset.py @@ -27,7 +27,7 @@ from nemo.core.classes import Dataset from nemo.utils import logging -__all__ = ['GPTEmbeddingDataset'] +__all__ = ['GPTEmbeddingDataset', 'GPTRerankerDataset'] class GPTEmbeddingDataset(Dataset): @@ -49,7 +49,7 @@ def __init__( data_type: str = 'train', # train, query or doc ): """ - file_path: Path to a JSONL dataset with (query,pos_doc,neg_doc) triplets in jsonl format. + file_path: Path to a JSONL dataset with (query,pos_doc,neg_doc) triplets in jsonl format. tokenizer: Tokenizer for the dataset. Instance of a class that inherits TokenizerSpec (ex: YTTM, SentencePiece). max_seq_length (int): maximum sequence length for each dataset examples. Examples will either be truncated to fit this length or dropped if they cannot be truncated. min_seq_length (int): min length of each data example in the dataset. Data examples will be dropped if they do not meet the min length requirements. @@ -279,3 +279,138 @@ def collate_fn(self, batch): } return processed_batch + + +class GPTRerankerDataset(GPTEmbeddingDataset): + def __init__( + self, + file_path: str, + tokenizer: TokenizerSpec, + max_seq_length: int = 1024, + min_seq_length: int = 1, + add_bos: bool = False, + add_eos: bool = True, + max_num_samples: int = None, + seed: int = 1234, + index_mapping_dir: str = None, + virtual_tokens: int = 0, + memmap_workers: Optional[int] = None, + truncation_method: str = 'right', + special_tokens: Optional[Mapping[str, str]] = None, # special tokens, a dictory of {token_type: token} + data_type: str = 'train', # train, query or doc + ): + """ + file_path: Path to a JSONL dataset with (query,pos_doc,neg_doc) triplets in jsonl format. + tokenizer: Tokenizer for the dataset. Instance of a class that inherits TokenizerSpec (ex: YTTM, SentencePiece). + max_seq_length (int): maximum sequence length for each dataset examples. Examples will either be truncated to fit this length or dropped if they cannot be truncated. + min_seq_length (int): min length of each data example in the dataset. Data examples will be dropped if they do not meet the min length requirements. + add_bos (bool): Whether to add a beginning of sentence token to each data example + add_eos (bool): Whether to add an end of sentence token to each data example + seed: Random seed for data shuffling. + max_num_samples: Maximum number of samples to load. This can be > dataset length if you want to oversample data. If None, all samples will be loaded. + index_mapping_dir: Directory to save the index mapping to. If None, will write to the same folder as the dataset. + truncation_method: Truncation from which position. Options: ['left', 'right'] + special_tokens: special tokens for the chat prompts, a dictionary of {token_type: token}. Default: {'system_turn_start': '', 'turn_start': '', 'label_start': '', 'end_of_turn': '\n', "end_of_name": "\n"} + """ + super().__init__( + file_path=file_path, + tokenizer=tokenizer, + max_seq_length=max_seq_length, + min_seq_length=min_seq_length, + add_bos=add_bos, + add_eos=add_eos, + max_num_samples=max_num_samples, + seed=seed, + index_mapping_dir=index_mapping_dir, + virtual_tokens=virtual_tokens, + memmap_workers=memmap_workers, + truncation_method=truncation_method, + special_tokens=special_tokens, + data_type=data_type, + ) + + def _process_example(self, example): + """ + Create an example by concatenating text and answer. + Truncation is carried out when needed, but it is performed only on the prompt side. + BOS, EOS, and SEP, are added if specified. + """ + metadata = {k: v for k, v in example.items()} + if self.data_type == 'train': + qd = self.tokenizer.text_to_ids( + "query: " + example['query'].strip() + " passage: " + example['pos_doc'].strip() + ) + qnd = self.tokenizer.text_to_ids( + "query: " + example['query'].strip() + " passage: " + example['neg_doc'].strip() + ) + else: + qd = self.tokenizer.text_to_ids( + "query: " + example['query'].strip() + " passage: " + example['pos_doc'].strip() + ) + qnd = [] + + if self.virtual_tokens: + # (@adithyare) we are going to insert "pad/eos" tokens in the beginning of the text and context + # these pad/eos tokens are placeholders for virtual tokens for ptuning (if used) + qd = [self.tokenizer.eos_id] * self.virtual_tokens + qd # type: ignore + qnd = [self.tokenizer.eos_id] * self.virtual_tokens + qnd # type: ignore + + if self.add_bos: + qd = [self.tokenizer.bos_id] + qd # type: ignore + qnd = [self.tokenizer.bos_id] + qnd # type: ignore + + # TODO: (@adithyare) should probably add a warning before truncation + qd = qd[: self.max_seq_length - 1] + qnd = qnd[: self.max_seq_length - 1] + + if self.add_eos: + qd = qd + [self.tokenizer.eos_id] # type: ignore + qnd = qnd + [self.tokenizer.eos_id] # type: ignore + + processed_example = { + 'query_pos_doc': qd, + 'query_neg_doc': qnd, + 'metadata': metadata, + } + + return processed_example + + def collate_fn(self, batch): + input_ids = [] + metadata = [] + lengths = [] + max_length = -1 + for item in batch: + metadata.append(item['metadata']) + if self.data_type == 'train': + input_ids.append(item['query_pos_doc']) + lengths.append(len(item['query_pos_doc'])) + input_ids.append(item['query_neg_doc']) + lengths.append(len(item['query_neg_doc'])) + max_length = max(max_length, len(item['query_pos_doc']), len(item['query_neg_doc'])) + else: + input_ids.append(item['query_pos_doc']) + lengths.append(len(item['query_pos_doc'])) + max_length = max(max_length, len(item['query_pos_doc'])) + + max_length = min(self.max_seq_length, self._ceil_to_nearest(max_length, 16)) + assert max_length <= self.max_seq_length + + attention_mask = [self._create_attention_mask(max_length) for _ in input_ids] + attention_mask = torch.stack(attention_mask) + position_ids = [list(range(max_length)) for _ in input_ids] + position_ids = torch.LongTensor(position_ids) + input_ids = torch.LongTensor( + self._collate_item(input_ids, max_length=max_length, pad_id=self.tokenizer.eos_id) + ) + lengths = torch.LongTensor(lengths) - 1 # subtract 1 to account for the eos token + + processed_batch = { + 'tokens': input_ids, + 'attention_mask': attention_mask, + 'loss_mask': lengths, + 'position_ids': position_ids, + 'metadata': metadata, + } + + return processed_batch diff --git a/nemo/collections/nlp/data/language_modeling/megatron/data_samplers.py b/nemo/collections/nlp/data/language_modeling/megatron/data_samplers.py index 4a8b989a7b6d..622e2d759266 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/data_samplers.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/data_samplers.py @@ -100,13 +100,16 @@ def get_start_end_idx(self): end_idx = start_idx + self.micro_batch_size return start_idx, end_idx + def _get_padding_indices(self, pad_samples_num): + return range(-1, -pad_samples_num - 1, -1) + def __iter__(self): batch = [] # Last batch will be dropped if drop_last is not set False indices = range(self.consumed_samples, self.total_samples) if (not self.drop_last) and self.pad_samples_to_global_batch_size: pad_samples_num = -len(indices) % self.global_batch_size - pad_indices = [None] * pad_samples_num + pad_indices = self._get_padding_indices(pad_samples_num) indices = chain(indices, pad_indices) for idx in indices: @@ -125,6 +128,11 @@ def __iter__(self): yield batch[start_idx:end_idx] +class MegatronCorePretrainingSampler(MegatronPretrainingSampler): + def _get_padding_indices(self, pad_samples_num): + return [None] * pad_samples_num + + class MegatronPretrainingRandomSampler(BaseMegatronSampler): def __init__( self, diff --git a/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py index e16543a7568d..2e21c57dddd3 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py @@ -57,6 +57,7 @@ def __init__( tokens_to_generate: int = 0, memmap_workers: Optional[int] = None, hf_dataset: bool = False, + global_sample_mapping: bool = False, truncation_method: str = 'right', special_tokens: Optional[Mapping[str, str]] = None, # special tokens, a dictory of {token_type: token} is_test: bool = False, @@ -83,6 +84,7 @@ def __init__( index_mapping_dir: Directory to save the index mapping to. If None, will write to the same folder as the dataset. prompt_template: Prompt template to inject via an fstring. Formatted like Q: {context_key}\n\nA: {label_key} hf_dataset: Whether to load the json file with the HuggingFace dataset. otherwise, will load the jsonl file with the JSONLMemMapDataset. + global_sample_mapping: Whether to shuffle all data together, or shuffle the dataset within each epoch truncation_method: Truncation from which position. Options: ['left', 'right'] special_tokens: special tokens for the chat prompts, a dictionary of {token_type: token}. Default: {'system_turn_start': '', 'turn_start': '', 'label_start': '', 'end_of_turn': '\n', "end_of_name": "\n"} is_test: Whether this dataset is the test split. @@ -109,6 +111,7 @@ def __init__( self.tokens_to_generate = tokens_to_generate self.memmap_workers = memmap_workers self.hf_dataset = hf_dataset + self.global_sample_mapping = global_sample_mapping self.truncation_method = truncation_method self.is_test = is_test self.output_original_text = output_original_text @@ -176,7 +179,11 @@ def _maybe_validate_prompt_template(self): def _build_samples_mapping(self): if self.max_num_samples is not None: - osm = OnlineSampleMapping(dataset_size=len(self.indexed_dataset), num_samples=self.max_num_samples) + osm = ( + OnlineSampleMapping(dataset_size=len(self.indexed_dataset), num_samples=self.max_num_samples) + if not self.global_sample_mapping + else None + ) self.samples_mapping = get_samples_mapping( indexed_dataset=self.indexed_dataset, data_prefix=self.file_path, @@ -324,25 +331,17 @@ def _multiple_truncation(self, template_ids: List[List[int]], template_ids_keys: logging.warning(f'{key} is not long enough to truncate.') truncation_length = len(ids) - if self.truncation_method == 'left': - window_offset = truncation_length - elif self.truncation_method == 'right': - window_offset = 0 - else: - raise ValueError(f'{self.truncation_method} is not supported') - - window_length = len(ids) - truncation_length - template_ids[i] = ids[window_offset : window_offset + window_length] - else: - # If truncation_field is empty, we truncate template_ids (List[List[int]]) to make total ids < self.max_seq_length. - logging.warning( - f'`truncation_field` is empty, we truncate input from {self.truncation_method} based on truncation_method.' - ) + truncation_length_total -= truncation_length + template_ids[i] = self._truncation(ids, len(ids) - truncation_length) + + if truncation_length_total > 0: template_ids_lengths = [len(ids) for ids in template_ids] if self.truncation_method == 'left': iters = range(0, len(template_ids_lengths), 1) elif self.truncation_method == 'right': iters = range(len(template_ids_lengths) - 1, -1, -1) + # We need to truncate more to let context_ids + tokens_to_generate < self.max_seq_length + truncation_length_total += min(len(label_ids), self.tokens_to_generate) else: raise ValueError(f'{self.truncation_method} is not supported') @@ -350,22 +349,27 @@ def _multiple_truncation(self, template_ids: List[List[int]], template_ids_keys: for i in iters: if template_ids_lengths[i] >= truncation_length_total: template_ids_lengths[i] -= truncation_length_total - if self.truncation_method == 'left': - template_ids[i] = template_ids[i][-template_ids_lengths[i] :] - elif self.truncation_method == 'right': - template_ids[i] = template_ids[i][: template_ids_lengths[i]] - else: - raise ValueError(f'{self.truncation_method} is not supported') + template_ids[i] = self._truncation(template_ids[i], template_ids_lengths[i]) break else: truncation_length_total -= template_ids_lengths[i] template_ids_lengths[i] = 0 - template_ids[i] = [] + template_ids[i] = self._truncation(template_ids[i], template_ids_lengths[i]) context_ids = [i for ids in template_ids[:-1] for i in ids] label_ids = template_ids[-1] return context_ids, label_ids + def _truncation(self, ids, expect_length): + if expect_length == 0: + return [] + elif self.truncation_method == 'left': + return ids[-expect_length:] + elif self.truncation_method == 'right': + return ids[:expect_length] + else: + raise ValueError(f'{self.truncation_method} is not supported') + def _process_example(self, example): """ Create an example by concatenating text and answer. @@ -406,17 +410,6 @@ def _process_example(self, example): if self.add_eos: input_ids = input_ids + [self.tokenizer.eos_id] - if len(input_ids) > self.max_seq_length: - # this only happens if tuncation_field is not enough to truncate. - # context_ids can be empty if we truncate contexts. - # answer_ids can be empty if we truncate answers. - logging.warning( - f'After truncation, input ids length {len(input_ids)} still exceeds max sequence length {self.max_seq_length}' - ) - context_ids = context_ids[: self.max_seq_length] - input_ids = input_ids[: self.max_seq_length] - answer_ids = input_ids[len(context_ids) :] - # store metadata in dataset, in case user may have keys required in the prediction json files metadata = {k: v for k, v in example.items() if k not in self.prompt_template_keys} if self.output_original_text: diff --git a/nemo/collections/nlp/data/language_modeling/megatron/retro_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/retro_dataset.py index 7d604c0b51bc..d8dafd69c658 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/retro_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/retro_dataset.py @@ -45,11 +45,14 @@ from nemo.collections.nlp.modules.common.megatron.utils import get_ltor_masks_and_position_ids - HAVE_MEGATRON_CORE = True + HAVE_TE_AND_MEGATRON_CORE = True except (ImportError, ModuleNotFoundError): - HAVE_MEGATRON_CORE = False + HAVE_TE_AND_MEGATRON_CORE = False + from typing import Any + + RetroConfig = Any class RETRODataset(Dataset): @@ -129,57 +132,64 @@ def build_train_valid_test_datasets( tokenizer, ): - # gpt dataset - train_ds, valid_ds, test_ds = gpt_train_valid_test_datasets_provider(cfg, train_valid_test_num_samples, tokenizer) + if HAVE_TE_AND_MEGATRON_CORE: - gpt_datasets = { - "train": (train_ds, train_valid_test_num_samples[0]), - "valid": (valid_ds, train_valid_test_num_samples[1]), - "test": (test_ds, train_valid_test_num_samples[2]), - } + # gpt dataset + train_ds, valid_ds, test_ds = gpt_train_valid_test_datasets_provider( + cfg, train_valid_test_num_samples, tokenizer + ) - retro_train_ds, retro_valid_ds, retro_test_ds = get_retro_datasets( - config=retro_config, - gpt_datasets=gpt_datasets, - sample_length=seq_length, - eod_token_id=tokenizer.eos_id, - ) + gpt_datasets = { + "train": (train_ds, train_valid_test_num_samples[0]), + "valid": (valid_ds, train_valid_test_num_samples[1]), + "test": (test_ds, train_valid_test_num_samples[2]), + } - train_ds = ( - RETRODataset( - cfg=cfg, - retro_config=retro_config, - tokenizer=tokenizer, - mcore_retro_dataset=retro_train_ds, - number_samples_with_neighbors=train_valid_test_num_samples[0], + retro_train_ds, retro_valid_ds, retro_test_ds = get_retro_datasets( + config=retro_config, + gpt_datasets=gpt_datasets, + sample_length=seq_length, + eod_token_id=tokenizer.eos_id, ) - if retro_train_ds - else None - ) - valid_ds = ( - RETRODataset( - cfg=cfg, - retro_config=retro_config, - tokenizer=tokenizer, - mcore_retro_dataset=retro_valid_ds, - number_samples_with_neighbors=train_valid_test_num_samples[1], + + train_ds = ( + RETRODataset( + cfg=cfg, + retro_config=retro_config, + tokenizer=tokenizer, + mcore_retro_dataset=retro_train_ds, + number_samples_with_neighbors=train_valid_test_num_samples[0], + ) + if retro_train_ds + else None ) - if retro_valid_ds - else None - ) - test_ds = ( - RETRODataset( - cfg=cfg, - retro_config=retro_config, - tokenizer=tokenizer, - mcore_retro_dataset=retro_test_ds, - number_samples_with_neighbors=train_valid_test_num_samples[2], + valid_ds = ( + RETRODataset( + cfg=cfg, + retro_config=retro_config, + tokenizer=tokenizer, + mcore_retro_dataset=retro_valid_ds, + number_samples_with_neighbors=train_valid_test_num_samples[1], + ) + if retro_valid_ds + else None + ) + test_ds = ( + RETRODataset( + cfg=cfg, + retro_config=retro_config, + tokenizer=tokenizer, + mcore_retro_dataset=retro_test_ds, + number_samples_with_neighbors=train_valid_test_num_samples[2], + ) + if retro_test_ds + else None ) - if retro_test_ds - else None - ) - return train_ds, valid_ds, test_ds + return train_ds, valid_ds, test_ds + else: + logging.warn('Megatron core is not installed. Returning None') + return def gpt_train_valid_test_datasets_provider(cfg, train_val_test_num_samples, tokenizer): diff --git a/nemo/collections/nlp/models/dialogue/dialogue_s2s_generation_model.py b/nemo/collections/nlp/models/dialogue/dialogue_s2s_generation_model.py index 73f09f62b1d5..48f3e5127a88 100644 --- a/nemo/collections/nlp/models/dialogue/dialogue_s2s_generation_model.py +++ b/nemo/collections/nlp/models/dialogue/dialogue_s2s_generation_model.py @@ -35,12 +35,13 @@ from nemo.utils.decorators import deprecated_warning try: - from apex.transformer.pipeline_parallel.utils import _reconfigure_microbatch_calculator - - HAVE_APEX = True -except: - HAVE_APEX = False + from megatron.core.num_microbatches_calculator import reconfigure_num_microbatches_calculator +except (ImportError, ModuleNotFoundError): + logging.warning("Megatron num_microbatches_calculator not found, using Apex version.") + from apex.transformer.pipeline_parallel.utils import ( + _reconfigure_microbatch_calculator as reconfigure_num_microbatches_calculator, + ) __all__ = ['DialogueS2SGenerationModel'] @@ -237,7 +238,7 @@ def generate_candidates(self, input_ids, attn_masks, labels): generated_tokens = self.language_model.generate(**param_dict) elif self.cfg.library == 'megatron': - _reconfigure_microbatch_calculator( + reconfigure_num_microbatches_calculator( rank=0, # This doesn't matter since it is only used for logging rampup_batch_size=None, global_batch_size=1, diff --git a/nemo/collections/nlp/models/information_retrieval/megatron_bert_embedding_model.py b/nemo/collections/nlp/models/information_retrieval/megatron_bert_embedding_model.py index 102ab5ec0f84..5e38b61938c9 100644 --- a/nemo/collections/nlp/models/information_retrieval/megatron_bert_embedding_model.py +++ b/nemo/collections/nlp/models/information_retrieval/megatron_bert_embedding_model.py @@ -13,30 +13,18 @@ # limitations under the License. import logging +import os -try: - from apex.transformer.pipeline_parallel.utils import get_num_microbatches - - HAVE_APEX = True +import numpy as np -except (ImportError, ModuleNotFoundError): - - HAVE_APEX = False -try: - from megatron.core import parallel_state - from megatron.core.pipeline_parallel.schedules import get_forward_backward_func - from megatron.core.transformer.module import Float16Module as MCoreFloat16Module - - HAVE_MEGATRON_CORE = True -except (ImportError, ModuleNotFoundError): - TransformerConfig = ApexGuardDefaults - HAVE_MEGATRON_CORE = False import torch from megatron.core.models.bert.bert_layer_specs import bert_layer_with_transformer_engine_spec from omegaconf import DictConfig, OmegaConf, open_dict from omegaconf.dictconfig import DictConfig from pytorch_lightning.trainer.trainer import Trainer +from torch.distributed import all_gather as all_gather_no_backprop +from torch.distributed.nn.functional import all_gather as all_gather_with_backprop from nemo.collections.nlp.data.information_retrieval.bert_embedding_dataset import BertEmbeddingDataset from nemo.collections.nlp.data.language_modeling.megatron.data_samplers import ( @@ -58,17 +46,35 @@ from nemo.collections.nlp.parts.utils_funcs import get_last_rank from nemo.utils import logging + try: from megatron.core import parallel_state + from megatron.core.pipeline_parallel.schedules import get_forward_backward_func + from megatron.core.transformer.module import Float16Module as MCoreFloat16Module HAVE_MEGATRON_CORE = True except (ImportError, ModuleNotFoundError): - + TransformerConfig = ApexGuardDefaults ModelParallelConfig = ApexGuardDefaults HAVE_MEGATRON_CORE = False +try: + from megatron.core.num_microbatches_calculator import get_num_microbatches + +except (ImportError, ModuleNotFoundError): + logging.warning("Megatron num_microbatches_calculator not found, using Apex version.") + from apex.transformer.pipeline_parallel.utils import get_num_microbatches + + +def listify(tensor): + l_tensor = [] + for t in tensor: + r = t[:].unsqueeze(0).cpu() + l_tensor.append(r) + return l_tensor + class MegatronBertEmbeddingModel(MegatronBertModel): """ @@ -82,6 +88,10 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): self.cross_entropy_loss = torch.nn.CrossEntropyLoss(label_smoothing=cfg.get('label_smoothing', 0.0)) softmax_temp = cfg.get('softmax_temp', 0.05) self.scale = 1.0 / softmax_temp + self.hard_negatives_to_train = self.cfg.data.get("hard_negatives_to_train", 4) + self.global_inbatch_negatives = self.cfg.get("global_inbatch_negatives", True) + self.backprop_type = self.cfg.get("backprop_type", "local") + assert self.backprop_type in ["local", "global"], "Backprop type must be `local` or `global`" def model_provider_func(self, pre_process, post_process): cfg = self.cfg @@ -149,34 +159,61 @@ def model_provider_func(self, pre_process, post_process): return model - def build_train_valid_test_datasets(self): + def build_train_valid_test_datasets(self, is_train=True): self._train_ds = None self._validation_ds = None self._test_ds = None - self._train_ds = BertEmbeddingDataset( - self.cfg.data.data_train, - tokenizer=self.tokenizer, - add_bos=True, - num_hard_negatives=self.cfg.data.get("hard_negatives_to_train", 4), - max_seq_length=self.cfg.encoder_seq_length, - ) - if self.cfg.data.data_validation: - self._validation_ds = BertEmbeddingDataset( - self.cfg.data.data_validation, + if is_train: + self._train_ds = BertEmbeddingDataset( + self.cfg.data.data_train, tokenizer=self.tokenizer, add_bos=True, num_hard_negatives=self.cfg.data.get("hard_negatives_to_train", 4), max_seq_length=self.cfg.encoder_seq_length, ) + if self.cfg.data.data_validation: + self._validation_ds = BertEmbeddingDataset( + self.cfg.data.data_validation, + tokenizer=self.tokenizer, + add_bos=True, + num_hard_negatives=self.cfg.data.get("hard_negatives_to_train", 4), + max_seq_length=self.cfg.encoder_seq_length, + ) + + else: + logging.info(f'Building test dataset') + if self.cfg.data.data_test.query_file_names is None or self.cfg.data.data_test.doc_file_names is None: + return [] + + query_dataset = BertEmbeddingDataset( + file_path=self.cfg.data.data_test.query_file_names[0], + tokenizer=self.tokenizer, + max_seq_length=self.cfg.encoder_seq_length, + add_bos=True, + add_eos=True, + data_type="query", + ) + doc_dataset = BertEmbeddingDataset( + file_path=self.cfg.data.data_test.doc_file_names[0], + tokenizer=self.tokenizer, + max_seq_length=self.cfg.encoder_seq_length, + add_bos=True, + add_eos=True, + data_type="doc", + ) + + self._test_ds = [query_dataset, doc_dataset] if self._train_ds is not None: logging.info(f'Length of train dataset: {len(self._train_ds)}') if self._validation_ds is not None: logging.info(f'Length of val dataset: {len(self._validation_ds)}') if self._test_ds is not None: - logging.info(f'Length of test dataset: {len(self._test_ds)}') + logging.info(f'Length of test query dataset: {len(self._test_ds[0])}') + logging.info(f'Length of test doc dataset: {len(self._test_ds[1])}') + logging.info(f'Finished building SBert datasets.') return self._train_ds, self._validation_ds, self._test_ds @@ -210,6 +247,9 @@ def setup(self, stage=None): if stage == 'predict': return + elif stage == 'test': + self.build_train_valid_test_datasets(is_train=False) + self.setup_test_data(self.cfg.data) else: # TODO: consider adding a ModelPT guard to check if model is being restored. # allowing restored models to optionally setup datasets @@ -300,7 +340,8 @@ def build_pretraining_data_loader(self, dataset, consumed_samples): global_batch_size=self.cfg.global_batch_size, data_parallel_rank=parallel_state.get_data_parallel_rank(), data_parallel_size=parallel_state.get_data_parallel_world_size(), - drop_last=self.cfg.get('drop_last', True), + drop_last=self.cfg.get('drop_last', False), + pad_samples_to_global_batch_size=not self.cfg.get('drop_last', False), ) elif self.cfg.data.dataloader_type == 'cyclic': batch_sampler = MegatronPretrainingRandomSampler( @@ -309,7 +350,8 @@ def build_pretraining_data_loader(self, dataset, consumed_samples): micro_batch_size=self.cfg.micro_batch_size, data_parallel_rank=parallel_state.get_data_parallel_rank(), data_parallel_size=parallel_state.get_data_parallel_world_size(), - drop_last=self.cfg.get('drop_last', True), + drop_last=self.cfg.get('drop_last', False), + pad_samples_to_global_batch_size=not self.cfg.get('drop_last', False), ) else: raise ValueError('cfg.data.dataloader_type must be "single" or "cyclic"') @@ -345,6 +387,24 @@ def setup_validation_data(self, cfg): ) self._validation_dl = self.build_pretraining_data_loader(self._validation_ds, consumed_samples) + def setup_eval_dataloader(self, datasets): + dataloaders = [] + for dataset in datasets: + eval_dl = self.build_pretraining_data_loader( + dataset=dataset, + consumed_samples=0, + ) + dataloaders.append(eval_dl) + return dataloaders + + def setup_test_data(self, cfg): + if self._test_ds: + logging.info( + f'Setting up test dataloader with len(len(self._test_ds)): {len(self._test_ds[0])}, {len(self._test_ds[1])}' + ) + self._test_dl = self.setup_eval_dataloader(self._test_ds) + return + def training_step(self, dataloader_iter): self._optimizer.zero_grad() @@ -435,14 +495,18 @@ def training_step(self, dataloader_iter): self.log('lr', lr, batch_size=1) self.log('global_step', self.trainer.global_step, prog_bar=True, batch_size=1) self.log( - 'consumed_samples', self._compute_consumed_samples_after_training_step(), prog_bar=True, batch_size=1, + 'consumed_samples', + self._compute_consumed_samples_after_training_step(), + prog_bar=True, + batch_size=1, ) return loss_mean[0] def get_forward_output_and_loss_func(self): def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_layers=None): - batches = next(dataloader_iter)[0] + batches, _, dl_idx = next(dataloader_iter) + metadata = batches.pop('metadata') batches = {k: v.cuda(non_blocking=True) for k, v in batches.items()} if self.mcore_bert: @@ -466,15 +530,170 @@ def loss_func(output_tensor): loss = lm_loss reduced_loss = average_losses_across_data_parallel_group([loss, lm_loss]) - return loss, {'loss': reduced_loss} + if 'hs' in loss_dict: + # metadata = batches.get('metadata', [{}] * len(batches['input_ids'])) + return loss, { + 'loss': reduced_loss, + 'd_hs': loss_dict['hs'], + 'q_hs': loss_dict['hs'], + 'metadata': metadata, + 'dl_idx': dl_idx, + } + else: + return loss, {'loss': reduced_loss} return output_tensor, loss_func return fwd_output_and_loss_func - def loss_func(self, output_tensor): + def validation_step(self, dataloader_iter): + prefix = "test" if self.trainer.testing else "val" + if self.cfg.data.dataloader_type == "LDDL": + seq_length = dataloader_iter.iterator.get_seqlen() + else: + seq_length = self.cfg.encoder_seq_length + + fwd_bwd_function = get_forward_backward_func() - chunks = output_tensor.chunk(self.cfg.micro_batch_size) + losses_reduced_per_micro_batch = fwd_bwd_function( + forward_step_func=self.get_forward_output_and_loss_func(), + data_iterator=self._make_data_iterator_list(dataloader_iter), + model=self.model, + num_microbatches=get_num_microbatches(), + forward_only=True, + seq_length=seq_length, + micro_batch_size=self.cfg.micro_batch_size, + ) + + if losses_reduced_per_micro_batch: + loss_tensors_list = [loss_reduced['loss'] for loss_reduced in losses_reduced_per_micro_batch] + loss_tensor = torch.vstack(loss_tensors_list) + loss_mean = loss_tensor.mean(axis=0) + else: + loss_mean = torch.tensor([0.0]).cuda() + + loss = loss_mean[0] + if prefix == 'val': + self.validation_step_outputs.append(loss) + else: + assert len(losses_reduced_per_micro_batch) == 1 + dataloader_idx = losses_reduced_per_micro_batch[0]['dl_idx'] + self.test_step_outputs[dataloader_idx].append(losses_reduced_per_micro_batch[0]) + return loss + + def on_test_epoch_end(self): + for dataloader_idx, output in enumerate(self.test_step_outputs): + self.gather_and_maybe_write_predictions(output, self.cfg.data.data_test, 'test', dataloader_idx) + + def gather_and_maybe_write_predictions(self, output, data_cfg, mode, dataloader_idx=0): + if not data_cfg.get("write_embeddings_to_file", False): + return True + gathered_output_batches = [None for _ in range(parallel_state.get_data_parallel_world_size())] + torch.distributed.all_gather_object( + gathered_output_batches, + [ + { + 'q_hs': batch['q_hs'], + 'd_hs': batch['d_hs'], + 'metadata': batch['metadata'], + } + for batch in output + ], + group=parallel_state.get_data_parallel_group(), + ) + + # Remove duplicate examples due to distributed sampler. + deduplicated_outputs = { + 'q_hs': [], + 'd_hs': [], + 'metadata': [], + } + total_size, skipped = 0, 0 + for rank in range(0, parallel_state.get_data_parallel_world_size()): + for batch in gathered_output_batches[rank]: + l_q_hs = listify(batch['q_hs']) + l_d_hs = listify(batch['d_hs']) + l_m = batch['metadata'] + assert len(l_m) == len(l_q_hs) == len(l_d_hs) + for q_hs, d_hs, metadata in zip( + l_q_hs, + l_d_hs, + l_m, + ): + total_size += 1 + if not metadata.get("__AUTOGENERATED__", False): + deduplicated_outputs['q_hs'].append(q_hs) + deduplicated_outputs['d_hs'].append(d_hs) + deduplicated_outputs['metadata'].append(metadata) + else: + skipped += 1 + + logging.info( + f"{total_size-skipped} deduplicated outputs in dataloader:{dataloader_idx}, (skipped {skipped} autogenerated examples)." + ) + + # Write predictions to file + if self.global_rank == 0 and data_cfg.get("write_embeddings_to_file", False): + logging.info( + f"Total deduplicated inference data size: {total_size} to {len(deduplicated_outputs['metadata'])}" + ) + + # Check if the user provided a prefix path to the file(s) they want to write. + if not hasattr(data_cfg, "output_file_path_prefix") or data_cfg.output_file_path_prefix is None: + raise ValueError( + f"Cannot write predictions to file when output_file_path_prefix is not set or present in the yaml config file." + ) + filename_log_key = f"{mode}_{data_cfg.names[dataloader_idx]}" + consumed_samples = self._compute_consumed_samples_after_training_step() + fldr_path = f"{data_cfg.output_file_path_prefix}/consumed_samples{consumed_samples}/{filename_log_key}" + self.write_embeddings_to_file(deduplicated_outputs, fldr_path, dataloader_idx) + return deduplicated_outputs, total_size + + def write_embeddings_to_file(self, outputs, output_file_path, d_idx): + emb_type = 'query' if d_idx == 0 else 'doc' + hs = torch.cat(outputs['q_hs' if d_idx == 0 else 'd_hs'], dim=0) + hs_npy = hs.float().numpy() + emb_fldr = f"{output_file_path}" + os.makedirs(emb_fldr, exist_ok=True) + with open(f"{output_file_path}/{emb_type}.ids", "w") as f: + for m in outputs['metadata']: + f.write(m[f"{emb_type}_id"] + "\n") + np.save(f"{emb_fldr}/{emb_type}.npy", hs_npy) + return True + + def inference_loss_func(self, eos_tensors): + hs = eos_tensors + _blank = torch.zeros(1, device=hs.device, dtype=hs.dtype)[0] + return { + 'hs': eos_tensors, + 'lm loss': _blank, + } + + def _gather_global_inbatch_representations(self, local_tensor): + local_tensor = local_tensor.contiguous() + if self.backprop_type == 'local': + global_tensors = [ + torch.zeros_like(local_tensor) for _ in range(parallel_state.get_data_parallel_world_size()) + ] + all_gather_no_backprop(global_tensors, local_tensor, group=parallel_state.get_data_parallel_group()) + global_tensors[parallel_state.get_data_parallel_rank()] = local_tensor + global_tensors = torch.cat(global_tensors, dim=0) + + else: + global_tensors = all_gather_with_backprop(local_tensor) + global_tensors = torch.cat(global_tensors, dim=0) + + return global_tensors + + def loss_func(self, output_tensor): + if self.global_inbatch_negatives and self.trainer.training: + output_tensor = self._gather_global_inbatch_representations(output_tensor) + if self.trainer.testing: + return self.inference_loss_func(output_tensor) + + num_tensors_per_example = 2 + self.hard_negatives_to_train + bs = output_tensor.shape[0] // num_tensors_per_example + chunks = output_tensor.chunk(bs) queries = torch.stack([item[0] for item in chunks]) # shape (bs, embedding_dim) positives = torch.stack([item[1] for item in chunks]) # shape (bs, embedding_dim) @@ -483,16 +702,21 @@ def loss_func(self, output_tensor): ) # shape (bs, bs); each positive is negative for other queries. hard_negs = [ - torch.stack([item[i + 2] for item in chunks]) - for i in range(self.cfg.data.get("hard_negatives_to_train", 4)) + torch.stack([item[i + 2] for item in chunks]) for i in range(self.hard_negatives_to_train) ] # List of length "num_negatives", each tensor of shape (bs, embedding_dim) hard_negs_scores = ( - torch.multiply(queries.unsqueeze(0).repeat(len(hard_negs), 1, 1), torch.stack(hard_negs),).sum(axis=-1).T + torch.multiply( + queries.unsqueeze(0).repeat(len(hard_negs), 1, 1), + torch.stack(hard_negs), + ) + .sum(axis=-1) + .T ) # shape = (bs, num_negatives); Hard negatives are not shared between queries. scores = torch.cat([pos_inbatch_negs_scores, hard_negs_scores], axis=1) + scores = scores.clamp(-1.0, 1.0) scores *= self.scale labels = torch.tensor( diff --git a/nemo/collections/nlp/models/information_retrieval/megatron_gpt_embedding_model.py b/nemo/collections/nlp/models/information_retrieval/megatron_gpt_embedding_model.py index 67fd2b1b6c62..c7565f45358e 100644 --- a/nemo/collections/nlp/models/information_retrieval/megatron_gpt_embedding_model.py +++ b/nemo/collections/nlp/models/information_retrieval/megatron_gpt_embedding_model.py @@ -36,11 +36,6 @@ except (ImportError, ModuleNotFoundError): HAVE_MEGATRON_CORE = False -try: - - HAVE_APEX = True -except (ImportError, ModuleNotFoundError): - HAVE_APEX = False def listify(tensor): @@ -52,6 +47,17 @@ def listify(tensor): return l_tensor +def _gather_global_inbatch_representations(local_eos_tensor): + local_eos_tensor = local_eos_tensor.contiguous() + global_eos_tensors = [ + torch.zeros_like(local_eos_tensor) for _ in range(parallel_state.get_data_parallel_world_size()) + ] + torch.distributed.all_gather(global_eos_tensors, local_eos_tensor, group=parallel_state.get_data_parallel_group()) + global_eos_tensors[parallel_state.get_data_parallel_rank()] = local_eos_tensor + global_eos_tensors = torch.cat(global_eos_tensors, dim=0) + return global_eos_tensors + + class MegatronGPTEmbeddingModel(MegatronGPTSFTModel): def __init__(self, cfg: DictConfig, trainer: Trainer): super().__init__(cfg, trainer=trainer) @@ -412,25 +418,20 @@ def inference_loss_func(self, loss_mask, num_valid_tokens_in_ub, eos_tensors): hs = eos_tensors hs = torch.nn.functional.normalize(hs, dim=1) _blank = torch.zeros(1, device=hs.device, dtype=hs.dtype)[0] - return _blank, hs, hs, _blank, _blank, _blank - - def _gather_global_inbatch_representations(self, local_eos_tensor): - local_eos_tensor = local_eos_tensor.contiguous() - global_eos_tensors = [ - torch.zeros_like(local_eos_tensor) for _ in range(parallel_state.get_data_parallel_world_size()) - ] - torch.distributed.all_gather( - global_eos_tensors, local_eos_tensor, group=parallel_state.get_data_parallel_group() - ) - global_eos_tensors[parallel_state.get_data_parallel_rank()] = local_eos_tensor - global_eos_tensors = torch.cat(global_eos_tensors, dim=0) - return global_eos_tensors + return { + "loss": _blank, + "query_hs": hs, + "pos_doc_hs": hs, + "pos_cs": _blank, + "neg_cs": _blank, + "diff_cs": _blank, + } def loss_func(self, loss_mask, num_valid_tokens_in_ub, output_tensor): idx = torch.arange(output_tensor.shape[1], device=output_tensor.device) eos_tensors = output_tensor[loss_mask, idx, :] if self.global_inbatch_negatives and self.trainer.training: - eos_tensors = self._gather_global_inbatch_representations(eos_tensors) + eos_tensors = _gather_global_inbatch_representations(eos_tensors) if not self.trainer.training: return self.inference_loss_func(loss_mask, num_valid_tokens_in_ub, eos_tensors) bs = eos_tensors.shape[0] // 3 @@ -464,4 +465,11 @@ def loss_func(self, loss_mask, num_valid_tokens_in_ub, output_tensor): query_hs = query_hs.clone().detach() pos_doc_hs = pos_doc_hs.clone().detach() diff_cs = pos_cs - neg_cs - return loss, query_hs, pos_doc_hs, pos_cs, neg_cs, diff_cs + return { + "loss": loss, + "query_hs": query_hs, + "pos_doc_hs": pos_doc_hs, + "pos_cs": pos_cs, + "neg_cs": neg_cs, + "diff_cs": diff_cs, + } diff --git a/nemo/collections/nlp/models/information_retrieval/megatron_gpt_reranker_model.py b/nemo/collections/nlp/models/information_retrieval/megatron_gpt_reranker_model.py new file mode 100644 index 000000000000..e316871fe607 --- /dev/null +++ b/nemo/collections/nlp/models/information_retrieval/megatron_gpt_reranker_model.py @@ -0,0 +1,301 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import os + +import numpy as np +import torch +from omegaconf import DictConfig, ListConfig +from pytorch_lightning.trainer.trainer import Trainer + +from nemo.collections.nlp.data.information_retrieval.gpt_embedding_dataset import GPTRerankerDataset +from nemo.collections.nlp.data.language_modeling.megatron.base_dataset_utils import ( + get_datasets_weights_and_num_samples, +) +from nemo.collections.nlp.data.language_modeling.megatron.blendable_dataset import BlendableDataset +from nemo.collections.nlp.models.information_retrieval.megatron_gpt_embedding_model import ( + MegatronGPTEmbeddingModel, + _gather_global_inbatch_representations, +) +from nemo.utils import logging + +try: + from megatron.core import parallel_state + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + + HAVE_MEGATRON_CORE = False + + +def listify(tensor): + l_tensor = [] + for t in tensor: + for rid in range(t.shape[0]): + r = t[rid, :].unsqueeze(0).cpu() + l_tensor.append(r) + return l_tensor + + +class MegatronGPTRerankerModel(MegatronGPTEmbeddingModel): + def __init__(self, cfg: DictConfig, trainer: Trainer): + self.reward_model_loss = cfg.get("reward_model_loss", False) + super().__init__(cfg, trainer=trainer) + + def model_provider_func(self, pre_process, post_process): + # (@adithyare) We need post_process to be False to get hidden states in the loss_func + return super().model_provider_func(pre_process, post_process=False) + + def maybe_setup_test(self): + if hasattr(self.cfg.data, 'test_ds') and self.cfg.data.test_ds.get('file_names', None) is not None: + self._test_dl = self.setup_eval_dataloader(self._test_ds, self.cfg.data.test_ds) + return + + def maybe_build_test(self): + if hasattr(self.cfg.data, 'test_ds') and self.cfg.data.test_ds.get('file_names', None) is not None: + logging.info('Building GPT Reranker test datasets.') + # Wrap this in a list since the general finetuning parent class supports multi-validation. + self._test_ds = self._build_dataset(self.cfg.data.test_ds, is_train=False) + + def _build_dataset(self, data_cfg, is_train=True): + packed_sequence = data_cfg.get("packed_sequence", False) + + # Determine if we are using a single dataset or a list of datasets. + if is_train: + # Construct the data prefix list for `get_datasets_weights_and_num_samples()` + # that is of the format [weight1,file_name1,weight2,file_name2,...] + if data_cfg.concat_sampling_probabilities is None or not isinstance( + data_cfg.concat_sampling_probabilities, ListConfig + ): + raise ValueError( + ( + f"concat_sampling_probabilities must be a ListConfig with the same number of files in file_names." + f"Found: {data_cfg.concat_sampling_probabilities}" + ) + ) + + if len(data_cfg.get('concat_sampling_probabilities', None)) != len(data_cfg.file_names): + raise ValueError( + ( + f"concat_sampling_probabilities must be of the same size as file_names.", + f"Provided size {len(data_cfg.concat_sampling_probabilities)}, number of datasets {len(data_cfg.file_names)}", + ) + ) + + data_prefix = [] + for weight, prefix in zip(data_cfg.concat_sampling_probabilities, data_cfg.file_names): + data_prefix.append(weight) + data_prefix.append(prefix) + + if self.trainer.max_steps is None or self.trainer.max_steps <= 0: + raise ValueError( + f'Trainer max_steps must be set to a positive integer. Found {self.trainer.max_steps}' + ) + num_train_samples = [self.trainer.max_steps * data_cfg.global_batch_size] + _, _, num_train_samples_per_dataset = get_datasets_weights_and_num_samples(data_prefix, num_train_samples) + num_train_samples_after_blend = sum([x[0] for x in num_train_samples_per_dataset]) + else: + num_train_samples_per_dataset = [[None]] * len(data_cfg.file_names) + + # Check dataset max_seq_legnth and max_position_embeddings size + if ( + self.cfg.get('position_embedding_type', None) in [None, 'learned_absolute'] + and data_cfg.max_seq_length > self.cfg.max_position_embeddings + ): + logging.warning( + f"Set dataset max_seq_length to max_position_embeddings {self.cfg.max_position_embeddings} if using learned_absolute position embedding" + ) + data_cfg.max_seq_length = self.cfg.max_position_embeddings + + # TE requires that the first input dim is divisible by 8 and the second by 16 for fp8 + # When using sequence parallel, sequence will further be split by TP size + pad_seq_length_to_mult = ( + 8 * self.cfg.get('tensor_model_parallel_size', 1) if self.cfg.get('sequence_parallel', False) else 16 + ) + pad_seq_length_to_mult *= self.cfg.get('context_parallel_size', 1) + + datasets = [] + for file_path, num_samples in zip(data_cfg.file_names, num_train_samples_per_dataset): + dataset = GPTRerankerDataset( + file_path=file_path, + tokenizer=self.tokenizer, + max_seq_length=data_cfg.max_seq_length, + min_seq_length=data_cfg.min_seq_length, + add_bos=data_cfg.get('add_bos', False), + add_eos=data_cfg.get('add_eos', True), + max_num_samples=num_samples[0], + seed=data_cfg.get('seed', 1234), + index_mapping_dir=data_cfg.get('index_mapping_dir', None), + virtual_tokens=self.virtual_tokens, + memmap_workers=data_cfg.get( + 'memmap_workers', None + ), # used to set num. of workers to create the memmap index files + truncation_method=data_cfg.get( + 'truncation_method', 'right' + ), # used to choose truncation method. Options: ['random', 'left', 'right'] + special_tokens=self.cfg.data.get( + 'chat_prompt_tokens', None + ), # special tokens for the chat prompts, a dictionary of {token_type: token}. Default: {'system_turn_start': '', 'turn_start': '', 'label_start': '', 'end_of_turn': '\n', "end_of_name": "\n"} + data_type="train" if is_train else "validation", + ) + datasets.append(dataset) + if is_train: + if packed_sequence: + num_train_samples_after_blend = sum(len(dataset) for dataset in datasets) + dataset = BlendableDataset( + datasets=datasets, weights=data_cfg.concat_sampling_probabilities, size=num_train_samples_after_blend + ) + return dataset + else: + return datasets + + def training_step_fwd_bwd_step_call(self, dataloader_iter, forward_only): + loss_mean, non_loss_tensors = self.fwd_bwd_step(dataloader_iter, forward_only) + logit_diff = non_loss_tensors['logit_diff'][0].item() + self.log("logit_diff", logit_diff, prog_bar=True, rank_zero_only=True, batch_size=1) + return loss_mean + + def inference_step_validation_call(self, batch, batch_idx, data_cfg, dataloader_idx=0): + metadata = batch.get('metadata', [{}] * len(batch['tokens'])) + loss, non_loss_tensors = self.local_validation_step(itertools.chain([dataloader_idx], [batch])) + outputs = { + 'loss': loss, + 'metadata': metadata, # [dict] + 'query_pos_doc_logit': non_loss_tensors['query_pos_doc_logit'], # [batch_size, hidden_size] + } + return outputs + + def inference_loss_func(self, loss_mask, num_valid_tokens_in_ub, eos_tensors): + query_pos_doc_hs = eos_tensors + _blank = torch.zeros(1, device=query_pos_doc_hs.device, dtype=query_pos_doc_hs.dtype)[0] + return { + "loss": _blank, + "query_pos_doc_logit": query_pos_doc_hs, + "query_neg_doc_logit": _blank, + "logit_diff": _blank, + } + + def loss_func(self, loss_mask, num_valid_tokens_in_ub, output_tensor): + idx = torch.arange(output_tensor.shape[1], device=output_tensor.device) + eos_tensors = output_tensor[loss_mask, idx, :] # (bs x 1) + if self.global_inbatch_negatives and self.trainer.training: + eos_tensors = _gather_global_inbatch_representations(eos_tensors) + if not self.trainer.training: + return self.inference_loss_func(loss_mask, num_valid_tokens_in_ub, eos_tensors) + bs = eos_tensors.shape[0] // 2 + query_pos_doc_hs = eos_tensors[::2, :] # every second tensor from idx 0 is a query w pos_doc (bs x 1) + query_neg_doc_hs = eos_tensors[1::2, :] # every second tensor from idx 1 is a query w negative doc (bs x 1) + + if self.reward_model_loss: + loss = -torch.nn.functional.logsigmoid(query_pos_doc_hs - query_neg_doc_hs).mean() + else: + cs = torch.cat([query_pos_doc_hs, query_neg_doc_hs], dim=1) # (bs x 2) + cs = cs / self.temperature + labels = torch.zeros(bs, device=cs.device).long() + loss = torch.nn.functional.cross_entropy(cs, labels) + + cp_size = self.cfg.get('context_parallel_size', 1) + if cp_size > 1: + torch.distributed.all_reduce(loss, group=parallel_state.get_context_parallel_group()) + query_pos_doc_hs = query_pos_doc_hs.clone().detach() + query_neg_doc_hs = query_neg_doc_hs.clone().detach() + logit_diffs = torch.mean(query_pos_doc_hs - query_neg_doc_hs) + return { + "loss": loss, + "query_pos_doc_logit": query_pos_doc_hs, + "query_neg_doc_logit": query_neg_doc_hs, + "logit_diff": logit_diffs, + } + + def gather_and_maybe_write_predictions(self, output, data_cfg, mode, averaged_metric, dataloader_idx=0): + if not data_cfg.get("write_embeddings_to_file", False): + return True + gathered_output_batches = [None for _ in range(parallel_state.get_data_parallel_world_size())] + torch.distributed.all_gather_object( + gathered_output_batches, + [ + { + 'query_pos_doc_logit': batch['query_pos_doc_logit'], + 'metadata': batch['metadata'], + } + for batch in output + ], + group=parallel_state.get_data_parallel_group(), + ) + + # Remove duplicate examples due to distributed sampler. + deduplicated_outputs = { + 'query_pos_doc_logit': [], + 'metadata': [], + } + total_size, skipped = 0, 0 + for rank in range(0, parallel_state.get_data_parallel_world_size()): + for batch in gathered_output_batches[rank]: + l_q_hs = listify(batch['query_pos_doc_logit']) + l_m = batch['metadata'] + assert len(l_m) == len(l_q_hs) + for q_hs, metadata in zip( + l_q_hs, + l_m, + ): + total_size += 1 + if not metadata.get("__AUTOGENERATED__", False): + deduplicated_outputs['query_pos_doc_logit'].append(q_hs) + deduplicated_outputs['metadata'].append(metadata) + else: + skipped += 1 + + logging.info( + f"{total_size-skipped} deduplicated outputs in dataloader:{dataloader_idx}, (skipped {skipped} autogenerated examples)." + ) + # Compute metric score + metric_name = self.val_metric_name if mode == 'validation' else self.test_metric_name + assert metric_name == "loss", "Only loss is supported for now." + # avg_pos_cs = torch.tensor(deduplicated_outputs['avg_pos_cs']).mean().item() + # avg_neg_cs = torch.tensor(deduplicated_outputs['avg_neg_cs']).mean().item() + # diff_cs = torch.tensor(deduplicated_outputs['diff_cs']).mean().item() + # self.log('val_avg_pos_cs', avg_pos_cs, prog_bar=True, rank_zero_only=True, batch_size=1) + # self.log('val_avg_neg_cs', avg_neg_cs, prog_bar=True, rank_zero_only=True, batch_size=1) + # self.log('val_diff_cs', diff_cs, prog_bar=True, rank_zero_only=True, batch_size=1) + + # Write predictions to file + if self.global_rank == 0 and data_cfg.get("write_embeddings_to_file", False): + logging.info( + f"Total deduplicated inference data size: {total_size} to {len(deduplicated_outputs['metadata'])}" + ) + + # Check if the user provided a prefix path to the file(s) they want to write. + if not hasattr(data_cfg, "output_file_path_prefix") or data_cfg.output_file_path_prefix is None: + raise ValueError( + f"Cannot write predictions to file when output_file_path_prefix is not set or present in the yaml config file." + ) + # (@adithyare) We are not using the log key to write the embeddings to file + filename_log_key = self._determine_log_key(data_cfg, dataloader_idx, None, mode) + consumed_samples = self._compute_consumed_samples_after_training_step() + fldr_path = f"{data_cfg.output_file_path_prefix}/consumed_samples{consumed_samples}/{filename_log_key}" + self.write_embeddings_to_file(deduplicated_outputs, fldr_path, dataloader_idx) + return deduplicated_outputs, total_size + + def write_embeddings_to_file(self, outputs, output_file_path, d_idx): + hs = torch.cat(outputs['query_pos_doc_logit'], dim=0) + hs_npy = hs.float().numpy() + emb_fldr = f"{output_file_path}" + os.makedirs(emb_fldr, exist_ok=True) + with open(f"{output_file_path}/logits.ids", "w") as f: + for m in outputs['metadata']: + f.write(f"{m['query_id'].strip()} {m['doc_id']}\n") + np.save(f"{emb_fldr}/logits.npy", hs_npy) + return True diff --git a/nemo/collections/nlp/models/language_modeling/megatron/gpt_full_te_layer_autocast_spec.py b/nemo/collections/nlp/models/language_modeling/megatron/gpt_full_te_layer_autocast_spec.py index 6cce2b42be9c..f3299d488fd0 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron/gpt_full_te_layer_autocast_spec.py +++ b/nemo/collections/nlp/models/language_modeling/megatron/gpt_full_te_layer_autocast_spec.py @@ -35,7 +35,9 @@ try: from megatron.core import parallel_state, tensor_parallel + from megatron.core.fusions.fused_layer_norm import FusedLayerNorm from megatron.core.transformer.spec_utils import ModuleSpec + from megatron.core.transformer.transformer_block import TransformerBlockSubmodules, get_num_layers_to_build from megatron.core.transformer.transformer_layer import BaseTransformerLayer from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint @@ -322,8 +324,10 @@ def sharded_state_dict(self, prefix: str = '', sharded_offsets: tuple = (), meta # Use this spec to use the full Transformer layer from Transformer Engine -def get_gpt_full_te_layer_autocast_spec() -> ModuleSpec: +def get_gpt_full_te_layer_autocast_spec(transformer_config) -> ModuleSpec: if not HAVE_MEGATRON_CORE or not HAVE_TE: raise ImportError(IMPORT_ERROR) - - return ModuleSpec(module=TETransformerLayerAutocast) + num_layers = get_num_layers_to_build(transformer_config) + return TransformerBlockSubmodules( + layer_specs=[ModuleSpec(module=TETransformerLayerAutocast)] * num_layers, layer_norm=FusedLayerNorm + ) diff --git a/nemo/collections/nlp/models/language_modeling/megatron/gpt_layer_modelopt_spec.py b/nemo/collections/nlp/models/language_modeling/megatron/gpt_layer_modelopt_spec.py index f001e8f58d25..d3f770cde91b 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron/gpt_layer_modelopt_spec.py +++ b/nemo/collections/nlp/models/language_modeling/megatron/gpt_layer_modelopt_spec.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from nemo.collections.nlp.modules.common.megatron.utils import ApexGuardDefaults - try: from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear @@ -31,6 +29,8 @@ except (ImportError, ModuleNotFoundError) as e: + from nemo.collections.nlp.modules.common.megatron.utils import ApexGuardDefaults + TransformerLayer = TransformerLayerSubmodules = ApexGuardDefaults MLP = MLPSubmodules = ModuleSpec = IdentityOp = ApexGuardDefaults AttnMaskType = DotProductAttention = TENorm = ApexGuardDefaults diff --git a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py index 4ded9a42db4f..7042b0d35ad9 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py @@ -47,16 +47,6 @@ from nemo.utils import AppState, logging, str_to_dtype from nemo.utils.get_rank import is_global_rank_zero -try: - from apex.transformer.pipeline_parallel.utils import get_num_microbatches - - HAVE_APEX = True - -except (ImportError, ModuleNotFoundError): - - HAVE_APEX = False - - try: from megatron.core import ModelParallelConfig, parallel_state from megatron.core.distributed import DistributedDataParallel as McoreDDP @@ -72,6 +62,13 @@ HAVE_MEGATRON_CORE = False +try: + from megatron.core.num_microbatches_calculator import get_current_global_batch_size, get_num_microbatches + +except (ImportError, ModuleNotFoundError): + logging.warning("Megatron num_microbatches_calculator not found, using Apex version.") + from apex.transformer.pipeline_parallel.utils import get_current_global_batch_size, get_num_microbatches + try: from megatron.core import Timers @@ -303,9 +300,12 @@ def _wrap_model_for_O2(self): if type(self).__name__ == 'MegatronGPTModel': nemo_args['share_token_embeddings'] = self.cfg.get('share_embeddings_and_output_weights', True) - mcore_args = { - 'config': self.transformer_config, - } + if is_mcore_model: + mcore_args = { + 'config': self.transformer_config, + } + else: + mcore_args = None args = mcore_args if is_mcore_model else nemo_args # Model wrapper to convert both model and inputs to half precision @@ -385,8 +385,11 @@ def _enable_nvidia_optimizations(self): # NVIDIA container version check nvidia_torch_version = os.getenv('NVIDIA_PYTORCH_VERSION', None) - # Support DLFW master container - if nvidia_torch_version == 'master': + def is_official_release_version(nvidia_torch_version): + return re.fullmatch("[0-9][0-9]\.[0-9][0-9].*", nvidia_torch_version) # "YY.MM.*" + + # Support DLFW dev container + if not is_official_release_version(nvidia_torch_version): nvidia_torch_version = datetime.now().strftime('%y.%m') if nvidia_torch_version is not None: @@ -395,7 +398,7 @@ def _enable_nvidia_optimizations(self): except Exception: NVIDIA_TORCH_MAJOR = 0 try: - NVIDIA_TORCH_MINOR = int(nvidia_torch_version.split('.')[1]) + NVIDIA_TORCH_MINOR = int(nvidia_torch_version.split('.')[1][:2]) except Exception: NVIDIA_TORCH_MINOR = 0 @@ -914,9 +917,7 @@ def compute_consumed_samples(self, steps_since_resume=0): app_state = AppState() if self.cfg.get('rampup_batch_size', None): - from apex.transformer.pipeline_parallel.utils import _GLOBAL_NUM_MICROBATCHES_CALCULATOR - - current_global_batch_size = getattr(_GLOBAL_NUM_MICROBATCHES_CALCULATOR, 'current_global_batch_size', 1) + current_global_batch_size = get_current_global_batch_size() if get_current_global_batch_size() else 1 consumed_samples = self.prev_consumed_samples + self.if_first_step * current_global_batch_size else: consumed_samples = ( @@ -1271,6 +1272,8 @@ def find_frozen_submodules(model): # TODO: Currently the main parameter data type is kept in fp32 (when O2=False). This needs to be # extended to support lower precision main parameters. frozen_submodule_names, frozen_submodules = find_frozen_submodules(self.model) + for submodule in frozen_submodule_names: + logging.debug(f"Ignoring state {submodule} in FSDP.") self.trainer.strategy.kwargs['ignored_states'] = frozen_submodules # FSDP requires uniform status of require_grads # Diffusion models like SD has frozen parts and needs to be added to 'ignored_states' from sharding for FSDP to work diff --git a/nemo/collections/nlp/models/language_modeling/megatron_base_prompt_learning_model.py b/nemo/collections/nlp/models/language_modeling/megatron_base_prompt_learning_model.py index f6ee4b20183c..2a356012c728 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_base_prompt_learning_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_base_prompt_learning_model.py @@ -39,14 +39,6 @@ from nemo.utils import AppState, logging from nemo.utils.decorators import deprecated_warning -try: - from apex.transformer.pipeline_parallel.utils import _reconfigure_microbatch_calculator - - HAVE_APEX = True - -except (ImportError, ModuleNotFoundError): - HAVE_APEX = False - try: from megatron.core import ModelParallelConfig, parallel_state @@ -59,6 +51,16 @@ HAVE_MEGATRON_CORE = False +try: + from megatron.core.num_microbatches_calculator import reconfigure_num_microbatches_calculator + +except (ImportError, ModuleNotFoundError): + logging.warning("Megatron num_microbatches_calculator not found, using Apex version.") + from apex.transformer.pipeline_parallel.utils import ( + _reconfigure_microbatch_calculator as reconfigure_num_microbatches_calculator, + ) + + __all__ = ['MegatronBasePromptLearningModel'] @@ -387,7 +389,7 @@ def _reconfigure_and_process_inference_batch(self, global_batch_size_per_gpu, gb if global_batch_size_per_gpu != gbs // parallel_state.get_data_parallel_world_size(): # NOTE: This is reconfiguring to make sure there is no grad-acc for validation batches. app_state = AppState() - _reconfigure_microbatch_calculator( + reconfigure_num_microbatches_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=global_batch_size_per_gpu * parallel_state.get_data_parallel_world_size(), @@ -397,7 +399,7 @@ def _reconfigure_and_process_inference_batch(self, global_batch_size_per_gpu, gb def _reconfigure_batch_sizes(self, gbs: int, mbs: int): app_state = AppState() - _reconfigure_microbatch_calculator( + reconfigure_num_microbatches_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=gbs, diff --git a/nemo/collections/nlp/models/language_modeling/megatron_bert_model.py b/nemo/collections/nlp/models/language_modeling/megatron_bert_model.py index 984fca5f1259..093fb2b8d688 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_bert_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_bert_model.py @@ -46,15 +46,6 @@ from nemo.core.neural_types import ChannelType, MaskType, NeuralType from nemo.utils import logging -try: - from apex.transformer.pipeline_parallel.utils import get_num_microbatches - - HAVE_APEX = True - -except (ImportError, ModuleNotFoundError): - - HAVE_APEX = False - try: import logging @@ -77,6 +68,13 @@ TransformerConfig = ApexGuardDefaults HAVE_MEGATRON_CORE = False +try: + from megatron.core.num_microbatches_calculator import get_num_microbatches + +except (ImportError, ModuleNotFoundError): + logging.warning("Megatron num_microbatches_calculator not found, using Apex version.") + from apex.transformer.pipeline_parallel.utils import get_num_microbatches + class MegatronBertModel(MegatronBaseModel): """ @@ -214,8 +212,8 @@ def model_provider_func(self, pre_process, post_process): return model def _validate_trainer(self): - """ Certain trainer configurations can break training. - Here we try to catch them and raise an error. + """Certain trainer configurations can break training. + Here we try to catch them and raise an error. """ if self.trainer.accumulate_grad_batches > 1: raise ValueError( @@ -308,7 +306,11 @@ def forward( model = self.model if self.mcore_bert: - output_tensor = model(input_ids, attention_mask, tokentype_ids=token_type_ids,) + output_tensor = model( + input_ids, + attention_mask, + tokentype_ids=token_type_ids, + ) else: output_tensor = model( input_ids, @@ -423,21 +425,24 @@ def training_step(self, dataloader_iter): self.log('lr', lr, batch_size=1) self.log('global_step', self.trainer.global_step, prog_bar=True, batch_size=1) self.log( - 'consumed_samples', self._compute_consumed_samples_after_training_step(), prog_bar=True, batch_size=1, + 'consumed_samples', + self._compute_consumed_samples_after_training_step(), + prog_bar=True, + batch_size=1, ) return loss_mean[0] def _make_data_iterator_list(self, data_iterator: Iterator) -> List[Iterator]: - """ Convert data iterator into form expected by Megatron - With interleaved pipeline parallelism, Megatron expects a - list of one data iterator per model chunk. Each model - chunk independently gets data from its data iterator, so - we need to interact with the data iterator multiple times - for each microbatch step. Instead of incorporating this - logic into the data loader, we cache the iterator's output - to the first model chunk and reuse it in the other model - chunks. + """Convert data iterator into form expected by Megatron + With interleaved pipeline parallelism, Megatron expects a + list of one data iterator per model chunk. Each model + chunk independently gets data from its data iterator, so + we need to interact with the data iterator multiple times + for each microbatch step. Instead of incorporating this + logic into the data loader, we cache the iterator's output + to the first model chunk and reuse it in the other model + chunks. """ if not isinstance(self.model, list) or len(self.model) == 1: @@ -703,9 +708,9 @@ def build_train_valid_test_datasets(self): ] if self.trainer.limit_val_batches <= 1.0 and isinstance(self.trainer.limit_val_batches, float): - train_valid_test_num_samples[ - 1 - ] = 1 # This is to make sure we only have one epoch on every validation iteration + train_valid_test_num_samples[1] = ( + 1 # This is to make sure we only have one epoch on every validation iteration + ) self._train_ds, self._validation_ds, self._test_ds = dataset_utils.build_train_valid_test_datasets( cfg=self.cfg, @@ -739,20 +744,20 @@ def build_train_valid_test_datasets(self): return self._train_ds, self._validation_ds, self._test_ds def backward(self, *args, **kwargs): - """ LightningModule hook to do backward. - We want this to do nothing since we run backward in the fwd/bwd functions from megatron-core. - No need to call it here. + """LightningModule hook to do backward. + We want this to do nothing since we run backward in the fwd/bwd functions from megatron-core. + No need to call it here. """ return def optimizer_zero_grad(self, *args, **kwargs): - """ LightningModule hook to zero grad. - We want this to do nothing as we are zeroing grads during the training_step. + """LightningModule hook to zero grad. + We want this to do nothing as we are zeroing grads during the training_step. """ return def _append_sequence_parallel_module_grads(self, module, grads): - """ Helper method for allreduce_sequence_parallel_gradients""" + """Helper method for allreduce_sequence_parallel_gradients""" for param in module.parameters(): sequence_parallel_param = getattr(param, 'sequence_parallel', False) @@ -822,12 +827,12 @@ def setup(self, stage=None): self.setup_transformer_engine_tp_groups() def setup_transformer_engine_tp_groups(self): - """ This should be called after model parallel groups have been initialized - and only needs to be called when using Transformer Engine. + """This should be called after model parallel groups have been initialized + and only needs to be called when using Transformer Engine. """ for module in self.get_bert_module_list(): """Set TP group - Copied from: https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/transformer.py#L398 + Copied from: https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/transformer.py#L398 """ # Deep iterate but skip self to avoid infinite recursion. for index, child in enumerate(module.modules()): @@ -849,9 +854,9 @@ def get_bert_module_list(self): return [self.model] def allreduce_sequence_parallel_gradients(self): - """ All-reduce layernorm parameters across model parallel nodes when sequence parallelism is used. - Modified from megatron-lm: - https://gitlab-master.nvidia.com/ADLR/megatron-lm/-/blob/3f91f09bb2ab32f9904b47f46f19d2fc3f518ed8/megatron/training.py#L425 + """All-reduce layernorm parameters across model parallel nodes when sequence parallelism is used. + Modified from megatron-lm: + https://gitlab-master.nvidia.com/ADLR/megatron-lm/-/blob/3f91f09bb2ab32f9904b47f46f19d2fc3f518ed8/megatron/training.py#L425 """ grads = [] @@ -931,10 +936,10 @@ def setup_test_data(self, cfg): self._test_dl = self.build_pretraining_data_loader(self._test_ds, consumed_samples) def transfer_batch_to_device(self, batch: Any, device: torch.device, dataloader_idx: int) -> Any: - """ PTL hook: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#transfer-batch-to-device - When using pipeline parallelism, we need the global batch to remain on the CPU, - since the memory overhead will be too high when using a large number of microbatches. - Microbatches are transferred from CPU to GPU inside the pipeline. + """PTL hook: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#transfer-batch-to-device + When using pipeline parallelism, we need the global batch to remain on the CPU, + since the memory overhead will be too high when using a large number of microbatches. + Microbatches are transferred from CPU to GPU inside the pipeline. """ return batch @@ -1154,10 +1159,10 @@ def on_load_checkpoint(self, checkpoint) -> None: parallel_state.set_virtual_pipeline_model_parallel_rank(0) def build_transformer_config(self) -> TransformerConfig: - """ Builds the megatron core gpt transformer config for the model. - For attributes in the nemo model config that are the same - as the megatron core TransformerConfig, we will use the value from the nemo model config. - For attributes in TransformerConfig that are not in the nemo model config, we add custom logic. + """Builds the megatron core gpt transformer config for the model. + For attributes in the nemo model config that are the same + as the megatron core TransformerConfig, we will use the value from the nemo model config. + For attributes in TransformerConfig that are not in the nemo model config, we add custom logic. """ activation = self.cfg.get('activation', 'gelu') assert activation == 'gelu', "Only gelu activation is support for BERT at the moment." diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 4f9722d900f6..22ee37ec361b 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -32,6 +32,7 @@ from nemo.collections.common.parts.utils import extend_instance from nemo.collections.nlp.data.language_modeling.megatron.data_samplers import ( + MegatronCorePretrainingSampler, MegatronPretrainingRandomSampler, MegatronPretrainingSampler, ) @@ -44,7 +45,6 @@ from nemo.collections.nlp.models.language_modeling.megatron.gpt_layer_modelopt_spec import get_gpt_layer_modelopt_spec from nemo.collections.nlp.models.language_modeling.megatron.gpt_model import GPTModel from nemo.collections.nlp.models.language_modeling.megatron_base_model import MegatronBaseModel -from nemo.collections.nlp.modules.common.hyena.hyena_spec import get_gpt_layer_with_te_and_hyena_spec from nemo.collections.nlp.modules.common.megatron.build_model import build_model from nemo.collections.nlp.modules.common.megatron.module import Float16Module from nemo.collections.nlp.modules.common.megatron.utils import ( @@ -76,16 +76,6 @@ from nemo.utils import logging from nemo.utils.te_utils import is_float8tensor -try: - import apex.transformer.pipeline_parallel.utils - from apex.transformer.pipeline_parallel.utils import get_num_microbatches - - HAVE_APEX = True - -except (ImportError, ModuleNotFoundError): - - HAVE_APEX = False - try: from megatron.core import InferenceParams, parallel_state, tensor_parallel from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder @@ -121,10 +111,27 @@ HAVE_MEGATRON_CORE = False +try: + from megatron.core.num_microbatches_calculator import ( + get_current_global_batch_size, + get_num_microbatches, + update_num_microbatches, + ) + +except (ImportError, ModuleNotFoundError): + logging.warning("Megatron num_microbatches_calculator not found, using Apex version.") + from apex.transformer.pipeline_parallel.utils import ( + get_current_global_batch_size, + get_num_microbatches, + update_num_microbatches, + ) + try: import transformer_engine from transformer_engine.pytorch import module as te_module + from nemo.collections.nlp.modules.common.hyena.hyena_spec import get_gpt_layer_with_te_and_hyena_spec + HAVE_TE = True except (ImportError, ModuleNotFoundError): @@ -144,7 +151,12 @@ def mcore_supports_moe() -> bool: return False -def get_specs(spec_name, num_experts=None, moe_grouped_gemm=False, use_te=True, hyena_cfg: Dict = None): +## TODO: This function will not work if TE is not installed +def get_specs(spec_name, transformer_config=None, use_te=True, hyena_cfg: Dict = None): + # else cases for backwards compatibility with neva + num_experts = transformer_config.num_moe_experts if transformer_config else None + moe_grouped_gemm = transformer_config.moe_grouped_gemm if transformer_config else False + if num_experts is not None: assert mcore_supports_moe(), "Megatron-core >= v0.5.0 is required for MoE" @@ -154,7 +166,7 @@ def get_specs(spec_name, num_experts=None, moe_grouped_gemm=False, use_te=True, "": get_gpt_layer_local_spec(num_experts, moe_grouped_gemm), "te_gpt": get_gpt_layer_with_transformer_engine_spec(num_experts, moe_grouped_gemm), "megatron_falcon_gpt": get_falcon_layer_spec(), - "megatron_gpt_full_te_layer_autocast": get_gpt_full_te_layer_autocast_spec(), + "megatron_gpt_full_te_layer_autocast": get_gpt_full_te_layer_autocast_spec(transformer_config), "modelopt": get_gpt_layer_modelopt_spec(num_experts), "te_gpt_hyena": get_gpt_layer_with_te_and_hyena_spec(hyena_cfg), } @@ -276,10 +288,6 @@ class MegatronGPTModel(MegatronBaseModel, TextGeneration): """ def __init__(self, cfg: DictConfig, trainer: Trainer): - if not HAVE_APEX: - raise ImportError( - "Apex was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." - ) if not HAVE_MEGATRON_CORE: logging.warning( "megatron-core was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." @@ -391,7 +399,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): self.log_memory_usage = bool(int(os.getenv("NEMO_LOG_MEMORY_USAGE", 0))) self.loss_broadcast_src_rank = None data_cfg = cfg.get('data', {}) - self.return_output_tensors = data_cfg.get('return_output_tensors', False) self.validation_drop_last = data_cfg.get('validation_drop_last', True) self.sample_weight = data_cfg.get('sample_weight', 'token') self.validation_param_sync_overlap = self.cfg.get('validation_param_sync_overlap', False) @@ -426,8 +433,7 @@ def model_provider_func(self, pre_process, post_process): config=self.transformer_config, transformer_layer_spec=get_specs( self.spec_name, - self.transformer_config.num_moe_experts, - self.transformer_config.moe_grouped_gemm, + self.transformer_config, self.transformer_engine, self.cfg.get('hyena', None), ), @@ -539,6 +545,7 @@ def setup_mcore_distributed_parallel(self): # mcore bucket_size is based on num of parameters, therefore not # using bucket_cap_mb to configure bucket_size here bucket_size=self.cfg.optim.get('ddp_bucket_size', None), + average_in_collective=self.cfg.optim.get('average_in_collective', True), ) self.model = [ McoreDDP( @@ -624,11 +631,7 @@ def make_parameter_bucket(module: torch.nn.Module) -> List[torch.nn.Parameter]: if self.cfg.get('virtual_pipeline_model_parallel_size', None) is not None: # Initialize a bucket for each virtual pipeline stage for module in self.model: - if isinstance(module, (Float16Module, MCoreFloat16Module)): - module = module.module - stage_bucket = [] - layers = module.decoder.layers if self.mcore_gpt else module.language_model.encoder.layers - buckets.extend(make_parameter_bucket(layer) for layer in layers) + buckets.append(make_parameter_bucket(module)) else: # Initialize a bucket for each Transformer layer modules = self.model if isinstance(self.model, list) else [self.model] @@ -788,8 +791,7 @@ def training_step(self, dataloader_iter): self.if_init_step = False if self.rampup_batch_size: - num_microbatch_calculator = apex.transformer.pipeline_parallel.utils._GLOBAL_NUM_MICROBATCHES_CALCULATOR - current_global_batch_size = num_microbatch_calculator.current_global_batch_size + current_global_batch_size = get_current_global_batch_size() # do validation and save the checkpoint when gbs is changed if self.prev_global_batch_size != current_global_batch_size and self.prev_global_batch_size: self.trainer.should_stop = True @@ -825,7 +827,9 @@ def training_step(self, dataloader_iter): ignore_virtual=True ): if ( - self.cfg.get('defer_embedding_wgrad_compute', False) and self.mcore_gpt + self.cfg.get('defer_embedding_wgrad_compute', False) + and self.mcore_gpt + and not self.use_mcore_dist_optim ): # Silently ignore the optimization if MCORE is not used module_list = self.get_model_module_list() if len(module_list) > 1: @@ -848,7 +852,9 @@ def training_step(self, dataloader_iter): ignore_virtual=True ): if ( - self.cfg.get('defer_embedding_wgrad_compute', False) and self.mcore_gpt + self.cfg.get('defer_embedding_wgrad_compute', False) + and self.mcore_gpt + and not self.use_mcore_dist_optim ): # Silently ignore the optimization if MCORE is not used module_list = self.get_model_module_list() if len(module_list) > 1: @@ -1275,24 +1281,47 @@ def loss_func(output_tensor): # Loss for a micro-batch (ub) loss_for_ub = self.loss_func(batch['loss_mask'], batch['num_valid_tokens_in_ub'], output_tensor) cp_size = parallel_state.get_context_parallel_world_size() - if self.return_output_tensors: + if isinstance(loss_for_ub, dict): # TODO: need a better way to check if loss_func is returning more stuff than just loss... (@adithyare) - loss_for_ub, q_hs, d_hs, pos_cs, neg_cs, diff_cs = loss_for_ub - reduced_loss = average_losses_across_data_parallel_group([loss_for_ub]) - pos_cs = average_losses_across_data_parallel_group([pos_cs]) - neg_cs = average_losses_across_data_parallel_group([neg_cs]) - diff_cs = average_losses_across_data_parallel_group([diff_cs]) - return ( - loss_for_ub * cp_size, - { - 'avg': reduced_loss, - 'query_hs': q_hs, - 'doc_hs': d_hs, - 'avg_pos_cs': pos_cs, - 'avg_neg_cs': neg_cs, - 'diff_cs': diff_cs, - }, - ) + + if set(loss_for_ub.keys()) == set( + ["loss", "query_hs", "pos_doc_hs", "pos_cs", "neg_cs", "diff_cs"] + ): # (adithyare) this check will be True for GPT Embedding models + loss = loss_for_ub['loss'] + reduced_loss = average_losses_across_data_parallel_group([loss]) + pos_cs = average_losses_across_data_parallel_group([loss_for_ub['pos_cs']]) + neg_cs = average_losses_across_data_parallel_group([loss_for_ub['neg_cs']]) + diff_cs = average_losses_across_data_parallel_group([loss_for_ub['diff_cs']]) + return ( + loss * cp_size, + { + 'avg': reduced_loss, + 'query_hs': loss_for_ub['query_hs'], + 'doc_hs': loss_for_ub['pos_doc_hs'], + 'avg_pos_cs': pos_cs, + 'avg_neg_cs': neg_cs, + 'diff_cs': diff_cs, + }, + ) + elif set(loss_for_ub.keys()) == set( + ["loss", "query_pos_doc_logit", "query_neg_doc_logit", "logit_diff"] + ): # (adithyare) this check will be True for GPT Reranker models + + loss = loss_for_ub['loss'] + reduced_loss = average_losses_across_data_parallel_group([loss]) + logit_diff = average_losses_across_data_parallel_group([loss_for_ub['logit_diff']]) + return ( + loss * cp_size, + { + 'avg': reduced_loss, + 'query_pos_doc_logit': loss_for_ub['query_pos_doc_logit'], + 'query_neg_doc_logit': loss_for_ub['query_neg_doc_logit'], + 'logit_diff': logit_diff, + }, + ) + else: + raise RuntimeError(f"Dict loss_for_ub has unknown key set {loss_for_ub.keys()}") + elif validation_step and not self.validation_drop_last: num_valid_tokens_in_ub = batch['num_valid_tokens_in_ub'] if loss_for_ub.isnan(): @@ -1535,6 +1564,7 @@ def build_train_valid_test_datasets(self): "create_attention_mask": not self.get_attention_mask_from_fusion, "mmap_bin_files": self.cfg.data.get("mmap_bin_files", True), "drop_last_partial_validation_sequence": self.cfg.data.get("validation_drop_last", True), + "num_dataset_builder_threads": self.cfg.data.get("num_dataset_builder_threads", 1), "add_extra_token_to_sequence": add_extra_token, } @@ -1583,8 +1613,13 @@ def build_pretraining_data_loader( logging.info(f'Building dataloader with consumed samples: {consumed_samples}') # Megatron sampler if hasattr(self.cfg.data, 'dataloader_type') and self.cfg.data.dataloader_type is not None: + data_sampler = ( + MegatronPretrainingSampler + if self.cfg.data.get('legacy_dataset', False) + else MegatronCorePretrainingSampler + ) if self.cfg.data.dataloader_type == 'single': - batch_sampler = MegatronPretrainingSampler( + batch_sampler = data_sampler( total_samples=len(dataset), consumed_samples=consumed_samples, micro_batch_size=self.cfg.micro_batch_size, @@ -1644,8 +1679,7 @@ def setup(self, stage=None): self.init_global_step = self.trainer.global_step if self.rampup_batch_size: - num_microbatch_calculator = apex.transformer.pipeline_parallel.utils._GLOBAL_NUM_MICROBATCHES_CALCULATOR - num_microbatch_calculator.update(self.init_consumed_samples, consistency_check=False) + update_num_microbatches(self.init_consumed_samples, consistency_check=False) self.prev_consumed_samples = self.init_consumed_samples if stage == 'predict': @@ -1662,6 +1696,12 @@ def setup(self, stage=None): # Override limit_val_batches to be a multiple of num microbatches to prevent val_step from exiting in between a step self._reconfigure_limit_batches(self.trainer.limit_val_batches, self._validation_dl, 'val') + # Data cache generation only + # Stops script execution after creating a data cache + if self.cfg.data.get('data_cache_generation_only', False): + self.trainer.num_sanity_val_steps = 0 + self.trainer.should_stop = True + if stage == 'fit': self.initialize_last_rank_embeddings() diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_prompt_learning_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_prompt_learning_model.py index acfc22439a7d..78f671142c1b 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_prompt_learning_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_prompt_learning_model.py @@ -46,27 +46,25 @@ from nemo.utils import AppState, logging from nemo.utils.decorators import deprecated_warning -try: - from apex.transformer.pipeline_parallel.utils import get_micro_batch_size, get_num_microbatches - - HAVE_APEX = True - -except (ImportError, ModuleNotFoundError): - HAVE_APEX = False - try: from megatron.core import InferenceParams, ModelParallelConfig, parallel_state, tensor_parallel from megatron.core.enums import ModelType from megatron.core.pipeline_parallel.schedules import get_forward_backward_func HAVE_MEGATRON_CORE = True - except (ImportError, ModuleNotFoundError): ModelParallelConfig = ApexGuardDefaults HAVE_MEGATRON_CORE = False +try: + from megatron.core.num_microbatches_calculator import get_micro_batch_size, get_num_microbatches + +except (ImportError, ModuleNotFoundError): + logging.warning("Megatron num_microbatches_calculator not found, using Apex version.") + from apex.transformer.pipeline_parallel.utils import get_micro_batch_size, get_num_microbatches + __all__ = ['MegatronGPTPromptLearningModel'] diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py index 28bcbf22ac33..9c2372ef38ca 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py @@ -38,20 +38,9 @@ from nemo.collections.nlp.parts.utils_funcs import get_last_rank from nemo.utils import AppState, logging -try: - from apex.transformer.pipeline_parallel.utils import ( - _reconfigure_microbatch_calculator, - get_current_global_batch_size, - get_micro_batch_size, - get_num_microbatches, - ) - - HAVE_APEX = True -except (ImportError, ModuleNotFoundError): - HAVE_APEX = False - try: from megatron.core import parallel_state + from megatron.core.distributed import finalize_model_grads from megatron.core.pipeline_parallel.schedules import get_forward_backward_func HAVE_MEGATRON_CORE = True @@ -60,6 +49,25 @@ HAVE_MEGATRON_CORE = False +try: + from megatron.core.num_microbatches_calculator import ( + get_current_global_batch_size, + get_micro_batch_size, + get_num_microbatches, + reconfigure_num_microbatches_calculator, + ) + +except (ImportError, ModuleNotFoundError): + logging.warning("Megatron num_microbatches_calculator not found, using Apex version.") + from apex.transformer.pipeline_parallel.utils import ( + _reconfigure_microbatch_calculator as reconfigure_num_microbatches_calculator, + ) + from apex.transformer.pipeline_parallel.utils import ( + get_current_global_batch_size, + get_micro_batch_size, + get_num_microbatches, + ) + __all__ = ['MegatronGPTSFTModel'] @@ -70,10 +78,6 @@ class MegatronGPTSFTModel(NLPAdapterModelMixin, MegatronGPTModel): """ def __init__(self, cfg: DictConfig, trainer: Trainer): - if not HAVE_APEX: - raise ImportError( - "Apex was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." - ) super().__init__(cfg, trainer=trainer) self.sep_id = cfg.get('sep_id', 49704) if hasattr(self.cfg.data, "validation_ds"): @@ -298,6 +302,7 @@ def _build_dataset(self, data_cfg, is_train=True): prompt_template=data_cfg.get('prompt_template', None), ceil_to_power_2=data_cfg.get('ceil_to_power_2', False), get_attention_mask_from_fusion=data_cfg.get('get_attention_mask_from_fusion', False), + global_sample_mapping=data_cfg.get('global_sample_mapping', False), virtual_tokens=self.virtual_tokens, tokens_to_generate=data_cfg.get( 'tokens_to_generate', 0 @@ -374,11 +379,27 @@ def fwd_bwd_step(self, dataloader_iter, forward_only, first_val_step=None): ) grad_sync_func = self.reduce_overlap_gradients param_sync_func = self.sync_overlap_parameters + elif not forward_only and self.use_mcore_dist_optim: + if self.cfg.optim.get("overlap_grad_sync", False): + no_sync_func = [model_chunk.no_sync for model_chunk in self.model] + no_sync_func = no_sync_func[0] if len(self.model) == 1 else no_sync_func + + if self.cfg.optim.get("delay_grad_reduce", True): + grad_sync_func = [model_chunk.start_grad_sync for model_chunk in self.model] + grad_sync_func = grad_sync_func[0] if len(self.model) == 1 else grad_sync_func + if self.cfg.optim.get("overlap_param_sync", False) and self.cfg.optim.get("delay_param_gather", False): + param_sync_func = [ + lambda x, model_index=model_index: self._optimizer.finish_param_sync(model_index, x) + for model_index in range(len(self.model)) + ] + param_sync_func = param_sync_func[0] if len(self.model) == 1 else param_sync_func for module in self.get_model_module_list(): module.config.no_sync_func = no_sync_func module.config.grad_sync_func = grad_sync_func module.config.param_sync_func = param_sync_func + if self.use_mcore_dist_optim: + module.config.finalize_model_grads_func = finalize_model_grads fwd_bwd_function = get_forward_backward_func() @@ -644,7 +665,7 @@ def inference_epoch_end(self, outputs, mode, data_cfg): app_state = AppState() self._restore_activation_checkpointing_args() if hasattr(self, "_train_ds"): - _reconfigure_microbatch_calculator( + reconfigure_num_microbatches_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=self.cfg.data.train_ds.global_batch_size, @@ -654,7 +675,7 @@ def inference_epoch_end(self, outputs, mode, data_cfg): # When running `trainer.validate()`, the training dataset is not available. else: logging.warning('No training data found, reconfiguring microbatches based on validation batch sizes.') - _reconfigure_microbatch_calculator( + reconfigure_num_microbatches_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=data_cfg.global_batch_size, @@ -690,7 +711,7 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] response = generate(self, **inference_config) app_state = AppState() - _reconfigure_microbatch_calculator( + reconfigure_num_microbatches_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=global_batch_size_per_gpu * parallel_state.get_data_parallel_world_size(), @@ -784,7 +805,7 @@ def _reconfigure_and_process_inference_batch(self, batch, data_cfg): != data_cfg.global_batch_size // parallel_state.get_data_parallel_world_size() ): app_state = AppState() - _reconfigure_microbatch_calculator( + reconfigure_num_microbatches_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=global_batch_size_per_gpu * parallel_state.get_data_parallel_world_size(), @@ -794,7 +815,7 @@ def _reconfigure_and_process_inference_batch(self, batch, data_cfg): # NOTE: need to explicitly handle resetting for multi-validation else: app_state = AppState() - _reconfigure_microbatch_calculator( + reconfigure_num_microbatches_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=data_cfg.global_batch_size, @@ -878,7 +899,7 @@ def setup_eval_dataloader(self, datasets, data_cfg): def on_validation_epoch_start(self): self._reset_activation_checkpointing_args() app_state = AppState() - _reconfigure_microbatch_calculator( + reconfigure_num_microbatches_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=self.cfg.data.validation_ds.global_batch_size, @@ -890,7 +911,7 @@ def on_validation_epoch_start(self): def on_test_epoch_start(self): self._reset_activation_checkpointing_args() app_state = AppState() - _reconfigure_microbatch_calculator( + reconfigure_num_microbatches_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=self.cfg.data.test_ds.global_batch_size, diff --git a/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py b/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py index 6609b1aff303..7b92b9e25d69 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py @@ -36,7 +36,6 @@ MegatronTokenLevelEncoderDecoderModule, ) from nemo.collections.nlp.modules.common.megatron.utils import ( - ApexGuardDefaults, average_losses_across_data_parallel_group, build_attention_mask_3d, get_params_for_weight_decay_optimization, @@ -48,19 +47,6 @@ from nemo.collections.nlp.parts.utils_funcs import get_last_rank from nemo.utils import AppState, logging -try: - from apex.transformer.pipeline_parallel.utils import ( - _reconfigure_microbatch_calculator, - get_micro_batch_size, - get_num_microbatches, - ) - - HAVE_APEX = True - -except (ImportError, ModuleNotFoundError): - - HAVE_APEX = False - try: from megatron.core import parallel_state, tensor_parallel from megatron.core.enums import ModelType @@ -81,6 +67,20 @@ HAVE_MEGATRON_CORE = False +try: + from megatron.core.num_microbatches_calculator import ( + get_micro_batch_size, + get_num_microbatches, + reconfigure_num_microbatches_calculator, + ) + +except (ImportError, ModuleNotFoundError): + logging.warning("Megatron num_microbatches_calculator not found, using Apex version.") + from apex.transformer.pipeline_parallel.utils import ( + _reconfigure_microbatch_calculator as reconfigure_num_microbatches_calculator, + ) + from apex.transformer.pipeline_parallel.utils import get_micro_batch_size, get_num_microbatches + __all__ = ["MegatronLMEncoderDecoderModel"] @@ -268,7 +268,7 @@ def model_provider_func(self, pre_process, post_process, add_encoder, add_decode if parallel_state.get_pipeline_model_parallel_world_size() > 1 and self.cfg.encoder.arch == 'perceiver': raise ValueError(f"Perceivers with pipeline parallel > 1 is not supported yet.") - if hasattr(self, 'mcore_t5') and self.mcore_t5: + if getattr(self, 'mcore_t5', False): assert HAVE_MEGATRON_CORE, "Cannot use MCore T5 since Megatron Core is not found" assert self.cfg.get( 'share_token_embeddings', True @@ -286,8 +286,18 @@ def model_provider_func(self, pre_process, post_process, add_encoder, add_decode en_block_spec = enc_dec_spec_fns[0](self.cfg.encoder.num_layers) de_block_spec = enc_dec_spec_fns[1](self.cfg.decoder.num_layers) + + encoder_config = copy.deepcopy(self.transformer_config) + encoder_config.num_layers = self.cfg.encoder.num_layers + if self.cfg.pipeline_model_parallel_size > 1: + assert ( + self.cfg.pipeline_model_parallel_split_rank is not None + ), "Need to know how to shard the encoder & decoder." + encoder_config.pipeline_model_parallel_size = self.cfg.pipeline_model_parallel_split_rank + model = MCoreT5Model( config=self.transformer_config, + encoder_config=encoder_config, transformer_encoder_layer_spec=en_block_spec, transformer_decoder_layer_spec=de_block_spec, vocab_size=self.padded_vocab_size, @@ -439,8 +449,6 @@ def training_step(self, dataloader_iter): for module in modules: if isinstance(module, (Float16Module, MCoreFloat16Module)): module = module.module - if not self.mcore_t5: - module = module.language_model if hasattr(module, 'embedding'): for param in module.embedding.parameters(): param.data_ptr() @@ -799,9 +807,76 @@ def fwd_output_only_func(dataloader_iter, model): batch = next(dataloader_iter) batch = [x.cuda(non_blocking=True) if torch.is_tensor(x) else x for x in batch] - # map batch and shared args into forward args - args = self._build_forward_args_from_kwargs(args_name=arg_names, args=batch, **kwargs) - output = model(*args).contiguous() + # processing forward args for mcore T5 + if self.mcore_t5: + # when run encoding + if output_name == "hiddens": + ( + encoder_input_ids, + encoder_attn_mask, + ) = batch + + # attn mask logic follows megatron.data.t5_dataset.py in Megatron-LM + encoder_attn_mask_3d = build_attention_mask_3d( + encoder_attn_mask, encoder_attn_mask, AttnMaskType.padding + ) + + output = model( + encoder_input_ids=encoder_input_ids, + decoder_input_ids=None, + encoder_attn_mask=encoder_attn_mask_3d, + decoder_attn_mask=None, + encoder_decoder_attn_mask=None, + lm_labels=None, + encoder_hidden_states=None, + output_encoder_hidden_only=True, + ).contiguous() + + # when run decoding + elif output_name == "logits": + ( + encoder_hidden_states, + encoder_attn_mask, + decoder_input_ids, + decoder_attn_mask, + ) = batch + + # attn mask logic follows megatron.data.t5_dataset.py in Megatron-LM + encoder_attn_mask_3d = build_attention_mask_3d( + encoder_attn_mask, encoder_attn_mask, AttnMaskType.padding + ) + decoder_attn_mask_3d = build_attention_mask_3d( + decoder_attn_mask, decoder_attn_mask, AttnMaskType.causal + ) + enc_dec_attn_mask_3d = build_attention_mask_3d( + decoder_attn_mask, encoder_attn_mask, AttnMaskType.padding + ) + + # re-transpose encoder_hidden_states from [batch, seq_len, hidden] to [seq_len, batch, hidden] + encoder_hidden_states = encoder_hidden_states.transpose(1, 0) + + output = model( + encoder_input_ids=None, + decoder_input_ids=decoder_input_ids, + encoder_attn_mask=encoder_attn_mask_3d, + decoder_attn_mask=decoder_attn_mask_3d, + encoder_decoder_attn_mask=enc_dec_attn_mask_3d, + lm_labels=None, + encoder_hidden_states=encoder_hidden_states, + output_encoder_hidden_only=False, + ).contiguous() + + else: + assert output_name in [ + "hiddens", + "logits", + ], "output_name argument must be either 'hiddens' or 'logits'" + + else: + # map batch and shared args into forward args + args = self._build_forward_args_from_kwargs(args_name=arg_names, args=batch, **kwargs) + + output = model(*args).contiguous() def id_func(output_tensor): if isinstance(output_tensor, dict): @@ -1159,7 +1234,7 @@ def dummy(): # Reconfigure microbatch sizes here because on model restore, this will contain the micro/global batch configuration used while training. if reconfigure_microbatch: - _reconfigure_microbatch_calculator( + reconfigure_num_microbatches_calculator( rank=0, # This doesn't matter since it is only used for logging rampup_batch_size=None, global_batch_size=1, @@ -1180,7 +1255,7 @@ def dummy(): # Reconfigure microbatch calculator here to set num microbatches to 1 while decoding since its not clear how to decode with "grad acc". # reconfigure back to how things were before encode if reconfigure_microbatch: - _reconfigure_microbatch_calculator( + reconfigure_num_microbatches_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=global_batch_per_gpu * parallel_state.get_data_parallel_world_size(), @@ -1191,8 +1266,12 @@ def dummy(): # build input arguments description if tokens_enc is not None: - batch_for_pipeline = [tokens_enc, enc_mask, batch_data] - arg_names = ['enc_input_ids', 'enc_attn_mask', 'batch_data'] + if self.mcore_t5 is True: + batch_for_pipeline = [tokens_enc, enc_mask] + arg_names = [] + else: + batch_for_pipeline = [tokens_enc, enc_mask, batch_data] + arg_names = ['enc_input_ids', 'enc_attn_mask', 'batch_data'] else: if encoder_input is None: raise ValueError("At least one of tokens_enc and encoder_input must be provided with not None value") @@ -1204,10 +1283,12 @@ def dummy(): batch_for_pipeline.append(encoder_input) arg_names.append('enc_input') - forward_step_func = self._get_forward_output_only_func( - arg_names=arg_names, output_name="hiddens", output_enc_hidden_only=True - ) - + if self.mcore_t5: + forward_step_func = self._get_forward_output_only_func(arg_names=arg_names, output_name="hiddens") + else: + forward_step_func = self._get_forward_output_only_func( + arg_names=arg_names, output_name="hiddens", output_enc_hidden_only=True + ) fwd_bwd_func = get_forward_backward_func() # Counter intuitively, we need to set decoder_sequence_length=encoder_seq_length @@ -1244,7 +1325,7 @@ def dummy(): # Reset microbatch calculator to what it was before decoding. if reconfigure_microbatch: - _reconfigure_microbatch_calculator( + reconfigure_num_microbatches_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=global_batch_per_gpu * parallel_state.get_data_parallel_world_size(), @@ -1320,7 +1401,7 @@ def dummy(): self.trainer.strategy.setup_environment() # Reconfigure microbatch sizes here because on model restore, this will contain the micro/global batch configuration used while training. - _reconfigure_microbatch_calculator( + reconfigure_num_microbatches_calculator( rank=0, # This doesn't matter since it is only used for logging rampup_batch_size=None, global_batch_size=1, @@ -1348,7 +1429,7 @@ def dummy(): # Reconfigure microbatch calculator here to set num microbatches to 1 while decoding since its not clear how to decode with "grad acc". # reconfigure back to how things were before decode # TODO: Check if the user is trying to do gradient acc and maybe throw error - _reconfigure_microbatch_calculator( + reconfigure_num_microbatches_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=global_batch_per_gpu * parallel_state.get_data_parallel_world_size(), @@ -1382,8 +1463,12 @@ def dummy(): dec_mask = predicted_tokens_dec != tokenizer.pad_id dec_mask[:, 0] = 1 # Make sure you never mask the first token even if it is . - batch_for_pipeline = [enc_output, enc_output_attn_mask, predicted_tokens_dec, dec_mask, batch_data] - arg_names = ['enc_output', 'enc_output_attn_mask', 'dec_input_ids', 'dec_attn_mask', 'batch_data'] + if self.mcore_t5: + batch_for_pipeline = [enc_output, enc_output_attn_mask, predicted_tokens_dec, dec_mask] + arg_names = [] + else: + batch_for_pipeline = [enc_output, enc_output_attn_mask, predicted_tokens_dec, dec_mask, batch_data] + arg_names = ['enc_output', 'enc_output_attn_mask', 'dec_input_ids', 'dec_attn_mask', 'batch_data'] forward_step_func = self._get_forward_output_only_func(arg_names=arg_names, output_name="logits") fwd_bwd_func = get_forward_backward_func() @@ -1437,10 +1522,10 @@ def dummy(): pad_profile = torch.zeros_like(scores).long() decoder_seq_lengths = torch.zeros_like(scores).fill_(predicted_tokens_dec.size(1) + 1) - # reconfigure batch size for apex since the tensor have been augmented with beam size + # reconfigure batch size since the tensor have been augmented with beam size global_batch_per_gpu = token_ids.shape[0] tensor_shape[1] = global_batch_per_gpu - _reconfigure_microbatch_calculator( + reconfigure_num_microbatches_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=global_batch_per_gpu * parallel_state.get_data_parallel_world_size(), @@ -1531,7 +1616,7 @@ def dummy(): ) # Reset microbatch calculator to what it was before decoding. - _reconfigure_microbatch_calculator( + reconfigure_num_microbatches_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=global_batch_per_gpu * parallel_state.get_data_parallel_world_size(), @@ -1586,7 +1671,7 @@ def complete(self, request: Dict): app_state = AppState() # The complete method only works with global batch = micro batch size = data parallel size = 1. - _reconfigure_microbatch_calculator( + reconfigure_num_microbatches_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=1, @@ -1724,6 +1809,9 @@ def on_load_checkpoint(self, checkpoint) -> None: # addressing the current T5 mcore version's implementation of sharded_state_dict checkpoint_state_dict['lm_head.output_layer.bias'] = checkpoint_state_dict['output_layer.bias'] + checkpoint_state_dict['position_embeddings.weight'] = checkpoint_state_dict[ + 'embedding.position_embeddings.weight' + ] module.load_state_dict(checkpoint_state_dict, strict=True) else: diff --git a/nemo/collections/nlp/models/language_modeling/megatron_mamba_model.py b/nemo/collections/nlp/models/language_modeling/megatron_mamba_model.py index 5180bd12b35e..54dff1cd7887 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_mamba_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_mamba_model.py @@ -44,8 +44,6 @@ def model_provider_func(self, pre_process, post_process): self.transformer_config.gated_linear_unit = self.cfg.get('gated_linear_unit', False) self.transformer_config.layernorm_epsilon = self.cfg.get('layernorm_epsilon', 1e-5) - # TODO @ataghibakhsh: add mamba_ssm_ngroups=self.cfg.get('mamba_ssm_ngroups', 8) once MLM MR merged - model = MambaModel( config=self.transformer_config, max_sequence_length=self.cfg.get('encoder_seq_length', 4096), @@ -64,10 +62,6 @@ def forward(self, input_ids, position_ids=None, attention_mask=None, labels=None ) return output_tensor - def build_transformer_config(self): - transformer_config = super().build_transformer_config() - return transformer_config - def on_validation_epoch_end(self): averaged_loss = torch.tensor(0.0, dtype=torch.float32).cuda() diff --git a/nemo/collections/nlp/models/language_modeling/megatron_retro_model.py b/nemo/collections/nlp/models/language_modeling/megatron_retro_model.py index 2a8e5713573b..6dfe022d0275 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_retro_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_retro_model.py @@ -67,16 +67,6 @@ from nemo.core.neural_types import ChannelType, NeuralType from nemo.utils import logging -try: - import apex.transformer.pipeline_parallel.utils - from apex.transformer.pipeline_parallel.utils import get_num_microbatches - - HAVE_APEX = True - -except (ImportError, ModuleNotFoundError): - - HAVE_APEX = False - try: from megatron.core import InferenceParams, parallel_state from megatron.core.models.retro import RetroModel as MCoreRetroModel @@ -97,9 +87,17 @@ except (ImportError, ModuleNotFoundError): TransformerConfig = ApexGuardDefaults + RetroConfig = ApexGuardDefaults HAVE_MEGATRON_CORE = False +try: + from megatron.core.num_microbatches_calculator import get_num_microbatches + +except (ImportError, ModuleNotFoundError): + logging.warning("Megatron num_microbatches_calculator not found, using Apex version.") + from apex.transformer.pipeline_parallel.utils import get_num_microbatches + try: import transformer_engine from transformer_engine.pytorch import module as te_module diff --git a/nemo/collections/nlp/models/language_modeling/megatron_t0_model.py b/nemo/collections/nlp/models/language_modeling/megatron_t0_model.py index 4d4d80b71a98..cee1b11a160b 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_t0_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_t0_model.py @@ -27,23 +27,22 @@ from nemo.utils import AppState, logging try: - from apex.transformer.pipeline_parallel.utils import _reconfigure_microbatch_calculator + from megatron.core import parallel_state - HAVE_APEX = True + HAVE_MEGATRON_CORE = True except (ImportError, ModuleNotFoundError): - HAVE_APEX = False - + HAVE_MEGATRON_CORE = False try: - from megatron.core import parallel_state - - HAVE_MEGATRON_CORE = True + from megatron.core.num_microbatches_calculator import reconfigure_num_microbatches_calculator except (ImportError, ModuleNotFoundError): - - HAVE_MEGATRON_CORE = False + logging.warning("Megatron num_microbatches_calculator not found, using Apex version.") + from apex.transformer.pipeline_parallel.utils import ( + _reconfigure_microbatch_calculator as reconfigure_num_microbatches_calculator, + ) __all__ = ['MegatronT0Model'] @@ -162,7 +161,7 @@ def _reconfigure_and_process_inference_batch(self, batch): # This should happen only on the last batch of the validation/test dataset with drop_last=False. if global_batch_per_gpu != self.cfg.data.validation_ds.global_batch_size: app_state = AppState() - _reconfigure_microbatch_calculator( + reconfigure_num_microbatches_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=global_batch_per_gpu * parallel_state.get_data_parallel_world_size(), @@ -194,7 +193,10 @@ def build_train_valid_test_datasets(self, stage): logging.info(f'Length of train dataset: {len(self._train_ds)}') def build_data_loader( - self, dataset, data_cfg, consumed_samples=0, + self, + dataset, + data_cfg, + consumed_samples=0, ): """Buld dataloader given an input dataset.""" logging.info(f'Building dataloader with consumed samples: {consumed_samples}') @@ -224,13 +226,19 @@ def setup_training_dataloader(self): if hasattr(self, '_train_ds'): consumed_samples = self.compute_consumed_samples(0) self._train_dl = self.build_data_loader( - dataset=self._train_ds, data_cfg=self.cfg.data.train_ds, consumed_samples=consumed_samples, + dataset=self._train_ds, + data_cfg=self.cfg.data.train_ds, + consumed_samples=consumed_samples, ) def setup_eval_dataloader(self, datasets, data_cfg): dataloaders = [] for dataset in datasets: - eval_dl = self.build_data_loader(dataset=dataset, data_cfg=data_cfg, consumed_samples=0,) + eval_dl = self.build_data_loader( + dataset=dataset, + data_cfg=data_cfg, + consumed_samples=0, + ) dataloaders.append(eval_dl) return dataloaders diff --git a/nemo/collections/nlp/models/language_modeling/megatron_t5_prompt_learning_model.py b/nemo/collections/nlp/models/language_modeling/megatron_t5_prompt_learning_model.py index f13be45db836..1f54cb87428e 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_t5_prompt_learning_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_t5_prompt_learning_model.py @@ -35,19 +35,6 @@ from nemo.collections.nlp.parts.utils_funcs import get_last_rank from nemo.utils import AppState, logging -try: - from apex.transformer.pipeline_parallel.utils import ( - _reconfigure_microbatch_calculator, - get_micro_batch_size, - get_num_microbatches, - ) - - HAVE_APEX = True - -except (ImportError, ModuleNotFoundError): - - HAVE_APEX = False - try: from megatron.core import parallel_state from megatron.core.enums import ModelType @@ -60,26 +47,34 @@ HAVE_MEGATRON_CORE = False +try: + from megatron.core.num_microbatches_calculator import get_micro_batch_size, get_num_microbatches + +except (ImportError, ModuleNotFoundError): + logging.warning("Megatron num_microbatches_calculator not found, using Apex version.") + from apex.transformer.pipeline_parallel.utils import get_micro_batch_size, get_num_microbatches + + __all__ = ['MegatronT5PromptLearningModel'] class MegatronT5PromptLearningModel(MegatronBasePromptLearningModel): """ - Model class for prompt-tuning or p-tuning a pretrained Megatron T5 model. + Model class for prompt-tuning or p-tuning a pretrained Megatron T5 model. Prompt Tuning initalizes virtual prompt embeddings directly from a copy of certain token embeddings from the the pretrained T5 model's vocabulary - and directly tunes these embedding weights. The token embeddings used in - initalization are specified by the user in the config file. The model can - be prompt-tuned for multiple tasks at once. Virtual prompts are stored in a - prompt table and can be added or deleted without disrupting virtual prompts - for other tasks. + and directly tunes these embedding weights. The token embeddings used in + initalization are specified by the user in the config file. The model can + be prompt-tuned for multiple tasks at once. Virtual prompts are stored in a + prompt table and can be added or deleted without disrupting virtual prompts + for other tasks. P-tuning initializes an LSTM encoder model that generates virtual prompt embeddings for every task. Each task shares the same encoder. After p-tuning is compelete, the learned virtual prompts can be saved to the prompt table - using add_ptuned_prompts_to_prompt_table(). Thus, if a user wants to add a - new virtual prompt via p-tuning, they do not need to retrain on all previous + using add_ptuned_prompts_to_prompt_table(). Thus, if a user wants to add a + new virtual prompt via p-tuning, they do not need to retrain on all previous tasks. This gives p-tuning the same task flexiblity as prompt-tuning. """ @@ -93,7 +88,15 @@ def first_stage_of_pipeline(self): return False def forward( - self, input_ids, dec_input, enc_mask, dec_mask, position_ids, taskname_ids, labels=None, inference=False, + self, + input_ids, + dec_input, + enc_mask, + dec_mask, + position_ids, + taskname_ids, + labels=None, + inference=False, ): """ Special forward method for p-tuning/prompt-tuning pretrained @@ -174,8 +177,8 @@ def load_frozen_model(self, cfg, trainer): def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only): """ - Dataloader produces a global batch which is turned into a list of microbatches. - The list of microbatches is then piped through the pipeline using megatron-core fwd/bwd functions. + Dataloader produces a global batch which is turned into a list of microbatches. + The list of microbatches is then piped through the pipeline using megatron-core fwd/bwd functions. """ # Get seq length of batch batch = next(dataloader_iter) @@ -230,15 +233,15 @@ def loss_func(output_tensor): return fwd_output_and_loss_func def backward(self, *args, **kwargs): - """ LightningModule hook to do backward. - We want this to do nothing since we run backward in the fwd/bwd functions from megatron-core. - No need to call it here. + """LightningModule hook to do backward. + We want this to do nothing since we run backward in the fwd/bwd functions from megatron-core. + No need to call it here. """ return def optimizer_zero_grad(self, *args, **kwargs): - """ LightningModule hook to zero grad. - We want this to do nothing as we are zeroing grads during the training_step. + """LightningModule hook to zero grad. + We want this to do nothing as we are zeroing grads during the training_step. """ return @@ -291,9 +294,9 @@ def get_predictions(self, input_ids, enc_mask, encoder_input, labels): enc_mask=enc_mask, num_tokens_to_generate=self.decoder_seq_length, encoder_input=encoder_input, - bos_id=self.tokenizer.pad_id - if self.cfg.data.get('decoder_starts_with_pad', False) - else self.tokenizer.bos_id, + bos_id=( + self.tokenizer.pad_id if self.cfg.data.get('decoder_starts_with_pad', False) else self.tokenizer.bos_id + ), ) # Special ids to text function to handle stripping and special tokens with sentencepiece tokenizers. preds_text = MegatronT5SFTModel.ids_to_text(predicted_token_ids, self.tokenizer) @@ -385,7 +388,8 @@ def on_validation_epoch_end(self): gather_results_dedup = list(set(itertools.chain(*gather_results))) val_metric_dict = self.validation_metric.get_score( - [i[2] for i in gather_results_dedup], [i[1] for i in gather_results_dedup], + [i[2] for i in gather_results_dedup], + [i[1] for i in gather_results_dedup], ) for metric, val in val_metric_dict.items(): @@ -445,9 +449,9 @@ def build_virtual_prompt_dataset( drop_last=drop_last, num_workers=num_workers, pin_memory=pin_memory, - persistent_workers=True - if num_workers > 0 - else False, # (@adithyare and @eharper) We need to set this to True to get around issues with spawn=True + persistent_workers=( + True if num_workers > 0 else False + ), # (@adithyare and @eharper) We need to set this to True to get around issues with spawn=True ) print('build success', len(dataloader), dataset_paths) return dataset, dataloader @@ -477,9 +481,9 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A enc_mask=enc_mask, num_tokens_to_generate=self.decoder_seq_length, encoder_input=encoder_input, - bos_id=self.tokenizer.pad_id - if self.cfg.data.get('decoder_starts_with_pad', False) - else self.tokenizer.bos_id, + bos_id=( + self.tokenizer.pad_id if self.cfg.data.get('decoder_starts_with_pad', False) else self.tokenizer.bos_id + ), ) # Special ids to text function to handle stripping and special tokens with sentencepiece tokenizers. preds_text = MegatronT5SFTModel.ids_to_text(predicted_token_ids, self.tokenizer) @@ -522,7 +526,7 @@ def on_predict_epoch_end(self) -> None: input_prediction_pair = [] correct = 0 - for (input, pred, label) in gather_results_dedup: + for input, pred, label in gather_results_dedup: input_prediction_pair.append((input, pred)) if label: if pred == label: diff --git a/nemo/collections/nlp/models/language_modeling/megatron_t5_sft_model.py b/nemo/collections/nlp/models/language_modeling/megatron_t5_sft_model.py index 2344dac3a64a..c70f44925d33 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_t5_sft_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_t5_sft_model.py @@ -31,18 +31,6 @@ from nemo.collections.nlp.parts.utils_funcs import get_last_rank from nemo.utils import AppState, logging -try: - from apex.transformer.pipeline_parallel.utils import ( - _reconfigure_microbatch_calculator, - get_current_global_batch_size, - get_micro_batch_size, - get_num_microbatches, - ) - - HAVE_APEX = True -except (ImportError, ModuleNotFoundError): - HAVE_APEX = False - try: from megatron.core import parallel_state from megatron.core.pipeline_parallel.schedules import get_forward_backward_func @@ -53,17 +41,32 @@ HAVE_MEGATRON_CORE = False +try: + from megatron.core.num_microbatches_calculator import ( + get_current_global_batch_size, + get_micro_batch_size, + get_num_microbatches, + reconfigure_num_microbatches_calculator, + ) + +except (ImportError, ModuleNotFoundError): + logging.warning("Megatron num_microbatches_calculator not found, using Apex version.") + from apex.transformer.pipeline_parallel.utils import ( + _reconfigure_microbatch_calculator as reconfigure_num_microbatches_calculator, + ) + from apex.transformer.pipeline_parallel.utils import ( + get_current_global_batch_size, + get_micro_batch_size, + get_num_microbatches, + ) + __all__ = ['MegatronT5SFTModel'] class MegatronT5SFTModel(NLPAdapterModelMixin, MegatronT5Model): - """ T5 Finetuning model in the same format as MegatronGPTSFTModel """ + """T5 Finetuning model in the same format as MegatronGPTSFTModel""" def __init__(self, cfg: DictConfig, trainer: Trainer): - if not HAVE_APEX: - raise ImportError( - "Apex was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." - ) super().__init__(cfg, trainer=trainer) self.val_metric = self.test_metric = None if hasattr(self.cfg.data, "validation_ds"): @@ -176,7 +179,7 @@ def setup(self, stage=None): def on_validation_epoch_start(self): app_state = AppState() - _reconfigure_microbatch_calculator( + reconfigure_num_microbatches_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=self.cfg.data.validation_ds.global_batch_size, @@ -187,7 +190,7 @@ def on_validation_epoch_start(self): def on_test_epoch_start(self): app_state = AppState() - _reconfigure_microbatch_calculator( + reconfigure_num_microbatches_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=self.cfg.data.test_ds.global_batch_size, @@ -270,7 +273,7 @@ def _reconfigure_and_process_inference_batch(self, batch, ds_config): != ds_config.global_batch_size // parallel_state.get_data_parallel_world_size() ): app_state = AppState() - _reconfigure_microbatch_calculator( + reconfigure_num_microbatches_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=global_batch_size_per_gpu * parallel_state.get_data_parallel_world_size(), @@ -280,7 +283,7 @@ def _reconfigure_and_process_inference_batch(self, batch, ds_config): # NOTE: need to explicitly handle resetting for multi-validation else: app_state = AppState() - _reconfigure_microbatch_calculator( + reconfigure_num_microbatches_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=ds_config.global_batch_size, @@ -290,8 +293,8 @@ def _reconfigure_and_process_inference_batch(self, batch, ds_config): def fwd_bwd_step(self, dataloader_iter, forward_only): """ - Dataloader produces a global batch which is turned into a list of microbatches. - The list of microbatches is then piped through the pipeline using Apex fwd/bwd functions. + Dataloader produces a global batch which is turned into a list of microbatches. + The list of microbatches is then piped through the pipeline using Apex fwd/bwd functions. """ # If tuple, 1st element in it is the batch since dataloader_iter returns batch, batch_idx, dataloader_idx batch = next(dataloader_iter) @@ -562,7 +565,7 @@ def inference_epoch_end(self, outputs, mode, data_cfg): app_state = AppState() if hasattr(self, "_train_ds"): - _reconfigure_microbatch_calculator( + reconfigure_num_microbatches_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=self.cfg.data.train_ds.global_batch_size, @@ -572,7 +575,7 @@ def inference_epoch_end(self, outputs, mode, data_cfg): # When running `trainer.validate()`, the training dataset is not available. else: logging.warning('No training data found, reconfiguring microbatches based on validation batch sizes.') - _reconfigure_microbatch_calculator( + reconfigure_num_microbatches_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=data_cfg.global_batch_size, @@ -605,7 +608,13 @@ def on_test_epoch_end(self): # return super().on_test_epoch_end() def build_data_loader( - self, dataset, global_batch_size, shuffle, num_workers, pin_memory, drop_last, + self, + dataset, + global_batch_size, + shuffle, + num_workers, + pin_memory, + drop_last, ): """Buld dataloader given an input dataset.""" @@ -652,9 +661,11 @@ def setup_eval_data(self, datasets, data_cfg): for dataset in datasets: eval_dl = self.build_data_loader( dataset, - global_batch_size=self.cfg.data.test_ds.global_batch_size - if hasattr(self.cfg.data, "test_ds") - else self.cfg.data.validation_ds.global_batch_size, + global_batch_size=( + self.cfg.data.test_ds.global_batch_size + if hasattr(self.cfg.data, "test_ds") + else self.cfg.data.validation_ds.global_batch_size + ), shuffle=data_cfg.shuffle, num_workers=data_cfg.num_workers, pin_memory=data_cfg.pin_memory, diff --git a/nemo/collections/nlp/models/machine_translation/megatron_nmt_model.py b/nemo/collections/nlp/models/machine_translation/megatron_nmt_model.py index 5a41682a4b5b..4461b417f311 100644 --- a/nemo/collections/nlp/models/machine_translation/megatron_nmt_model.py +++ b/nemo/collections/nlp/models/machine_translation/megatron_nmt_model.py @@ -53,27 +53,27 @@ from nemo.utils import AppState, logging, timers try: - from apex.transformer.pipeline_parallel.utils import ( - _reconfigure_microbatch_calculator, - get_micro_batch_size, - get_num_microbatches, - ) + from megatron.core import parallel_state - HAVE_APEX = True + HAVE_MEGATRON_CORE = True except (ImportError, ModuleNotFoundError): - HAVE_APEX = False + HAVE_MEGATRON_CORE = False try: - from megatron.core import parallel_state - from megatron.core.pipeline_parallel.schedules import get_forward_backward_func - - HAVE_MEGATRON_CORE = True + from megatron.core.num_microbatches_calculator import ( + get_micro_batch_size, + get_num_microbatches, + reconfigure_num_microbatches_calculator, + ) except (ImportError, ModuleNotFoundError): - - HAVE_MEGATRON_CORE = False + logging.warning("Megatron num_microbatches_calculator not found, using Apex version.") + from apex.transformer.pipeline_parallel.utils import ( + _reconfigure_microbatch_calculator as reconfigure_num_microbatches_calculator, + ) + from apex.transformer.pipeline_parallel.utils import get_micro_batch_size, get_num_microbatches __all__ = ["MegatronNMTModel"] @@ -210,17 +210,21 @@ def _build_tokenizer(self): self.encoder_tokenizer, self.decoder_tokenizer = MTEncDecModel.setup_enc_dec_tokenizers( encoder_tokenizer_library=self.encoder_tokenizer_library, encoder_tokenizer_model=encoder_tokenizer_model, - encoder_bpe_dropout=self._cfg.encoder_tokenizer.get('bpe_dropout', 0.0) - if self._cfg.encoder_tokenizer.get('bpe_dropout', 0.0) is not None - else 0.0, + encoder_bpe_dropout=( + self._cfg.encoder_tokenizer.get('bpe_dropout', 0.0) + if self._cfg.encoder_tokenizer.get('bpe_dropout', 0.0) is not None + else 0.0 + ), encoder_model_name=self._cfg.encoder_tokenizer.get('type', None), encoder_r2l=self._cfg.encoder_tokenizer.get('r2l', False), decoder_tokenizer_library=self.decoder_tokenizer_library, encoder_tokenizer_vocab_file=self._cfg.encoder_tokenizer.get('vocab_file', None), decoder_tokenizer_model=decoder_tokenizer_model, - decoder_bpe_dropout=self._cfg.decoder_tokenizer.get('bpe_dropout', 0.0) - if self._cfg.decoder_tokenizer.get('bpe_dropout', 0.0) is not None - else 0.0, + decoder_bpe_dropout=( + self._cfg.decoder_tokenizer.get('bpe_dropout', 0.0) + if self._cfg.decoder_tokenizer.get('bpe_dropout', 0.0) is not None + else 0.0 + ), decoder_model_name=self._cfg.encoder_tokenizer.get('type', None), decoder_r2l=self._cfg.decoder_tokenizer.get('r2l', False), encoder_sentencepiece_legacy=self._cfg.encoder_tokenizer.get('sentencepiece_legacy', False), @@ -252,10 +256,14 @@ def _build_vocab(self): f"NMT-XLM objective requires sentencepiece tokenizer, but got decoder tokenizer library : {self.cfg.decoder_tokenizer.library}" ) MegatronT5Model.add_special_tokens_to_tokenizer( - tokenizer=self.encoder_tokenizer, tokenizer_cfg=self.cfg.encoder_tokenizer, dataset_type='ul2', + tokenizer=self.encoder_tokenizer, + tokenizer_cfg=self.cfg.encoder_tokenizer, + dataset_type='ul2', ) MegatronT5Model.add_special_tokens_to_tokenizer( - tokenizer=self.decoder_tokenizer, tokenizer_cfg=self.cfg.decoder_tokenizer, dataset_type='ul2', + tokenizer=self.decoder_tokenizer, + tokenizer_cfg=self.cfg.decoder_tokenizer, + dataset_type='ul2', ) # Set up pre and post processors as well. @@ -277,7 +285,10 @@ def _build_vocab(self): else: # After this call, the model will have self.source_processor and self.target_processor objects self.source_processor, self.target_processor = MTEncDecModel.setup_pre_and_post_processing_utils( - self.src_language, self.tgt_language, self.encoder_tokenizer_library, self.decoder_tokenizer_library, + self.src_language, + self.tgt_language, + self.encoder_tokenizer_library, + self.decoder_tokenizer_library, ) self.multilingual_ids = [None] @@ -289,8 +300,8 @@ def _build_vocab(self): def fwd_bwd_step(self, dataloader_iter, forward_only): """ - Dataloader produces a global batch which is turned into a list of microbatches. - The list of microbatches is then piped through the pipeline using Apex fwd/bwd functions. + Dataloader produces a global batch which is turned into a list of microbatches. + The list of microbatches is then piped through the pipeline using Apex fwd/bwd functions. """ # If tuple, 1st element in it is the batch since dataloader_iter returns batch, batch_idx, dataloader_idx batch = next(dataloader_iter) @@ -322,7 +333,7 @@ def eval_step(self, dataloader_iter): # Eval step requires text datasets so we need to reconfigure MBS on each batch. app_state = AppState() - _reconfigure_microbatch_calculator( + reconfigure_num_microbatches_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=batch['text_enc'].size(0) * parallel_state.get_data_parallel_world_size(), @@ -351,13 +362,19 @@ def eval_step(self, dataloader_iter): # Post-process the translations and inputs to log. preds = self.postprocess_outputs( - outputs=predicted_tokens_ids, tokenizer=self.decoder_tokenizer, processor=target_processor, + outputs=predicted_tokens_ids, + tokenizer=self.decoder_tokenizer, + processor=target_processor, ) labels = self.postprocess_outputs( - outputs=labels, tokenizer=self.decoder_tokenizer, processor=target_processor, + outputs=labels, + tokenizer=self.decoder_tokenizer, + processor=target_processor, ) encoder_inputs = self.postprocess_outputs( - outputs=tokens_enc, tokenizer=self.encoder_tokenizer, processor=source_processor, + outputs=tokens_enc, + tokenizer=self.encoder_tokenizer, + processor=source_processor, ) loss_dict = { @@ -537,7 +554,7 @@ def eval_epoch_end(self, outputs, mode): app_state = AppState() if hasattr(self, "_train_ds"): - _reconfigure_microbatch_calculator( + reconfigure_num_microbatches_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=self._cfg.train_ds.global_batch_size, @@ -781,12 +798,12 @@ def build_memmap_dataset_from_config(self, cfg: DictConfig): tgt_file=tgt_file, num_samples=num_samples, prepend_id=multilingual_ids[idx], - src_language=self.src_language - if not isinstance(self.src_language, ListConfig) - else self.src_language[idx], - tgt_language=self.tgt_language - if not isinstance(self.tgt_language, ListConfig) - else self.tgt_language[idx], + src_language=( + self.src_language if not isinstance(self.src_language, ListConfig) else self.src_language[idx] + ), + tgt_language=( + self.tgt_language if not isinstance(self.tgt_language, ListConfig) else self.tgt_language[idx] + ), ) datasets.append(dataset) dataset = BlendableDataset( @@ -808,7 +825,7 @@ def list_available_models(self): def on_validation_epoch_start(self): app_state = AppState() - _reconfigure_microbatch_calculator( + reconfigure_num_microbatches_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=parallel_state.get_data_parallel_world_size(), diff --git a/nemo/collections/nlp/modules/common/huggingface/huggingface_utils.py b/nemo/collections/nlp/modules/common/huggingface/huggingface_utils.py index cf692e07749d..d8f6936f7126 100644 --- a/nemo/collections/nlp/modules/common/huggingface/huggingface_utils.py +++ b/nemo/collections/nlp/modules/common/huggingface/huggingface_utils.py @@ -16,12 +16,6 @@ from typing import List, Optional from transformers import ( - ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST, - BERT_PRETRAINED_MODEL_ARCHIVE_LIST, - CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST, - DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST, - GPT2_PRETRAINED_MODEL_ARCHIVE_LIST, - ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, AlbertConfig, AutoModel, BertConfig, @@ -41,6 +35,74 @@ __all__ = ["get_huggingface_lm_model", "get_huggingface_pretrained_lm_models_list", "VOCAB_FILE_NAME"] +# Manually specify the model archive lists since these are now removed in HF +# https://github.com/huggingface/transformers/blob/v4.40-release/src/transformers/models/deprecated/_archive_maps.py +ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "albert/albert-base-v1", + "albert/albert-large-v1", + "albert/albert-xlarge-v1", + "albert/albert-xxlarge-v1", + "albert/albert-base-v2", + "albert/albert-large-v2", + "albert/albert-xlarge-v2", + "albert/albert-xxlarge-v2", +] + +BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "google-bert/bert-base-uncased", + "google-bert/bert-large-uncased", + "google-bert/bert-base-cased", + "google-bert/bert-large-cased", + "google-bert/bert-base-multilingual-uncased", + "google-bert/bert-base-multilingual-cased", + "google-bert/bert-base-chinese", + "google-bert/bert-base-german-cased", + "google-bert/bert-large-uncased-whole-word-masking", + "google-bert/bert-large-cased-whole-word-masking", + "google-bert/bert-large-uncased-whole-word-masking-finetuned-squad", + "google-bert/bert-large-cased-whole-word-masking-finetuned-squad", + "google-bert/bert-base-cased-finetuned-mrpc", + "google-bert/bert-base-german-dbmdz-cased", + "google-bert/bert-base-german-dbmdz-uncased", + "cl-tohoku/bert-base-japanese", + "cl-tohoku/bert-base-japanese-whole-word-masking", + "cl-tohoku/bert-base-japanese-char", + "cl-tohoku/bert-base-japanese-char-whole-word-masking", + "TurkuNLP/bert-base-finnish-cased-v1", + "TurkuNLP/bert-base-finnish-uncased-v1", + "wietsedv/bert-base-dutch-cased", +] +CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "almanach/camembert-base", + "Musixmatch/umberto-commoncrawl-cased-v1", + "Musixmatch/umberto-wikipedia-uncased-v1", +] + +DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "distilbert-base-uncased", + "distilbert-base-uncased-distilled-squad", + "distilbert-base-cased", + "distilbert-base-cased-distilled-squad", + "distilbert-base-german-cased", + "distilbert-base-multilingual-cased", + "distilbert-base-uncased-finetuned-sst-2-english", +] +GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "openai-community/gpt2", + "openai-community/gpt2-medium", + "openai-community/gpt2-large", + "openai-community/gpt2-xl", + "distilbert/distilgpt2", +] +ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "FacebookAI/roberta-base", + "FacebookAI/roberta-large", + "FacebookAI/roberta-large-mnli", + "distilbert/distilroberta-base", + "openai-community/roberta-base-openai-detector", + "openai-community/roberta-large-openai-detector", +] + HUGGINGFACE_MODELS = { "BertModel": { @@ -94,7 +156,9 @@ def get_huggingface_lm_model( - pretrained_model_name: str, config_dict: Optional[dict] = None, config_file: Optional[str] = None, + pretrained_model_name: str, + config_dict: Optional[dict] = None, + config_file: Optional[str] = None, ): """ Returns lm model instantiated with Huggingface @@ -135,7 +199,9 @@ def get_huggingface_lm_model( raise ValueError(f"Use HuggingFace API directly in NeMo for {pretrained_model_name}") -def get_huggingface_pretrained_lm_models_list(include_external: bool = False,) -> List[str]: +def get_huggingface_pretrained_lm_models_list( + include_external: bool = False, +) -> List[str]: """ Returns the list of pretrained HuggingFace language models diff --git a/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py b/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py index 2f00f5907ad8..48b6afa788ae 100644 --- a/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py +++ b/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py @@ -14,17 +14,21 @@ import torch import torch.nn.functional as F +from megatron.core import InferenceParams from megatron.core.fusions.fused_bias_geglu import bias_geglu_impl from megatron.core.fusions.fused_bias_gelu import bias_gelu_impl from megatron.core.fusions.fused_bias_swiglu import bias_swiglu_impl from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding from megatron.core.models.common.embeddings.rotary_pos_embedding import apply_rotary_pos_emb +from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.transformer.attention import SelfAttention from megatron.core.transformer.custom_layers.transformer_engine import SplitAlongDim from megatron.core.transformer.mlp import MLP from megatron.core.transformer.moe.experts import SequentialMLP +from megatron.core.transformer.transformer_block import TransformerBlock from megatron.core.transformer.transformer_layer import TransformerLayer from megatron.core.utils import make_viewless_tensor +from torch import Tensor from nemo.collections.nlp.modules.common.megatron.adapters.parallel_adapters import ( AdapterName, @@ -37,6 +41,7 @@ LoraMoeHto4HAdapterConfig, LoraUnfusedHto4HAdapterConfig, LoraUnfusedKQVAdapterConfig, + MLPHeadAdapterConfig, MLPInfusedAdapterConfig, ParallelLinearAdapterConfig, PromptEncoderAdapterConfig, @@ -61,6 +66,34 @@ def mcore_register_adapters(self): raise NotImplementedError("Mcore mixins should implement setup_adapters on a subclass of MyBase") +class MCoreTransformerBlockMixin(TransformerBlock, MCoreAdapterModuleMixin): + def mcore_register_adapters(self): + """ + Setup NeMo (canonical) Adapter to this MCore layer. + """ + self.set_accepted_adapter_types([MLPHeadAdapterConfig._target_]) + + def forward( + self, + hidden_states: Tensor, + attention_mask: Tensor, + context: Tensor = None, + context_mask: Tensor = None, + rotary_pos_emb: Tensor = None, + inference_params: InferenceParams = None, + packed_seq_params: PackedSeqParams = None, + ): + hidden_states = super().forward( + hidden_states, attention_mask, context, context_mask, rotary_pos_emb, inference_params, packed_seq_params + ) + + mlp_head_adapter = self.get_adapter_module(AdapterName.MLP_HEAD_ADAPTER) + if mlp_head_adapter and self.adapter_cfg[AdapterName.MLP_HEAD_ADAPTER]['enabled']: + hidden_states = mlp_head_adapter(hidden_states) + + return hidden_states + + class MCoreSelfAttentionMixin(SelfAttention, MCoreAdapterModuleMixin): def mcore_register_adapters(self): """ diff --git a/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py b/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py index 9ab1da7136a1..7167eefda637 100644 --- a/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py +++ b/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py @@ -77,6 +77,7 @@ class AdapterName(str, enum.Enum): PTUNING_ADAPTER = "ptuning_adapter" LORA_KQV_ADAPTER = "lora_kqv_adapter" LORA_UNFUSED_KQV_ADAPTER = "lora_unfused_kqv_adapter" + MLP_HEAD_ADAPTER = "mlp_head_adapter" LORA_KV_ADAPTER = "lora_kv_adapter" LORA_Q_ADAPTER = "lora_q_adapter" MM_LINEAR_ADAPTER = "mm_linear_adapter" @@ -256,7 +257,7 @@ def __init__( te_version = packaging.version.Version(version("transformer-engine")) if te_version >= packaging.version.Version("1.5.0dev") and ( - not self.input_is_parallel and model_parallel_config.tp_comm_disable_qkv + not self.input_is_parallel and getattr(model_parallel_config, "tp_comm_overlap_disable_qkv", False) ): # TE 1.5 introduces the option `return_layernorm_output_gathered`, so the all gather # in the forward method is not needed, so set self._sequence_parallel to False @@ -388,6 +389,57 @@ class ParallelLinearAdapterConfig(AdapterConfig): _target_: str = "{0}.{1}".format(ParallelLinearAdapter.__module__, ParallelLinearAdapter.__name__) +class MLPHeadAdapter(nn.Module, AdapterModuleUtil): + def __init__( + self, + in_features: int, + out_features: int, + input_is_parallel: bool = False, + model_parallel_config: Optional[ModelParallelConfig] = None, + **kwargs, + ): + super().__init__() + if model_parallel_config is None: + model_parallel_config = ModelParallelConfig() + self._sequence_parallel = model_parallel_config.sequence_parallel + model_parallel_config.sequence_parallel = False # SP is irrelevant for the lora linear layer + + if input_is_parallel: + self.linear = RowParallelLinear( + in_features, + out_features, + config=model_parallel_config, + input_is_parallel=True, + skip_bias_add=True, + bias=False, + init_method=init.xavier_normal_, + ) + else: + self.linear = ColumnParallelLinear( + in_features, + out_features, + config=model_parallel_config, + bias=False, + gather_output=True, + init_method=init.xavier_normal_, + disable_grad_reduce=self._sequence_parallel, + ) + + # Setup adapter strategy + self.setup_adapter_strategy(adapter_mixin_strategies.ReturnResultAdapterStrategy()) + + def forward(self, x): + x, _ = self.linear(x) + return x + + +@dataclass +class MLPHeadAdapterConfig(AdapterConfig): + in_features: int + out_features: int + _target_: str = "{0}.{1}".format(MLPHeadAdapter.__module__, MLPHeadAdapter.__name__) + + class LoraKQVAdapter(ParallelLinearAdapter): """ Lora Adapters are the same arch as regular adapters but with potentially different input and output feature sizes @@ -777,14 +829,21 @@ def set_inference_table(self, prompt_representation: torch.Tensor): self.is_inference_ready = True return True - def clear_inference_table(self): + def clear_inference_table( + self, + ): self.inference_table.fill_(0.0) self.is_inference_ready = False - def get_inference_table(self): + def get_inference_table( + self, + ): return self.inference_table.data - def inner_forward(self): + def inner_forward( + self, + ): + input_embeds = self.embedding(self.indices).unsqueeze(0) intermediate_parallel, bias_parallel = self.first(input_embeds) intermediate_parallel = fused_bias_gelu(intermediate_parallel, bias_parallel) diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_init.py b/nemo/collections/nlp/modules/common/megatron/megatron_init.py index 55e386bb22e5..d8fac724e63c 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_init.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_init.py @@ -21,10 +21,9 @@ try: from apex.transformer.log_util import set_logging_level - from apex.transformer.microbatches import ConstantNumMicroBatches - from apex.transformer.pipeline_parallel.utils import setup_microbatch_calculator HAVE_APEX = True + except (ImportError, ModuleNotFoundError): HAVE_APEX = False @@ -44,10 +43,38 @@ set_virtual_pipeline_model_parallel_rank, ) + HAVE_MEGATRON_CORE = True + except (ImportError, ModuleNotFoundError): HAVE_MEGATRON_CORE = False +try: + from megatron.core.num_microbatches_calculator import ( + ConstantNumMicroBatchesCalculator, + get_current_global_batch_size, + get_micro_batch_size, + get_num_microbatches, + init_num_microbatches_calculator, + ) + + MCORE_MB_CALCULATOR = True + +except (ImportError, ModuleNotFoundError): + logging.warning("Megatron num_microbatches_calculator not found, using Apex version.") + from apex.transformer.microbatches import ConstantNumMicroBatches as ConstantNumMicroBatchesCalculator + from apex.transformer.pipeline_parallel.utils import ( + get_current_global_batch_size, + get_micro_batch_size, + get_num_microbatches, + ) + from apex.transformer.pipeline_parallel.utils import ( + setup_microbatch_calculator as init_num_microbatches_calculator, + ) + + MCORE_MB_CALCULATOR = False + + try: from apex.transformer.parallel_state import set_virtual_pipeline_model_parallel_world_size @@ -136,29 +163,51 @@ def initialize_model_parallel_for_nemo( if global_batch_size and micro_batch_size is not None: # TODO: add rampup_batch_size here when we have it implemented - from apex.transformer.pipeline_parallel.utils import _GLOBAL_NUM_MICROBATCHES_CALCULATOR - - if _GLOBAL_NUM_MICROBATCHES_CALCULATOR is None: - setup_microbatch_calculator( - rank=global_rank, - global_batch_size=global_batch_size, - micro_batch_size=micro_batch_size, - data_parallel_size=app_state.data_parallel_size, - rampup_batch_size=rampup_batch_size, - ) + if MCORE_MB_CALCULATOR: + from megatron.core.num_microbatches_calculator import _GLOBAL_NUM_MICROBATCHES_CALCULATOR + + if _GLOBAL_NUM_MICROBATCHES_CALCULATOR is None: + init_num_microbatches_calculator( + rank=global_rank, + global_batch_size=global_batch_size, + micro_batch_size=micro_batch_size, + data_parallel_size=app_state.data_parallel_size, + rampup_batch_size=rampup_batch_size, + ) + else: + if isinstance(_GLOBAL_NUM_MICROBATCHES_CALCULATOR, ConstantNumMicroBatchesCalculator): + assert get_current_global_batch_size() == global_batch_size + assert get_micro_batch_size() == micro_batch_size + assert get_num_microbatches() == global_batch_size // ( + micro_batch_size * app_state.data_parallel_size + ) + else: + raise Exception("Microbatch calculator already initialized.") else: - if isinstance(_GLOBAL_NUM_MICROBATCHES_CALCULATOR, ConstantNumMicroBatches): - assert _GLOBAL_NUM_MICROBATCHES_CALCULATOR.current_global_batch_size == global_batch_size - assert _GLOBAL_NUM_MICROBATCHES_CALCULATOR.micro_batch_size == micro_batch_size - assert _GLOBAL_NUM_MICROBATCHES_CALCULATOR.num_micro_batches == global_batch_size // ( - micro_batch_size * app_state.data_parallel_size + from apex.transformer.pipeline_parallel.utils import _GLOBAL_NUM_MICROBATCHES_CALCULATOR + + if _GLOBAL_NUM_MICROBATCHES_CALCULATOR is None: + init_num_microbatches_calculator( + rank=global_rank, + global_batch_size=global_batch_size, + micro_batch_size=micro_batch_size, + data_parallel_size=app_state.data_parallel_size, + rampup_batch_size=rampup_batch_size, ) else: - raise Exception("Microbatch calculator already initialized.") + if isinstance(_GLOBAL_NUM_MICROBATCHES_CALCULATOR, ConstantNumMicroBatchesCalculator): + assert get_current_global_batch_size() == global_batch_size + assert get_micro_batch_size() == micro_batch_size + assert get_num_microbatches() == global_batch_size // ( + micro_batch_size * app_state.data_parallel_size + ) + else: + raise Exception("Microbatch calculator already initialized.") app_state._is_megatron_initialized = True - set_logging_level(apex_transformer_log_level) + if HAVE_APEX: + set_logging_level(apex_transformer_log_level) def _set_random_seed(seed_): diff --git a/nemo/collections/nlp/modules/common/megatron/transformer.py b/nemo/collections/nlp/modules/common/megatron/transformer.py index cb23c4a6b1fd..c5907873bac3 100644 --- a/nemo/collections/nlp/modules/common/megatron/transformer.py +++ b/nemo/collections/nlp/modules/common/megatron/transformer.py @@ -1525,7 +1525,12 @@ def forward( It indicates if the current step in the forward pass is the first in a gradient accumulation cycle. If set, FP8 weights are cached and some minor optimizations are applied to fuse_wgrad_accumulation """ - from apex.transformer.pipeline_parallel.utils import _GLOBAL_NUM_MICROBATCHES_CALCULATOR + try: + from megatron.core.num_microbatches_calculator import _GLOBAL_NUM_MICROBATCHES_CALCULATOR + + except (ImportError, ModuleNotFoundError): + logging.warning("Megatron num_microbatches_calculator not found, using Apex version.") + from apex.transformer.pipeline_parallel.utils import _GLOBAL_NUM_MICROBATCHES_CALCULATOR num_micro_batches = getattr(_GLOBAL_NUM_MICROBATCHES_CALCULATOR, 'num_micro_batches', 1) diff --git a/nemo/collections/nlp/modules/common/text_generation_strategy.py b/nemo/collections/nlp/modules/common/text_generation_strategy.py index 8f8fe313a5e3..8b9d7cf712c4 100644 --- a/nemo/collections/nlp/modules/common/text_generation_strategy.py +++ b/nemo/collections/nlp/modules/common/text_generation_strategy.py @@ -26,14 +26,7 @@ from nemo.collections.nlp.modules.common.lm_utils import pad_batch from nemo.collections.nlp.modules.common.megatron.module import Float16Module from nemo.collections.nlp.modules.common.megatron.utils import get_ltor_masks_and_position_ids - -try: - from apex.transformer.pipeline_parallel.utils import get_num_microbatches - - HAVE_APEX = True - -except (ImportError, ModuleNotFoundError): - HAVE_APEX = False +from nemo.utils import logging try: from megatron.core.pipeline_parallel.schedules import get_forward_backward_func @@ -46,6 +39,13 @@ HAVE_MEGATRON_CORE = False +try: + from megatron.core.num_microbatches_calculator import get_num_microbatches + +except (ImportError, ModuleNotFoundError): + logging.warning("Megatron num_microbatches_calculator not found, using Apex version.") + from apex.transformer.pipeline_parallel.utils import get_num_microbatches + # the text representation of eos_id, it applies for all tokenizers END_OF_SEQ = '<|endoftext|>' @@ -584,6 +584,7 @@ def __init__(self, model): media_type=getattr(self.data_cfg, 'media_type', 'image'), num_frames=getattr(self.data_cfg, 'num_frames', 1), mm_mlp_adapter_type=getattr(self.cfg.mm_cfg, 'mm_mlp_adapter_type', 'linear'), + use_lita=getattr(self.cfg.mm_cfg, 'use_lita', False), ) if self.multimodal_cfg['crop_size'] is None: image_processor = CLIPImageProcessor.from_pretrained( @@ -605,6 +606,21 @@ def __init__(self, model): width_num_patches += 1 self.num_media_latents = height_num_patches * width_num_patches + # add config for lita + if self.multimodal_cfg['use_lita']: + if self.cfg.mm_cfg.get('lita'): + lita = { + 'lita_video_arch': getattr(self.cfg.mm_cfg.lita, 'lita_video_arch', 'temporal_spatial_pool'), + 'visual_token_format': getattr(self.cfg.mm_cfg.lita, 'visual_token_format', 'v1'), + 'sample_frames': getattr(self.cfg.mm_cfg.lita, 'sample_frames', 1), + } + self.multimodal_cfg['lita'] = lita + else: + self.multimodal_cfg['use_lita'] = False + raise Warning( + 'Use lita has been set True but Lita config not found in the config file' + 'LITA will be disabled for this run.' + ) def clip_max_len(self, maxlen: int) -> int: """clip the max len based on the LM model max sequence length""" @@ -687,6 +703,7 @@ def prepare_batch_at_step( # not using type2use. uncomment it if it is used # if type_ids is not None: # types2use = type_ids[:, context_length - 1].view(batch_size, -1) + media = None """Prepare batch for each of the inference steps""" attention_mask_repeat = None diff --git a/nemo/collections/nlp/modules/common/text_generation_utils.py b/nemo/collections/nlp/modules/common/text_generation_utils.py index cd02f5409679..a5215b12bfae 100644 --- a/nemo/collections/nlp/modules/common/text_generation_utils.py +++ b/nemo/collections/nlp/modules/common/text_generation_utils.py @@ -31,29 +31,31 @@ DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IMAGE_PATCH_TOKEN, + DEFAULT_VID_END_TOKEN, + DEFAULT_VID_START_TOKEN, ) from nemo.collections.nlp.modules.common.megatron.utils import get_ltor_masks_and_position_ids from nemo.collections.nlp.modules.common.text_generation_strategy import model_inference_strategy_dispatcher from nemo.collections.nlp.modules.common.transformer.text_generation import LengthParam, OutputType, SamplingParam -from nemo.utils import AppState +from nemo.utils import AppState, logging try: - from apex.transformer.pipeline_parallel.utils import _reconfigure_microbatch_calculator + from megatron.core import parallel_state, tensor_parallel - HAVE_APEX = True + HAVE_MEGATRON_CORE = True except (ImportError, ModuleNotFoundError): - HAVE_APEX = False + HAVE_MEGATRON_CORE = False try: - from megatron.core import parallel_state, tensor_parallel - - HAVE_MEGATRON_CORE = True + from megatron.core.num_microbatches_calculator import reconfigure_num_microbatches_calculator except (ImportError, ModuleNotFoundError): - - HAVE_MEGATRON_CORE = False + logging.warning("Megatron num_microbatches_calculator not found, using Apex version.") + from apex.transformer.pipeline_parallel.utils import ( + _reconfigure_microbatch_calculator as reconfigure_num_microbatches_calculator, + ) __all__ = [ "get_default_sampling_params", @@ -144,7 +146,75 @@ def megatron_gpt_generate(model, inputs, tokenizer, length_params, sampling_para return output +def decode_time_tokens(tokenizer, text: str, duration: float, time_tokens: list[str], time_token_ids: list[int]): + """Decode the time tokens .... in the text to the actual time in seconds. + TO DO: to do time decoding on output ids instead of text + + Args: + text (str): _description_ + duration (float): the total length of the video in seconds + time_tokens (list[str]): list of time tokens [, , , ..] + time_token_ids (list[str]): list of time token ids [32004, 32005, ....] + """ + output_ids = tokenizer.text_to_ids(text) + num_time_tokens = len(time_token_ids) + # the original code is len(output_ids) - 1 + indices = [j for j in range(len(output_ids)) if output_ids[j] in time_token_ids] + last_processed = -1 + new_output_ids = [] + for j in range(len(indices)): + pred_seq = [int(output_ids[k]) for k in range(last_processed + 1, indices[j])] + new_output_ids.extend(pred_seq) + max_offset = num_time_tokens - 1 + time_token = tokenizer.ids_to_tokens([output_ids[indices[j]]])[0] + time_idx = time_tokens.index(time_token) + time = float(time_idx) * duration / max_offset + time = min(max(time, 0), duration) + time = round(time, 2) + # time_str = '<' + str(time) + '>' + time_str = '<%s>' % str(time) + new_output_ids.extend(tokenizer.text_to_ids(time_str)) + + last_processed = indices[j] + pred_seq = [int(x) for x in output_ids[last_processed + 1 :]] + new_output_ids.extend(pred_seq) + output_ids = new_output_ids + decoded_text = tokenizer.ids_to_text(output_ids) + return decoded_text + + +def encode_time_str(text: str, duration: float, num_time_tokens: int = 100, time_token_template: str = ""): + """ + Encode the common time expression to its time token expression + """ + + def time_to_string(time): + # time is normalized in [0, 1] + max_offset = float(num_time_tokens - 1) + time = int(np.round(max_offset * time)) + return time_token_template.format(t=time) + + def repl(match): + value = float(match.group(1)) / duration + return time_to_string(value) + f"" + + text = re.sub(r"<([\d.]{1,20})s>", repl, text) + text = re.sub(r"\s([\d.]{1,20})s[\s|\.|,|>]", repl, text) + text = re.sub(r"\s([\d.]{1,20}) seconds", repl, text) + text = re.sub(r"\s([\d.]{1,20}) second", repl, text) + + # This is to remove the timestamps from the text + text = re.sub(r"", "", text) + return text.strip() + + def megatron_neva_generate(model, prompt_dict_list, length_params, sampling_params, inference_config, **strategy_args): + use_lita = model.cfg.mm_cfg.get('use_lita', False) + if use_lita: + num_time_tokens = model.cfg.data.get('num_time_tokens', 100) + TIME_TOKEN_TEMPLATE = "" + time_tokens = [TIME_TOKEN_TEMPLATE.format(t=i) for i in range(num_time_tokens)] + time_token_ids = model.tokenizer.tokens_to_ids(time_tokens) model_type = model.cfg.mm_cfg.llm.get("model_type", "nvgpt") conv_template = model.cfg.data.get("conv_template", "nvgpt") @@ -152,6 +222,14 @@ def megatron_neva_generate(model, prompt_dict_list, length_params, sampling_para for idx, prompt_dict in enumerate(prompt_dict_list): # determine the media type in the prompt_dict media_type_token = inference_config.inference.get("media_type", "image") + if use_lita: + if prompt_dict.get("duration") is not None: + duration = prompt_dict.get("duration") + prompt_dict['prompt'] = encode_time_str( + prompt_dict['prompt'], duration, num_time_tokens, TIME_TOKEN_TEMPLATE + ) + else: + print("duration field is not in prompt file, skipping time encoding.") response = generate( model, inputs=prompt_dict.get('prompt'), @@ -184,7 +262,12 @@ def megatron_neva_generate(model, prompt_dict_list, length_params, sampling_para r'|', r'\|' ) ) - combined_pattern = re.compile(f'{pattern.pattern}|{pattern_nvgpt.pattern}') + + if use_lita: + pattern_lita = re.compile(rf'{DEFAULT_IM_START_TOKEN[model_type]}(.)+{DEFAULT_IM_END_TOKEN[model_type]}') + combined_pattern = re.compile(f'{pattern_lita.pattern}') + else: + combined_pattern = re.compile(f'{pattern.pattern}|{pattern_nvgpt.pattern}') clean_text = re.sub(combined_pattern, f"<{media_type_token}>", response['sentences'][0]) clean_response = clean_text @@ -204,10 +287,18 @@ def megatron_neva_generate(model, prompt_dict_list, length_params, sampling_para clean_response = clean_response.rsplit("[/INST] ", 1)[-1] elif conv_template == "llama_3": clean_response = clean_response.rsplit("assistant<|end_header_id|>\n\n", 1)[-1] - clean_response = clean_response.rstrip("<|eot_id|>") + clean_response = re.sub(r"(<\|eot_id\|>)+$", "", clean_response) elif conv_template == "v1": clean_response = clean_response.rsplit("ASSISTANT: ", 1)[-1] + if use_lita: + if prompt_dict.get("duration", None) is not None: + duration = prompt_dict.get("duration") + clean_response = decode_time_tokens( + model.tokenizer, clean_response, duration, time_tokens, time_token_ids + ) + else: + print("duration field is not in prompt file, skipping time decoding.") clean_response = clean_response.strip() response["clean_text"] = clean_text response["clean_response"] = clean_response @@ -703,6 +794,9 @@ def generate( if random_seed is not None: seed_everything(random_seed) + if hasattr(model, 'get_attention_mask_from_fusion') and model.get_attention_mask_from_fusion: + compute_attention_mask = False + output = synced_generate( model, inference_strategy, @@ -811,7 +905,7 @@ def sample_sequence_batch( app_state = AppState() micro_batch_size = context_tokens.shape[0] - _reconfigure_microbatch_calculator( + reconfigure_num_microbatches_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=micro_batch_size, @@ -1003,7 +1097,7 @@ def tab_sample_sequence_batch( ): app_state = AppState() micro_batch_size = context_tokens.shape[0] - _reconfigure_microbatch_calculator( + reconfigure_num_microbatches_calculator( rank=app_state.global_rank, rampup_batch_size=None, global_batch_size=micro_batch_size, diff --git a/nemo/collections/nlp/modules/common/tokenizer_utils.py b/nemo/collections/nlp/modules/common/tokenizer_utils.py index d3ee69f75b25..4cbadd87fe52 100644 --- a/nemo/collections/nlp/modules/common/tokenizer_utils.py +++ b/nemo/collections/nlp/modules/common/tokenizer_utils.py @@ -22,6 +22,7 @@ from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer from nemo.collections.common.tokenizers.regex_tokenizer import RegExTokenizer from nemo.collections.common.tokenizers.tabular_tokenizer import TabularTokenizer +from nemo.collections.common.tokenizers.tiktoken_tokenizer import TiktokenTokenizer from nemo.collections.common.tokenizers.word_tokenizer import WordTokenizer from nemo.collections.nlp.modules.common.huggingface.huggingface_utils import get_huggingface_pretrained_lm_models_list from nemo.collections.nlp.modules.common.lm_utils import get_pretrained_lm_models_list @@ -122,6 +123,8 @@ def get_tokenizer( legacy=True, chat_template=chat_template, ) + elif tokenizer_name == 'tiktoken': + return nemo.collections.common.tokenizers.tiktoken_tokenizer.TiktokenTokenizer(vocab_file=vocab_file) elif tokenizer_name == 'word': return WordTokenizer(vocab_file=vocab_file, **special_tokens_dict) elif tokenizer_name == 'char': @@ -221,6 +224,8 @@ def get_nmt_tokenizer( ) elif library == 'tabular': return TabularTokenizer(vocab_file, delimiter=delimiter) + elif library == 'tiktoken': + return TiktokenTokenizer(vocab_file=vocab_file) else: raise NotImplementedError( 'Currently we only support "huggingface", "sentencepiece", "megatron", and "byte-level" tokenizer' diff --git a/nemo/collections/nlp/parts/megatron_trainer_builder.py b/nemo/collections/nlp/parts/megatron_trainer_builder.py index f4276fd1b8f9..b2c85cde4e98 100644 --- a/nemo/collections/nlp/parts/megatron_trainer_builder.py +++ b/nemo/collections/nlp/parts/megatron_trainer_builder.py @@ -21,6 +21,7 @@ from pytorch_lightning.callbacks import ModelSummary from pytorch_lightning.plugins.environments import TorchElasticEnvironment +from nemo.collections.common.metrics.perf_metrics import FLOPsMeasurementCallback from nemo.collections.nlp.parts.nlp_overrides import ( CustomProgressBar, FSDPMixedPrecisionPlugin, @@ -173,6 +174,10 @@ def _callbacks(self, callbacks: Optional[list]) -> list: if self.cfg.get('exp_manager', {}).get('checkpoint_callback_params', {}).get('async_save', False): callbacks.append(AsyncFinalizerCallback()) + + if self.cfg.get('exp_manager', {}).get('log_tflops_per_sec_per_gpu', True): + callbacks.append(FLOPsMeasurementCallback(self.cfg)) + return callbacks def create_trainer(self, callbacks=None) -> Trainer: diff --git a/nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py b/nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py index 45f4af3cfbf3..a0446f290826 100644 --- a/nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py +++ b/nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py @@ -30,8 +30,13 @@ HAVE_MEGATRON_CORE = False -from nemo.collections.nlp.modules.common.megatron.adapters.parallel_adapters import PromptEncoderAdapterConfig +from nemo.collections.nlp.modules.common.megatron.adapters.parallel_adapters import ( + MLPHeadAdapterConfig, + PromptEncoderAdapterConfig, +) + from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector + from nemo.collections.nlp.parts.peft_config import ( PEFT_CONFIG_MAP, CanonicalAdaptersPEFTConfig, @@ -82,9 +87,11 @@ def __init__(self, *args, **kwargs): self.ptuning_only_and_non_first_stage = False super().__init__(*args, **kwargs) - self.use_mcore_gpt = hasattr(self, 'mcore_gpt') and self.mcore_gpt - if self.use_mcore_gpt: - assert HAVE_MEGATRON_CORE, "You set `mcore_gpt` as True but megatron core is not found." + self.use_mcore_gpt = getattr(self, 'mcore_gpt', False) + self.use_mcore_t5 = getattr(self, 'mcore_t5', False) + + if self.use_mcore_gpt or self.use_mcore_t5: + assert HAVE_MEGATRON_CORE, "You set `mcore_gpt` or `mcore_t5` as True but megatron core is not found." def _unwrap_model(self): if not hasattr(self, "model"): @@ -126,6 +133,8 @@ def _check_and_add_adapter(self, name, module, peft_name, peft_cfg, name_key_to_ mcore_target, f'model.{mcore_target}', f'model.module.{mcore_target}', + f'enc_dec_model.{mcore_target}', + f'enc_dec_model.module.{mcore_target}', ]: # simple string match for now if not isinstance(module, IdentityOp): swap_mcore_mixin(module, mcore_mixin) @@ -151,6 +160,11 @@ def _get_layers_from_model(self, model): layers = model.module.decoder.layers else: layers = model.decoder.layers + elif self.use_mcore_t5: + if self.cfg.megatron_amp_O2: + layers = model.module.encoder.layers + model.module.decoder.layers + else: + layers = model.encoder.layers + model.decoder.layers else: if self.cfg.megatron_amp_O2: layers = model.module.language_model.encoder.layers @@ -161,15 +175,20 @@ def _get_layers_from_model(self, model): def _check_and_add_peft_cfg(self, peft_cfg): layer_selection = peft_cfg.layer_selection - assert not self.use_mcore_gpt or hasattr( peft_cfg, 'name_key_to_mcore_mixins' ), f"{peft_cfg.__class__.__name__} is not supported in megatron core mode yet." - name_key_to_mcore_mixins = peft_cfg.name_key_to_mcore_mixins if self.use_mcore_gpt else None + name_key_to_mcore_mixins = ( + peft_cfg.name_key_to_mcore_mixins if (self.use_mcore_gpt or self.use_mcore_t5) else None + ) for adapter_name, adapter_cfg in peft_cfg.get_config_dict().items(): - # self.mcore_gpt means is GPT and not T5 - if hasattr(self, 'mcore_gpt') and not isinstance(adapter_cfg, PromptEncoderAdapterConfig): + # mixin for mcore models + if ( + (hasattr(self, 'mcore_gpt') or getattr(self, 'mcore_t5', False)) + and not isinstance(adapter_cfg, PromptEncoderAdapterConfig) + and not isinstance(adapter_cfg, MLPHeadAdapterConfig) + ): if layer_selection is not None: logging.info( f"Layer selection {layer_selection} is enabled for the current model (" @@ -204,8 +223,6 @@ def add_adapter(self, peft_cfgs: Union[PEFTConfig, List[PEFTConfig]]): peft_cfgs: One or more PEFTConfig objects that specify the PEFT method configuration """ - if self.cfg.get('virtual_pipeline_model_parallel_size', None): - raise ValueError('Virtual pipeline model parallel is not supported when using PEFT') if self.cfg.optim.name == "distributed_fused_adam": raise ValueError('distributed_fused_adam is not supported for PEFT. Please use fused_adam') @@ -352,8 +369,10 @@ def load_adapters( assert filepath.endswith( '.nemo' ), "Inferring peft scheme is only supported for .nemo checkpoints. Please supply the `peft_cfgs` argument." - peft_cfgs = [PEFT_CONFIG_MAP[conf.peft.peft_scheme](conf)] + peft_cfg_cls_lst = [PEFT_CONFIG_MAP[s] for s in conf.peft.peft_scheme.split(",")] + peft_cfgs = [_peft_cfg(conf) for _peft_cfg in peft_cfg_cls_lst] if getattr(self, 'megatron_amp_O2', False): + state_dict = {replace_prefix(k, 'model.', 'model.module.'): v for k, v in state_dict.items()} self.add_adapter(peft_cfgs) if not self.ptuning_only_and_non_first_stage: @@ -421,8 +440,8 @@ def state_dict(self, destination=None, prefix=None, keep_vars=False): return super().state_dict() def sharded_state_dict(self, prefix: str = ''): - use_mcore_gpt = hasattr(self, 'mcore_gpt') and self.mcore_gpt - if not use_mcore_gpt or (self.use_peft and self.setup_complete): + use_mcore = (getattr(self, 'mcore_gpt', False)) or (getattr(self, 'mcore_t5', False)) + if not use_mcore or (self.use_peft and self.setup_complete): return None else: return super().sharded_state_dict(prefix=prefix) @@ -450,7 +469,8 @@ def on_load_checkpoint(self, checkpoint) -> None: if not self.ptuning_only_and_non_first_stage: # same as super().on_load_checkpoint() but strict=False and only check unexpected keys # mcore uses distributed checkpointing - if hasattr(self, 'mcore_gpt') and self.mcore_gpt: + use_mcore = (getattr(self, 'mcore_gpt', False)) or (getattr(self, 'mcore_t5', False)) + if use_mcore: for index, module in enumerate(self.get_model_module_list()): if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None: checkpoint_state_dict = checkpoint['state_dict'][f'model_{index}'] diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index e251690831cb..b00b2ac28c3b 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -74,7 +74,6 @@ from nemo.utils.model_utils import ckpt_to_dir, inject_model_parallel_rank, uninject_model_parallel_rank try: - from apex.transformer.pipeline_parallel.utils import get_num_microbatches from nemo.core.optim.distributed_adam import MegatronDistributedFusedAdam from nemo.core.optim.mcore_optim import McoreDistributedOptimizer @@ -116,6 +115,13 @@ HAVE_MEGATRON_CORE = False +try: + from megatron.core.num_microbatches_calculator import get_num_microbatches + +except (ImportError, ModuleNotFoundError): + logging.warning("Megatron num_microbatches_calculator not found, using Apex version.") + from apex.transformer.pipeline_parallel.utils import get_num_microbatches + try: from modelopt.torch.opt.plugins import restore_sharded_modelopt_state, save_sharded_modelopt_state @@ -701,6 +707,7 @@ def __init__( nccl_communicator_config_path: Optional[str] = None, sharp: bool = False, set_buffer_dtype: Optional[str] = None, + extra_fsdp_wrap_module: Optional[set] = None, **kwargs: Union[Any, Dict[str, Any]], ) -> None: if not HAVE_APEX: @@ -730,6 +737,11 @@ def __init__( ParallelTransformerLayer, BasicTransformerBlock, } + + # if extra wrap modules are provided, use them + if extra_fsdp_wrap_module is not None: + self.fsdp_wrap_module.update(extra_fsdp_wrap_module) + kwargs['auto_wrap_policy'] = functools.partial( transformer_auto_wrap_policy, transformer_layer_cls=self.fsdp_wrap_module ) @@ -1306,8 +1318,14 @@ def dummy(): else: # Extract the nemo file into the temporary directory + filter_fn = None + if return_config: + filter_fn = lambda name: '.yaml' in name + members = self._filtered_tar_info(restore_path, filter_fn=filter_fn) self._unpack_nemo_file( - path2file=restore_path, out_folder=tmpdir, extract_config_only=return_config is True + path2file=restore_path, + out_folder=tmpdir, + members=members, ) # remove model weights extension tmp_model_weights_ckpt = os.path.join(tmpdir, self.model_weights_ckpt) diff --git a/nemo/collections/nlp/parts/peft_config.py b/nemo/collections/nlp/parts/peft_config.py index 50c97e349885..25f303fc22fb 100644 --- a/nemo/collections/nlp/parts/peft_config.py +++ b/nemo/collections/nlp/parts/peft_config.py @@ -24,6 +24,7 @@ MCoreMLPMixin, MCoreSelfAttentionMixin, MCoreSequentialMLPMixin, + MCoreTransformerBlockMixin, MCoreTransformerLayerMixin, ) except (ImportError, ModuleNotFoundError): @@ -41,6 +42,7 @@ LoraMoeHto4HAdapterConfig, LoraUnfusedHto4HAdapterConfig, LoraUnfusedKQVAdapterConfig, + MLPHeadAdapterConfig, MLPInfusedAdapterConfig, ParallelLinearAdapterConfig, ParallelLinearAdapterWeightTyingConfig, @@ -127,6 +129,21 @@ def __init__(self, cfg): self.tunable_base_param_names = selective_cfg.get("tunable_base_param_names", []) +class MLPHeadPEFTConfig(PEFTConfig): + def __init__(self, cfg): + config_args = {"in_features": cfg.hidden_size, "out_features": cfg.peft.mlp_head_tuning.out_features} + mlp_head_cfg = MLPHeadAdapterConfig(**config_args) + + name_key_to_cfg = { + AdapterName.MLP_HEAD_ADAPTER: mlp_head_cfg, + } + self.name_key_to_mcore_mixins = { + AdapterName.MLP_HEAD_ADAPTER: [("decoder", MCoreTransformerBlockMixin)], + } + + super().__init__(cfg.peft.mlp_head_tuning, name_key_to_cfg) + + class LoraPEFTConfig(PEFTConfig): def __init__(self, cfg): lora_cfg = cfg.peft.lora_tuning @@ -170,7 +187,7 @@ def __init__(self, cfg): elif module == PEFT_MODULE_MAP["dense_module"]: adapter_cfg = self._create_lora_config( - cfg, lora_cfg, cfg.hidden_size, cfg.hidden_size, LoraDenseAttentionAdapterConfig + cfg, lora_cfg, projection_size, cfg.hidden_size, LoraDenseAttentionAdapterConfig ) name_key_to_cfg[AdapterName.LORA_DENSE_ATTENTION_ADAPTER] = adapter_cfg name_key_to_mcore_mixins[AdapterName.LORA_DENSE_ATTENTION_ADAPTER] = [ @@ -401,6 +418,7 @@ def __init__(self, cfg): "ia3": IA3PEFTConfig, "ptuning": PtuningPEFTConfig, "lora": LoraPEFTConfig, + "mlp_head": MLPHeadPEFTConfig, "qlora": QLoraPEFTConfig, "selective": SelectivePEFTConfig, 'none': None, diff --git a/nemo/collections/tts/g2p/models/ja_jp_ipa.py b/nemo/collections/tts/g2p/models/ja_jp_ipa.py new file mode 100644 index 000000000000..c57d463b51b2 --- /dev/null +++ b/nemo/collections/tts/g2p/models/ja_jp_ipa.py @@ -0,0 +1,153 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pathlib +from collections import defaultdict +from typing import Dict, List, Optional, Union + +from nemo.collections.common.tokenizers.text_to_speech.ipa_lexicon import ( + get_grapheme_character_set, + get_ipa_punctuation_list, +) +from nemo.collections.tts.g2p.models.base import BaseG2p +from nemo.collections.tts.g2p.utils import set_grapheme_case +from nemo.utils import logging + + +class JapaneseG2p(BaseG2p): + def __init__( + self, + phoneme_dict: Union[str, pathlib.Path, Dict[str, List[str]]], + phoneme_prefix: str = "", + ascii_letter_prefix: str = "#", + ascii_letter_case: str = "upper", + word_tokenize_func=None, + apply_to_oov_word=None, + mapping_file: Optional[str] = None, + word_segmenter: Optional[str] = None, + ): + """ + Japanese G2P module. This module first segments Japanese characters into words using Janome, then + these separated words are converted into phoneme sequences by looking them up in the 'phoneme_dict'. + Args: + phoneme_dict (str, Path, Dict): Path to ja_JP_wordtoipa.txt dict file or a dict object. + phoneme_prefix (str): Prepend a special symbol to any phonemes in order to distinguish phonemes from + graphemes because there may be overlaps between the two sets. It is suggested to choose a prefix that + is not used or preserved somewhere else. Default to "#". + ascii_letter_prefix (str): Prepend a special symbol to any ASCII letters. Default to "". + ascii_letter_case (str): Specify the case chosen from `"lower"`, `"upper"`, or `"mixed"`, and process the + cases of non-Chinese words. Default to `"upper"`. + word_tokenize_func: Function for tokenizing text to words. + It has to return List[Tuple[Union[str, List[str]], bool]] where every tuple denotes word representation + and flag whether to leave unchanged or not. + It is expected that unchangeable word representation will be represented as List[str], other cases are + represented as str. + It is useful to mark word as unchangeable which is already in phoneme representation. + apply_to_oov_word: Function that will be applied to out of phoneme_dict word. + word_segmenter: method that will be applied to segment utterances into words for better polyphone disambiguation. + """ + assert phoneme_dict is not None, "Please set the phoneme_dict path." + assert word_segmenter in [ + None, + "janome", + ], f"{word_segmenter} is not supported now. Please choose correct word_segmenter." + + if phoneme_prefix is None: + phoneme_prefix = "" + if ascii_letter_prefix is None: + ascii_letter_prefix = "" + + # phonemes + phoneme_dict = ( + self._parse_ja_phoneme_dict(phoneme_dict, phoneme_prefix) + if isinstance(phoneme_dict, str) or isinstance(phoneme_dict, pathlib.Path) + else phoneme_dict + ) + self.phoneme_list = sorted({pron for prons in phoneme_dict.values() for pron in prons}) + + # ascii letters + self.ascii_letter_dict = { + x: ascii_letter_prefix + x for x in get_grapheme_character_set(locale="en-US", case=ascii_letter_case) + } + self.ascii_letter_list = sorted(self.ascii_letter_dict) + self.ascii_letter_case = ascii_letter_case + self.punctuation = get_ipa_punctuation_list('ja-JP') + + if apply_to_oov_word is None: + logging.warning( + "apply_to_oov_word=None, This means that some of words will remain unchanged " + "if they are not handled by any of the rules in self.parse_one_word(). " + "This may be intended if phonemes and chars are both valid inputs, otherwise, " + "you may see unexpected deletions in your input." + ) + + super().__init__( + phoneme_dict=phoneme_dict, + word_tokenize_func=word_tokenize_func, + apply_to_oov_word=apply_to_oov_word, + mapping_file=mapping_file, + ) + + if word_segmenter == "janome": + try: + from janome.tokenizer import Tokenizer + except ImportError as e: + logging.error(e) + + # Cut sentences into words to improve polyphone disambiguation + self.word_segmenter = Tokenizer().tokenize + else: + self.word_segmenter = lambda x: [x] + + @staticmethod + def _parse_ja_phoneme_dict( + phoneme_dict_path: Union[str, pathlib.Path], phoneme_prefix: str + ) -> Dict[str, List[str]]: + """Loads prondict dict file, and generates a set of all valid symbols.""" + g2p_dict = defaultdict(list) + with open(phoneme_dict_path, 'r') as file: + for line in file: + # skip empty lines and comment lines starting with `;;;`. + if line.startswith(";;;") or len(line.strip()) == 0: + continue + + word, pronunciation = line.rstrip().split(maxsplit=1) + + # add a prefix to distinguish phoneme symbols from non-phoneme symbols. + pronunciation_with_prefix = [phoneme_prefix + pron for pron in pronunciation] + g2p_dict[word] = pronunciation_with_prefix + + return g2p_dict + + def __call__(self, text: str) -> List[str]: + """ + This forward pass function translates Japanese characters into IPA phoneme sequences. + + For example, The text "こんにちは" would be converted as a list, + `['k', 'o', 'n', 'n', 'i', 't', 'ʃ', 'i', 'h', 'a']` + """ + text = set_grapheme_case(text, case=self.ascii_letter_case) + + words_list = self.word_segmenter(text) + phoneme_seq = [] + for token in words_list: + word = str(token).split("\t")[0] + if word in self.phoneme_dict.keys(): + phoneme_seq += self.phoneme_dict[word] + elif word in self.punctuation: + phoneme_seq += word + else: + logging.warning(f"{word} not found in the pronunciation dictionary. Returning graphemes instead.") + phoneme_seq += [c for c in word] + return phoneme_seq diff --git a/nemo/collections/tts/losses/audio_codec_loss.py b/nemo/collections/tts/losses/audio_codec_loss.py index 4454c46291a7..6db3e30595c6 100644 --- a/nemo/collections/tts/losses/audio_codec_loss.py +++ b/nemo/collections/tts/losses/audio_codec_loss.py @@ -19,7 +19,8 @@ from einops import rearrange from nemo.collections.asr.parts.preprocessing.features import FilterbankFeatures -from nemo.collections.tts.parts.utils.helpers import get_mask_from_lengths, mask_sequence_tensor +from nemo.collections.common.parts.utils import mask_sequence_tensor +from nemo.collections.tts.parts.utils.helpers import get_mask_from_lengths from nemo.core.classes import Loss, typecheck from nemo.core.neural_types import ( AudioSignal, @@ -312,7 +313,7 @@ def forward(self, audio_real, audio_gen, audio_len): # [B, 1] ref_pred = torch.sum(pred * target, dim=-1, keepdim=True) - ref_target = torch.sum(target ** 2, dim=-1, keepdim=True) + ref_target = torch.sum(target**2, dim=-1, keepdim=True) alpha = (ref_pred + self.epsilon) / (ref_target + self.epsilon) # [B, T] @@ -320,8 +321,8 @@ def forward(self, audio_real, audio_gen, audio_len): distortion = target_scaled - pred # [B] - target_scaled_power = torch.sum(target_scaled ** 2, dim=-1) - distortion_power = torch.sum(distortion ** 2, dim=-1) + target_scaled_power = torch.sum(target_scaled**2, dim=-1) + distortion_power = torch.sum(distortion**2, dim=-1) ratio = (target_scaled_power + self.epsilon) / (distortion_power + self.epsilon) si_sdr = 10 * torch.log10(ratio) @@ -505,7 +506,7 @@ def forward(self, disc_scores_real, disc_scores_gen): loss = 0.0 for disc_score_real, disc_score_gen in zip(disc_scores_real, disc_scores_gen): loss_real = torch.mean((1 - disc_score_real) ** 2) - loss_gen = torch.mean(disc_score_gen ** 2) + loss_gen = torch.mean(disc_score_gen**2) loss += (loss_real + loss_gen) / 2 loss /= len(disc_scores_real) diff --git a/nemo/collections/tts/losses/spectrogram_enhancer_losses.py b/nemo/collections/tts/losses/spectrogram_enhancer_losses.py index a77f42692b11..ff62fe80e9db 100644 --- a/nemo/collections/tts/losses/spectrogram_enhancer_losses.py +++ b/nemo/collections/tts/losses/spectrogram_enhancer_losses.py @@ -41,7 +41,7 @@ from einops import rearrange from torch.autograd import grad as torch_grad -from nemo.collections.tts.parts.utils.helpers import mask_sequence_tensor +from nemo.collections.common.parts.utils import mask_sequence_tensor class GradientPenaltyLoss(torch.nn.Module): diff --git a/nemo/collections/tts/models/audio_codec.py b/nemo/collections/tts/models/audio_codec.py index 04a6d2793f88..0c5e41157613 100644 --- a/nemo/collections/tts/models/audio_codec.py +++ b/nemo/collections/tts/models/audio_codec.py @@ -670,4 +670,18 @@ def list_available_models(cls) -> List[PretrainedModelInfo]: ) models.append(model) + model = PretrainedModelInfo( + pretrained_model_name="mel_codec_22khz_fullband_medium", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/mel_codec_22khz_fullband_medium/versions/v1/files/mel_codec_22khz_fullband_medium.nemo", + description="For details about this model please refer to the model card: https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/mel_codec_22khz_fullband_medium", + ) + models.append(model) + + model = PretrainedModelInfo( + pretrained_model_name="mel_codec_44khz_fullband_medium", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/mel_codec_44khz_fullband_medium/versions/v1/files/mel_codec_44khz_fullband_medium.nemo", + description="For details about this model please refer to the model card: https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/mel_codec_44khz_fullband_medium", + ) + models.append(model) + return models diff --git a/nemo/collections/tts/models/spectrogram_enhancer.py b/nemo/collections/tts/models/spectrogram_enhancer.py index 7115360e7125..65934d9a10ce 100644 --- a/nemo/collections/tts/models/spectrogram_enhancer.py +++ b/nemo/collections/tts/models/spectrogram_enhancer.py @@ -48,13 +48,14 @@ from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger from torch.utils.tensorboard.writer import SummaryWriter +from nemo.collections.common.parts.utils import mask_sequence_tensor from nemo.collections.tts.losses.spectrogram_enhancer_losses import ( ConsistencyLoss, GeneratorLoss, GradientPenaltyLoss, HingeLoss, ) -from nemo.collections.tts.parts.utils.helpers import mask_sequence_tensor, to_device_recursive +from nemo.collections.tts.parts.utils.helpers import to_device_recursive from nemo.core import Exportable, ModelPT, PretrainedModelInfo, typecheck from nemo.core.neural_types import LengthsType, MelSpectrogramType, NeuralType from nemo.core.neural_types.elements import BoolType @@ -128,7 +129,12 @@ def pad_spectrograms(self, spectrograms): } ) def forward( - self, *, input_spectrograms: torch.Tensor, lengths: torch.Tensor, mixing: bool = False, normalize: bool = True, + self, + *, + input_spectrograms: torch.Tensor, + lengths: torch.Tensor, + mixing: bool = False, + normalize: bool = True, ): """ Generator forward pass. Noise inputs will be generated. @@ -263,7 +269,10 @@ def training_step(self, batch, batch_idx, optimizer_idx): return g_loss + c_loss def configure_optimizers(self): - generator_opt = instantiate(self._cfg.generator_opt, params=self.generator.parameters(),) + generator_opt = instantiate( + self._cfg.generator_opt, + params=self.generator.parameters(), + ) discriminator_opt = instantiate(self._cfg.discriminator_opt, params=self.discriminator.parameters()) return [discriminator_opt, generator_opt], [] diff --git a/nemo/collections/tts/modules/audio_codec_modules.py b/nemo/collections/tts/modules/audio_codec_modules.py index 96029d9bd105..e9ed34732c36 100644 --- a/nemo/collections/tts/modules/audio_codec_modules.py +++ b/nemo/collections/tts/modules/audio_codec_modules.py @@ -23,7 +23,7 @@ from nemo.collections.asr.modules import AudioToMelSpectrogramPreprocessor from nemo.collections.asr.parts.utils.activations import Snake -from nemo.collections.tts.parts.utils.helpers import mask_sequence_tensor +from nemo.collections.common.parts.utils import mask_sequence_tensor from nemo.core.classes.common import typecheck from nemo.core.classes.module import NeuralModule from nemo.core.neural_types.elements import ( @@ -399,7 +399,9 @@ def encode(self, inputs: torch.Tensor, input_len: torch.Tensor) -> torch.Tensor: "indices": NeuralType(('D', 'B', 'T'), Index()), "input_len": NeuralType(tuple('B'), LengthsType()), }, - output_types={"dequantized": NeuralType(('B', 'D', 'T'), EncodedRepresentation()),}, + output_types={ + "dequantized": NeuralType(('B', 'D', 'T'), EncodedRepresentation()), + }, ) @abstractmethod def decode(self, indices: torch.Tensor, input_len: torch.Tensor) -> torch.Tensor: @@ -489,8 +491,7 @@ def round(inputs: torch.Tensor, input_len: torch.Tensor) -> torch.Tensor: return inputs + (inputs_rounded - inputs).detach() def compress(self, inputs: torch.Tensor, input_len: torch.Tensor) -> torch.Tensor: - """Apply compression to the input, to limit to values. - """ + """Apply compression to the input, to limit to values.""" output_scale = (self.num_levels - 1) / 2 # scale down a bit to avoid rounding issues output_scale = output_scale * (1 - self.eps) @@ -520,20 +521,17 @@ def inputs_to_codes(self, inputs: torch.Tensor, input_len: torch.Tensor) -> torc return codes def codes_to_nonnegative(self, codes: torch.Tensor) -> torch.Tensor: - """Convert values centered arouund zero to nonnegative values. - """ + """Convert values centered arouund zero to nonnegative values.""" scale = offset = self.num_levels // 2 return scale * codes + offset def nonnegative_to_codes(self, codes_nonnegative: torch.Tensor) -> torch.Tensor: - """Convert nonnegative values to values centered arouund zero. - """ + """Convert nonnegative values to values centered arouund zero.""" scale = offset = self.num_levels // 2 return (codes_nonnegative - offset) / scale def codes_to_indices(self, codes: torch.Tensor) -> torch.Tensor: - """Converts a code vector to a single index. - """ + """Converts a code vector to a single index.""" if codes.size(1) != self.dim: raise RuntimeError( f'Input code dimension {codes.size(1)} not matching the expected dimension {self.dim}, input codes shape {codes.shape}' @@ -575,8 +573,7 @@ def forward( output_types={"indices": NeuralType(('D', 'B', 'T'), Index())}, ) def encode(self, inputs: torch.Tensor, input_len: Optional[torch.Tensor] = None) -> torch.Tensor: - """Convert a continuous code vector to a single index. - """ + """Convert a continuous code vector to a single index.""" _, indices = self(inputs=inputs, input_len=input_len) return indices @@ -585,11 +582,12 @@ def encode(self, inputs: torch.Tensor, input_len: Optional[torch.Tensor] = None) "indices": NeuralType(('D', 'B', 'T'), Index()), "input_len": NeuralType(tuple('B'), LengthsType(), optional=True), }, - output_types={"dequantized": NeuralType(('B', 'D', 'T'), EncodedRepresentation()),}, + output_types={ + "dequantized": NeuralType(('B', 'D', 'T'), EncodedRepresentation()), + }, ) def decode(self, indices: torch.Tensor, input_len: Optional[torch.Tensor] = None) -> torch.Tensor: - """Convert a single index to a continuous code vector. - """ + """Convert a single index to a continuous code vector.""" if indices.size(0) > 1: # codebook dimension used for compatibility with RVQ raise ValueError( @@ -642,8 +640,7 @@ def __init__(self, num_groups: int, num_levels_per_group: List[int], **kwargs): @property def codebook_dim(self): - """Input vector dimension. - """ + """Input vector dimension.""" return self.codebook_dim_per_group * self.num_groups @property @@ -654,12 +651,11 @@ def codebook_size_per_group(self): @property def codebook_size(self): """Returns the size of the implicit codebook.""" - return self.codebook_size_per_group ** self.num_groups + return self.codebook_size_per_group**self.num_groups @typecheck() def forward(self, inputs, input_len): - """Quantize each group separately, then concatenate the results. - """ + """Quantize each group separately, then concatenate the results.""" inputs_grouped = inputs.chunk(self.num_groups, dim=1) dequantized, indices = [], [] @@ -685,8 +681,7 @@ def forward(self, inputs, input_len): output_types={"indices": NeuralType(('D', 'B', 'T'), Index())}, ) def encode(self, inputs: torch.Tensor, input_len: torch.Tensor) -> torch.Tensor: - """Input is split into groups, each group is encoded separately, then the results are concatenated. - """ + """Input is split into groups, each group is encoded separately, then the results are concatenated.""" inputs_grouped = inputs.chunk(self.num_groups, dim=1) indices = [] @@ -704,11 +699,12 @@ def encode(self, inputs: torch.Tensor, input_len: torch.Tensor) -> torch.Tensor: "indices": NeuralType(('D', 'B', 'T'), Index()), "input_len": NeuralType(tuple('B'), LengthsType()), }, - output_types={"dequantized": NeuralType(('B', 'D', 'T'), EncodedRepresentation()),}, + output_types={ + "dequantized": NeuralType(('B', 'D', 'T'), EncodedRepresentation()), + }, ) def decode(self, indices: torch.Tensor, input_len: torch.Tensor) -> torch.Tensor: - """Input indices are split into groups, each group is decoded separately, then the results are concatenated. - """ + """Input indices are split into groups, each group is decoded separately, then the results are concatenated.""" indices_grouped = indices.chunk(self.num_groups, dim=0) dequantized = [] diff --git a/nemo/collections/tts/modules/encodec_modules.py b/nemo/collections/tts/modules/encodec_modules.py index e93c7c799550..e9a1556ab700 100644 --- a/nemo/collections/tts/modules/encodec_modules.py +++ b/nemo/collections/tts/modules/encodec_modules.py @@ -43,6 +43,7 @@ from einops import rearrange, repeat from torch import Tensor +from nemo.collections.common.parts.utils import mask_sequence_tensor from nemo.collections.tts.losses.audio_codec_loss import MaskedMSELoss from nemo.collections.tts.modules.audio_codec_modules import ( CodecActivation, @@ -53,7 +54,6 @@ get_down_sample_padding, ) from nemo.collections.tts.parts.utils.distributed import broadcast_tensors -from nemo.collections.tts.parts.utils.helpers import mask_sequence_tensor from nemo.core.classes.common import typecheck from nemo.core.classes.module import NeuralModule from nemo.core.neural_types.elements import AudioSignal, EncodedRepresentation, Index, LengthsType, LossType, VoidType @@ -266,7 +266,10 @@ def __init__( out_channels = in_channels // 2 kernel_size = 2 * up_sample_rate up_sample_conv = ConvTranspose1dNorm( - in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=up_sample_rate, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=up_sample_rate, ) in_channels = out_channels self.up_sample_conv_layers.append(up_sample_conv) @@ -681,7 +684,10 @@ def encode(self, inputs, input_len): return indices @typecheck( - input_types={"indices": NeuralType(('B', 'T'), Index()), "input_len": NeuralType(tuple('B'), LengthsType()),}, + input_types={ + "indices": NeuralType(('B', 'T'), Index()), + "input_len": NeuralType(tuple('B'), LengthsType()), + }, output_types={"dequantized": NeuralType(('B', 'T', 'D'), EncodedRepresentation())}, ) def decode(self, indices, input_len): @@ -801,7 +807,9 @@ def encode(self, inputs: Tensor, input_len: Tensor) -> Tensor: "indices": NeuralType(('D', 'B', 'T'), Index()), "input_len": NeuralType(tuple('B'), LengthsType()), }, - output_types={"dequantized": NeuralType(('B', 'D', 'T'), EncodedRepresentation()),}, + output_types={ + "dequantized": NeuralType(('B', 'D', 'T'), EncodedRepresentation()), + }, ) def decode(self, indices: Tensor, input_len: Tensor) -> Tensor: # [B, T, D] @@ -852,8 +860,7 @@ def __init__(self, num_codebooks: int, num_groups: int, codebook_dim: int, **kwa @property def num_codebooks_per_group(self): - """Number of codebooks for each group. - """ + """Number of codebooks for each group.""" if self.num_codebooks % self.num_groups != 0: raise ValueError( f'num_codebooks ({self.num_codebooks}) must be divisible by num_groups ({self.num_groups})' @@ -863,8 +870,7 @@ def num_codebooks_per_group(self): @property def codebook_dim_per_group(self): - """Input vector dimension for each group. - """ + """Input vector dimension for each group.""" if self.codebook_dim % self.num_groups != 0: raise ValueError(f'codebook_dim ({self.codebook_dim}) must be divisible by num_groups ({self.num_groups})') @@ -881,8 +887,7 @@ def output_types(self): @typecheck() def forward(self, inputs, input_len): - """Quantize each group separately, then concatenate the results. - """ + """Quantize each group separately, then concatenate the results.""" inputs_grouped = inputs.chunk(self.num_groups, dim=1) dequantized, indices = [], [] @@ -910,8 +915,7 @@ def forward(self, inputs, input_len): output_types={"indices": NeuralType(('D', 'B', 'T'), Index())}, ) def encode(self, inputs: Tensor, input_len: Tensor) -> Tensor: - """Input is split into groups, each group is encoded separately, then the results are concatenated. - """ + """Input is split into groups, each group is encoded separately, then the results are concatenated.""" inputs_grouped = inputs.chunk(self.num_groups, dim=1) indices = [] @@ -929,11 +933,12 @@ def encode(self, inputs: Tensor, input_len: Tensor) -> Tensor: "indices": NeuralType(('D', 'B', 'T'), Index()), "input_len": NeuralType(tuple('B'), LengthsType()), }, - output_types={"dequantized": NeuralType(('B', 'D', 'T'), EncodedRepresentation()),}, + output_types={ + "dequantized": NeuralType(('B', 'D', 'T'), EncodedRepresentation()), + }, ) def decode(self, indices: Tensor, input_len: Tensor) -> Tensor: - """Input indices are split into groups, each group is decoded separately, then the results are concatenated. - """ + """Input indices are split into groups, each group is decoded separately, then the results are concatenated.""" indices_grouped = indices.chunk(self.num_groups, dim=0) dequantized = [] diff --git a/nemo/collections/tts/modules/spectrogram_enhancer.py b/nemo/collections/tts/modules/spectrogram_enhancer.py index 2cc88264a7d2..20866363d869 100644 --- a/nemo/collections/tts/modules/spectrogram_enhancer.py +++ b/nemo/collections/tts/modules/spectrogram_enhancer.py @@ -46,7 +46,7 @@ from einops import rearrange from kornia.filters import filter2d -from nemo.collections.tts.parts.utils.helpers import mask_sequence_tensor +from nemo.collections.common.parts.utils import mask_sequence_tensor class Blur(torch.nn.Module): @@ -99,7 +99,10 @@ def __init__(self, latent_dim, input_channel, upsample, channels=3): self.conv = Conv2DModulated(input_channel, out_filters, 1, demod=False) self.upsample = ( - torch.nn.Sequential(torch.nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False), Blur(),) + torch.nn.Sequential( + torch.nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False), + Blur(), + ) if upsample else None ) @@ -125,7 +128,15 @@ class Conv2DModulated(torch.nn.Module): """ def __init__( - self, in_chan, out_chan, kernel, demod=True, stride=1, dilation=1, eps=1e-8, **kwargs, + self, + in_chan, + out_chan, + kernel, + demod=True, + stride=1, + dilation=1, + eps=1e-8, + **kwargs, ): super().__init__() self.filters = out_chan @@ -148,7 +159,7 @@ def forward(self, x, y): weights = w2 * (w1 + 1) if self.demod: - d = torch.rsqrt((weights ** 2).sum(dim=(2, 3, 4), keepdim=True) + self.eps) + d = torch.rsqrt((weights**2).sum(dim=(2, 3, 4), keepdim=True) + self.eps) weights = weights * d x = x.reshape(1, -1, h, w) @@ -165,7 +176,13 @@ def forward(self, x, y): class GeneratorBlock(torch.nn.Module): def __init__( - self, latent_dim, input_channels, filters, upsample=True, upsample_rgb=True, channels=1, + self, + latent_dim, + input_channels, + filters, + upsample=True, + upsample_rgb=True, + channels=1, ): super().__init__() self.upsample = torch.nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False) if upsample else None @@ -257,7 +274,12 @@ def __init__( not_last = ind != (self.num_layers - 1) block = GeneratorBlock( - latent_dim, in_chan, out_chan, upsample=not_first, upsample_rgb=not_last, channels=channels, + latent_dim, + in_chan, + out_chan, + upsample=not_first, + upsample_rgb=not_last, + channels=channels, ) self.blocks.append(block) @@ -315,14 +337,18 @@ def forward(self, condition: torch.Tensor, lengths: torch.Tensor, ws: List[torch class Discriminator(torch.nn.Module): def __init__( - self, n_bands, network_capacity=16, channels=1, fmap_max=512, + self, + n_bands, + network_capacity=16, + channels=1, + fmap_max=512, ): super().__init__() num_layers = int(log2(n_bands) - 1) num_init_filters = channels blocks = [] - filters = [num_init_filters] + [(network_capacity * 4) * (2 ** i) for i in range(num_layers + 1)] + filters = [num_init_filters] + [(network_capacity * 4) * (2**i) for i in range(num_layers + 1)] set_fmap_max = partial(min, fmap_max) filters = list(map(set_fmap_max, filters)) diff --git a/nemo/collections/tts/parts/utils/helpers.py b/nemo/collections/tts/parts/utils/helpers.py index 08d31390107b..a4c65f9ed0e5 100644 --- a/nemo/collections/tts/parts/utils/helpers.py +++ b/nemo/collections/tts/parts/utils/helpers.py @@ -123,23 +123,26 @@ def binarize_attention(attn, in_len, out_len): def binarize_attention_parallel(attn, in_lens, out_lens): """For training purposes only. Binarizes attention with MAS. - These will no longer receive a gradient. + These will no longer receive a gradient. - Args: - attn: B x 1 x max_mel_len x max_text_len - """ + Args: + attn: B x 1 x max_mel_len x max_text_len + """ with torch.no_grad(): log_attn_cpu = torch.log(attn.data).cpu().numpy() attn_out = b_mas(log_attn_cpu, in_lens.cpu().numpy(), out_lens.cpu().numpy(), width=1) return torch.from_numpy(attn_out).to(attn.device) -def get_mask_from_lengths(lengths: Optional[torch.Tensor] = None, x: Optional[torch.Tensor] = None,) -> torch.Tensor: +def get_mask_from_lengths( + lengths: Optional[torch.Tensor] = None, + x: Optional[torch.Tensor] = None, +) -> torch.Tensor: """Constructs binary mask from a 1D torch tensor of input lengths Args: lengths: Optional[torch.tensor] (torch.tensor): 1D tensor with lengths - x: Optional[torch.tensor] = tensor to be used on, last dimension is for mask + x: Optional[torch.tensor] = tensor to be used on, last dimension is for mask Returns: mask (torch.tensor): num_sequences x max_length binary tensor """ @@ -168,7 +171,7 @@ def sort_tensor( context: tensor sorted by lens along dimension dim lens_sorted: lens tensor, sorted ids_sorted: reorder ids to be used to restore original order - + """ lens_sorted, ids_sorted = torch.sort(lens, descending=descending) context = torch.index_select(context, dim, ids_sorted) @@ -177,13 +180,13 @@ def sort_tensor( def unsort_tensor(ordered: torch.Tensor, indices: torch.Tensor, dim: Optional[int] = 0) -> torch.Tensor: """Reverses the result of sort_tensor function: - o, _, ids = sort_tensor(x,l) + o, _, ids = sort_tensor(x,l) assert unsort_tensor(o,ids) == x Args: ordered: context tensor, sorted by lengths indices: torch.tensor: 1D tensor with 're-order' indices returned by sort_tensor Returns: - ordered tensor in original order (before calling sort_tensor) + ordered tensor in original order (before calling sort_tensor) """ return torch.index_select(ordered, dim, indices.argsort(0)) @@ -294,7 +297,7 @@ def log_audio_to_tb( log_mel = spect.data.cpu().numpy().T mel = np.exp(log_mel) magnitude = np.dot(mel, filterbank) * griffin_lim_mag_scale - audio = griffin_lim(magnitude.T ** griffin_lim_power) + audio = griffin_lim(magnitude.T**griffin_lim_power) swriter.add_audio(name, audio / max(np.abs(audio)), step, sample_rate=sr) @@ -317,10 +320,16 @@ def tacotron2_log_to_tb_func( _, spec_target, mel_postnet, gate, gate_target, alignments = tensors if log_images and step % log_images_freq == 0: swriter.add_image( - f"{tag}_alignment", plot_alignment_to_numpy(alignments[0].data.cpu().numpy().T), step, dataformats="HWC", + f"{tag}_alignment", + plot_alignment_to_numpy(alignments[0].data.cpu().numpy().T), + step, + dataformats="HWC", ) swriter.add_image( - f"{tag}_mel_target", plot_spectrogram_to_numpy(spec_target[0].data.cpu().numpy()), step, dataformats="HWC", + f"{tag}_mel_target", + plot_spectrogram_to_numpy(spec_target[0].data.cpu().numpy()), + step, + dataformats="HWC", ) swriter.add_image( f"{tag}_mel_predicted", @@ -330,7 +339,10 @@ def tacotron2_log_to_tb_func( ) swriter.add_image( f"{tag}_gate", - plot_gate_outputs_to_numpy(gate_target[0].data.cpu().numpy(), torch.sigmoid(gate[0]).data.cpu().numpy(),), + plot_gate_outputs_to_numpy( + gate_target[0].data.cpu().numpy(), + torch.sigmoid(gate[0]).data.cpu().numpy(), + ), step, dataformats="HWC", ) @@ -340,13 +352,13 @@ def tacotron2_log_to_tb_func( log_mel = mel_postnet[0].data.cpu().numpy().T mel = np.exp(log_mel) magnitude = np.dot(mel, filterbank) * griffin_lim_mag_scale - audio = griffin_lim(magnitude.T ** griffin_lim_power) + audio = griffin_lim(magnitude.T**griffin_lim_power) swriter.add_audio(f"audio/{tag}_predicted", audio / max(np.abs(audio)), step, sample_rate=sr) log_mel = spec_target[0].data.cpu().numpy().T mel = np.exp(log_mel) magnitude = np.dot(mel, filterbank) * griffin_lim_mag_scale - audio = griffin_lim(magnitude.T ** griffin_lim_power) + audio = griffin_lim(magnitude.T**griffin_lim_power) swriter.add_audio(f"audio/{tag}_target", audio / max(np.abs(audio)), step, sample_rate=sr) @@ -373,16 +385,26 @@ def tacotron2_log_to_wandb_func( specs = [] gates = [] alignments += [ - wandb.Image(plot_alignment_to_numpy(alignments[0].data.cpu().numpy().T), caption=f"{tag}_alignment",) + wandb.Image( + plot_alignment_to_numpy(alignments[0].data.cpu().numpy().T), + caption=f"{tag}_alignment", + ) ] alignments += [ - wandb.Image(plot_spectrogram_to_numpy(spec_target[0].data.cpu().numpy()), caption=f"{tag}_mel_target",), - wandb.Image(plot_spectrogram_to_numpy(mel_postnet[0].data.cpu().numpy()), caption=f"{tag}_mel_predicted",), + wandb.Image( + plot_spectrogram_to_numpy(spec_target[0].data.cpu().numpy()), + caption=f"{tag}_mel_target", + ), + wandb.Image( + plot_spectrogram_to_numpy(mel_postnet[0].data.cpu().numpy()), + caption=f"{tag}_mel_predicted", + ), ] gates += [ wandb.Image( plot_gate_outputs_to_numpy( - gate_target[0].data.cpu().numpy(), torch.sigmoid(gate[0]).data.cpu().numpy(), + gate_target[0].data.cpu().numpy(), + torch.sigmoid(gate[0]).data.cpu().numpy(), ), caption=f"{tag}_gate", ) @@ -396,16 +418,24 @@ def tacotron2_log_to_wandb_func( log_mel = mel_postnet[0].data.cpu().numpy().T mel = np.exp(log_mel) magnitude = np.dot(mel, filterbank) * griffin_lim_mag_scale - audio_pred = griffin_lim(magnitude.T ** griffin_lim_power) + audio_pred = griffin_lim(magnitude.T**griffin_lim_power) log_mel = spec_target[0].data.cpu().numpy().T mel = np.exp(log_mel) magnitude = np.dot(mel, filterbank) * griffin_lim_mag_scale - audio_true = griffin_lim(magnitude.T ** griffin_lim_power) + audio_true = griffin_lim(magnitude.T**griffin_lim_power) audios += [ - wandb.Audio(audio_true / max(np.abs(audio_true)), caption=f"{tag}_wav_target", sample_rate=sr,), - wandb.Audio(audio_pred / max(np.abs(audio_pred)), caption=f"{tag}_wav_predicted", sample_rate=sr,), + wandb.Audio( + audio_true / max(np.abs(audio_true)), + caption=f"{tag}_wav_target", + sample_rate=sr, + ), + wandb.Audio( + audio_pred / max(np.abs(audio_pred)), + caption=f"{tag}_wav_predicted", + sample_rate=sr, + ), ] swriter.log({"audios": audios}) @@ -505,10 +535,22 @@ def create_plot(data, x_axis, y_axis, output_filepath=None): def plot_gate_outputs_to_numpy(gate_targets, gate_outputs): fig, ax = plt.subplots(figsize=(12, 3)) ax.scatter( - range(len(gate_targets)), gate_targets, alpha=0.5, color='green', marker='+', s=1, label='target', + range(len(gate_targets)), + gate_targets, + alpha=0.5, + color='green', + marker='+', + s=1, + label='target', ) ax.scatter( - range(len(gate_outputs)), gate_outputs, alpha=0.5, color='red', marker='.', s=1, label='predicted', + range(len(gate_outputs)), + gate_outputs, + alpha=0.5, + color='red', + marker='.', + s=1, + label='predicted', ) plt.xlabel("Frames (Green target, Red predicted)") @@ -530,24 +572,40 @@ def save_figure_to_numpy(fig): @rank_zero_only def waveglow_log_to_tb_func( - swriter, tensors, step, tag="train", n_fft=1024, hop_length=256, window="hann", mel_fb=None, + swriter, + tensors, + step, + tag="train", + n_fft=1024, + hop_length=256, + window="hann", + mel_fb=None, ): _, audio_pred, spec_target, mel_length = tensors mel_length = mel_length[0] spec_target = spec_target[0].data.cpu().numpy()[:, :mel_length] swriter.add_image( - f"{tag}_mel_target", plot_spectrogram_to_numpy(spec_target), step, dataformats="HWC", + f"{tag}_mel_target", + plot_spectrogram_to_numpy(spec_target), + step, + dataformats="HWC", ) if mel_fb is not None: mag, _ = librosa.core.magphase( librosa.core.stft( - np.nan_to_num(audio_pred[0].cpu().detach().numpy()), n_fft=n_fft, hop_length=hop_length, window=window, + np.nan_to_num(audio_pred[0].cpu().detach().numpy()), + n_fft=n_fft, + hop_length=hop_length, + window=window, ) ) mel_pred = np.matmul(mel_fb.cpu().numpy(), mag).squeeze() log_mel_pred = np.log(np.clip(mel_pred, a_min=1e-5, a_max=None)) swriter.add_image( - f"{tag}_mel_predicted", plot_spectrogram_to_numpy(log_mel_pred[:, :mel_length]), step, dataformats="HWC", + f"{tag}_mel_predicted", + plot_spectrogram_to_numpy(log_mel_pred[:, :mel_length]), + step, + dataformats="HWC", ) @@ -560,7 +618,12 @@ def remove(conv_list): def regulate_len( - durations, enc_out, pace=1.0, mel_max_len=None, group_size=1, dur_lens: torch.tensor = None, + durations, + enc_out, + pace=1.0, + mel_max_len=None, + group_size=1, + dur_lens: torch.tensor = None, ): """A function that takes predicted durations per encoded token, and repeats enc_out according to the duration. NOTE: durations.shape[1] == enc_out.shape[1] @@ -724,30 +787,6 @@ def to_device_recursive(e, device: torch.device): return e -def mask_sequence_tensor(tensor: torch.Tensor, lengths: torch.Tensor): - """ - For tensors containing sequences, zero out out-of-bound elements given lengths of every element in the batch. - - tensor: tensor of shape (B, D, L) or (B, D1, D2, L), - lengths: LongTensor of shape (B,) - """ - batch_size, *_, max_lengths = tensor.shape - - if len(tensor.shape) == 2: - mask = torch.ones(batch_size, max_lengths).cumsum(dim=-1).type_as(lengths) - mask = mask <= rearrange(lengths, "b -> b 1") - elif len(tensor.shape) == 3: - mask = torch.ones(batch_size, 1, max_lengths).cumsum(dim=-1).type_as(lengths) - mask = mask <= rearrange(lengths, "b -> b 1 1") - elif len(tensor.shape) == 4: - mask = torch.ones(batch_size, 1, 1, max_lengths).cumsum(dim=-1).type_as(lengths) - mask = mask <= rearrange(lengths, "b -> b 1 1 1") - else: - raise ValueError("Can only mask tensors of shape B x D x L and B x D1 x D2 x L") - - return tensor * mask - - @torch.jit.script def batch_from_ragged( text: torch.Tensor, @@ -786,13 +825,16 @@ def batch_from_ragged( def sample_tts_input( - export_config, device, max_batch=1, max_dim=127, + export_config, + device, + max_batch=1, + max_dim=127, ): """ - Generates input examples for tracing etc. - Returns: - A tuple of input examples. - """ + Generates input examples for tracing etc. + Returns: + A tuple of input examples. + """ sz = (max_batch * max_dim,) if export_config["enable_ragged_batches"] else (max_batch, max_dim) inp = torch.randint(*export_config["emb_range"], sz, device=device, dtype=torch.int64) pitch = torch.randn(sz, device=device, dtype=torch.float32) * 0.5 diff --git a/nemo/collections/vision/models/megatron_vit_classification_models.py b/nemo/collections/vision/models/megatron_vit_classification_models.py index 46788d2c882c..5cffdd6d12a3 100644 --- a/nemo/collections/vision/models/megatron_vit_classification_models.py +++ b/nemo/collections/vision/models/megatron_vit_classification_models.py @@ -40,15 +40,6 @@ from nemo.core.classes.common import PretrainedModelInfo from nemo.utils import logging -try: - from apex.transformer.pipeline_parallel.utils import get_num_microbatches - - HAVE_APEX = True - -except (ImportError, ModuleNotFoundError): - - HAVE_APEX = False - try: from megatron.core import parallel_state from megatron.core.pipeline_parallel.schedules import get_forward_backward_func @@ -59,6 +50,13 @@ HAVE_MEGATRON_CORE = False +try: + from megatron.core.num_microbatches_calculator import get_num_microbatches + +except (ImportError, ModuleNotFoundError): + logging.warning("Megatron num_microbatches_calculator not found, using Apex version.") + from apex.transformer.pipeline_parallel.utils import get_num_microbatches + class VitClassificationModel(MegatronModule): """Vision Transformer Model.""" @@ -113,10 +111,6 @@ class MegatronVitClassificationModel(MegatronBaseModel): """Megatron Vision Transformer Model.""" def __init__(self, cfg: DictConfig, trainer: Trainer): - if not HAVE_APEX: - raise ImportError( - "Apex was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." - ) if not HAVE_MEGATRON_CORE: raise ImportError( "megatron-core was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." @@ -286,7 +280,10 @@ def fwd_bwd_step(self, dataloader_iter, forward_only): grad_sync_func = None param_sync_func = None if not forward_only and self.with_distributed_adam: - no_sync_func = partial(self._optimizer.no_sync, greedy_grad_copy=self.megatron_amp_O2,) + no_sync_func = partial( + self._optimizer.no_sync, + greedy_grad_copy=self.megatron_amp_O2, + ) grad_sync_func = self.reduce_overlap_gradients param_sync_func = self.sync_overlap_parameters @@ -357,12 +354,12 @@ def initialize_ub_func(self): def training_step(self, dataloader_iter): """ - Our dataloaders produce a micro-batch and then we fetch - a number of microbatches depending on the global batch size and model parallel size - from the dataloader to produce a list of microbatches. - Batch should be a list of microbatches and those microbatches should on CPU. - Microbatches are then moved to GPU during the pipeline. - The list of microbatches is then piped through the pipeline using Apex fwd/bwd functions. + Our dataloaders produce a micro-batch and then we fetch + a number of microbatches depending on the global batch size and model parallel size + from the dataloader to produce a list of microbatches. + Batch should be a list of microbatches and those microbatches should on CPU. + Microbatches are then moved to GPU during the pipeline. + The list of microbatches is then piped through the pipeline using Apex fwd/bwd functions. """ # Initialize userbuffer communicators. if self.initialize_ub: @@ -425,20 +422,20 @@ def training_step(self, dataloader_iter): return loss_mean def backward(self, *args, **kwargs): - """ LightningModule hook to do backward. - We want this to do nothing since we run backward in the fwd/bwd functions from apex. - No need to call it here. + """LightningModule hook to do backward. + We want this to do nothing since we run backward in the fwd/bwd functions from apex. + No need to call it here. """ pass def optimizer_zero_grad(self, *args, **kwargs): - """ LightningModule hook to zero grad. - We want this to do nothing as we are zeroing grads during the training_step. + """LightningModule hook to zero grad. + We want this to do nothing as we are zeroing grads during the training_step. """ pass def _append_sequence_parallel_module_grads(self, module, grads): - """ Helper method for allreduce_sequence_parallel_gradients""" + """Helper method for allreduce_sequence_parallel_gradients""" for param in module.parameters(): sequence_parallel_param = getattr(param, 'sequence_parallel', False) @@ -450,9 +447,9 @@ def _append_sequence_parallel_module_grads(self, module, grads): grads.append(grad.data) def allreduce_sequence_parallel_gradients(self): - """ All-reduce layernorm parameters across model parallel nodes when sequence parallelism is used. - Modified from megatron-lm: - https://gitlab-master.nvidia.com/ADLR/megatron-lm/-/blob/3f91f09bb2ab32f9904b47f46f19d2fc3f518ed8/megatron/training.py#L425 + """All-reduce layernorm parameters across model parallel nodes when sequence parallelism is used. + Modified from megatron-lm: + https://gitlab-master.nvidia.com/ADLR/megatron-lm/-/blob/3f91f09bb2ab32f9904b47f46f19d2fc3f518ed8/megatron/training.py#L425 """ grads = [] @@ -512,10 +509,10 @@ def fwd_output_only_func(batch, model): def validation_step(self, dataloader_iter): """ - Our dataloaders produce a micro-batch and then we fetch - a number of microbatches depending on the global batch size and model parallel size - from the dataloader to produce a list of microbatches. - The list of microbatches is then piped through the pipeline using megatron-core fwd/bwd functions. + Our dataloaders produce a micro-batch and then we fetch + a number of microbatches depending on the global batch size and model parallel size + from the dataloader to produce a list of microbatches. + The list of microbatches is then piped through the pipeline using megatron-core fwd/bwd functions. """ mode = 'test' if self.trainer.testing else 'val' @@ -525,8 +522,10 @@ def validation_step(self, dataloader_iter): loss, accuracy = self.fwd_bwd_step(dataloader_iter, True) - self.validation_step_outputs.append((loss, accuracy)) if mode == 'val' else self.test_step_outputs.append( - (loss, accuracy) + ( + self.validation_step_outputs.append((loss, accuracy)) + if mode == 'val' + else self.test_step_outputs.append((loss, accuracy)) ) return loss, accuracy @@ -569,7 +568,9 @@ def build_train_valid_test_datasets(self): raise ValueError("limit_val_batches must be an integer or float less than or equal to 1.0.") self._train_ds, self._validation_ds = build_train_valid_datasets( - model_cfg=self.cfg, data_path=self.cfg.data.data_path, image_size=(self.cfg.img_h, self.cfg.img_w), + model_cfg=self.cfg, + data_path=self.cfg.data.data_path, + image_size=(self.cfg.img_h, self.cfg.img_w), ) self._test_ds = None @@ -709,16 +710,16 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] raise NotImplementedError def transfer_batch_to_device(self, batch: Any, device: torch.device, dataloader_idx: int) -> Any: - """ PTL hook: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#transfer-batch-to-device - When using pipeline parallelism, we need the global batch to remain on the CPU, - since the memory overhead will be too high when using a large number of microbatches. - Microbatches are transferred from CPU to GPU inside the pipeline. + """PTL hook: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#transfer-batch-to-device + When using pipeline parallelism, we need the global batch to remain on the CPU, + since the memory overhead will be too high when using a large number of microbatches. + Microbatches are transferred from CPU to GPU inside the pipeline. """ return batch def _validate_trainer(self): - """ Certain trainer configurations can break training. - Here we try to catch them and raise an error. + """Certain trainer configurations can break training. + Here we try to catch them and raise an error. """ if self.trainer.accumulate_grad_batches > 1: raise ValueError( diff --git a/nemo/core/classes/mixins/adapter_mixins.py b/nemo/core/classes/mixins/adapter_mixins.py index 05ac9b429d85..7b5d02c86bf7 100644 --- a/nemo/core/classes/mixins/adapter_mixins.py +++ b/nemo/core/classes/mixins/adapter_mixins.py @@ -391,6 +391,14 @@ def get_adapter_module(self, name: str): return self.adapter_layer[name] if name in self.adapter_layer else None return None + def get_adapter_cfg(self, name: str): + """Same logic as `get_adapter_module` but to get the config""" + _, name = self.resolve_adapter_module_name_(name) + + if hasattr(self, "adapter_cfg"): + return self.adapter_cfg[name] if name in self.adapter_cfg else None + return None + def set_accepted_adapter_types(self, adapter_types: List[Union[type, str]]) -> None: """ The module with this mixin can define a list of adapter names that it will accept. diff --git a/nemo/core/classes/mixins/hf_io_mixin.py b/nemo/core/classes/mixins/hf_io_mixin.py index b101cbabe749..543d6c6fccda 100644 --- a/nemo/core/classes/mixins/hf_io_mixin.py +++ b/nemo/core/classes/mixins/hf_io_mixin.py @@ -14,9 +14,9 @@ from abc import ABC from pathlib import Path -from typing import Dict, Iterable, List, Optional, Union +from typing import Any, Dict, Iterable, List, Optional, Union -from huggingface_hub import HfApi, ModelCard, ModelCardData, ModelFilter +from huggingface_hub import HfApi, ModelCard, ModelCardData from huggingface_hub import get_token as get_hf_token from huggingface_hub.hf_api import ModelInfo from huggingface_hub.utils import SoftTemporaryDirectory @@ -35,31 +35,35 @@ class HuggingFaceFileIO(ABC): """ @classmethod - def get_hf_model_filter(cls) -> ModelFilter: + def get_hf_model_filter(cls) -> Dict[str, Any]: """ Generates a filter for HuggingFace models. - Additionally includes default values of some metadata about results returned by the Hub. + Additionaly includes default values of some metadata about results returned by the Hub. Metadata: resolve_card_info: Bool flag, if set, returns the model card metadata. Default: False. limit_results: Optional int, limits the number of results returned. Returns: - A Hugging Face Hub ModelFilter object. + A dict representing the arguments passable to huggingface list_models(). """ - model_filter = ModelFilter(library='nemo') - - # Attach some additional info - model_filter.resolve_card_info = False - model_filter.limit_results = None + model_filter = dict( + author=None, + library='nemo', + language=None, + model_name=None, + task=None, + tags=None, + limit=None, + full=None, + cardData=False, + ) return model_filter @classmethod - def search_huggingface_models( - cls, model_filter: Optional[Union[ModelFilter, List[ModelFilter]]] = None - ) -> List['ModelInfo']: + def search_huggingface_models(cls, model_filter: Optional[Dict[str, Any]] = None) -> Iterable['ModelInfo']: """ Should list all pre-trained models available via Hugging Face Hub. @@ -75,16 +79,16 @@ def search_huggingface_models( # You can replace with any subclass of ModelPT. from nemo.core import ModelPT - # Get default ModelFilter + # Get default filter dict filt = .get_hf_model_filter() # Make any modifications to the filter as necessary - filt.language = [...] - filt.task = ... - filt.tags = [...] + filt['language'] = [...] + filt['task'] = ... + filt['tags'] = [...] - # Add any metadata to the filter as needed - filt.limit_results = 5 + # Add any metadata to the filter as needed (kwargs to list_models) + filt['limit'] = 5 # Obtain model info model_infos = .search_huggingface_models(model_filter=filt) @@ -96,10 +100,9 @@ def search_huggingface_models( model = ModelPT.from_pretrained(card.modelId) Args: - model_filter: Optional ModelFilter or List[ModelFilter] (from Hugging Face Hub) + model_filter: Optional Dictionary (for Hugging Face Hub kwargs) that filters the returned list of compatible model cards, and selects all results from each filter. Users can then use `model_card.modelId` in `from_pretrained()` to restore a NeMo Model. - If no ModelFilter is provided, uses the classes default filter as defined by `get_hf_model_filter()`. Returns: A list of ModelInfo entries. @@ -108,23 +111,6 @@ def search_huggingface_models( if model_filter is None: model_filter = cls.get_hf_model_filter() - # If single model filter, wrap into list - if not isinstance(model_filter, Iterable): - model_filter = [model_filter] - - # Inject `nemo` library filter - for mfilter in model_filter: - if isinstance(mfilter.library, str) and mfilter.library != 'nemo': - logging.warning(f"Model filter's `library` tag updated be `nemo`. Original value: {mfilter.library}") - mfilter.library = "nemo" - - elif isinstance(mfilter, Iterable) and 'nemo' not in mfilter.library: - logging.warning( - f"Model filter's `library` list updated to include `nemo`. Original value: {mfilter.library}" - ) - mfilter.library = list(mfilter) - mfilter.library.append('nemo') - # Check if api token exists, use if it does hf_token = get_hf_token() @@ -134,24 +120,11 @@ def search_huggingface_models( # Setup extra arguments for model filtering all_results = [] # type: List[ModelInfo] - for mfilter in model_filter: - cardData = None - limit = None - - if hasattr(mfilter, 'resolve_card_info') and mfilter.resolve_card_info is True: - cardData = True - - if hasattr(mfilter, 'limit_results') and mfilter.limit_results is not None: - limit = mfilter.limit_results - - results = api.list_models( - filter=mfilter, token=hf_token, sort="lastModified", direction=-1, cardData=cardData, limit=limit, - ) # type: Iterable[ModelInfo] - - for result in results: - all_results.append(result) + results = api.list_models( + token=hf_token, sort="lastModified", direction=-1, **model_filter + ) # type: Iterable[ModelInfo] - return all_results + return results def push_to_hf_hub( self, @@ -284,7 +257,10 @@ def _get_hf_model_card(self, template: str, template_kwargs: Optional[Dict[str, A HuggingFace ModelCard object that can be converted to a model card string. """ card_data = ModelCardData( - library_name='nemo', tags=['pytorch', 'NeMo'], license='cc-by-4.0', ignore_metadata_errors=True, + library_name='nemo', + tags=['pytorch', 'NeMo'], + license='cc-by-4.0', + ignore_metadata_errors=True, ) if 'card_data' not in template_kwargs: diff --git a/nemo/core/classes/module.py b/nemo/core/classes/module.py index 2d7bd0179447..ef80467c8c7a 100644 --- a/nemo/core/classes/module.py +++ b/nemo/core/classes/module.py @@ -18,6 +18,7 @@ from torch.nn import Module from nemo.core.classes.common import FileIO, Serialization, Typing +from nemo.utils import logging __all__ = ['NeuralModule'] @@ -54,39 +55,111 @@ def input_example(self, max_batch=None, max_dim=None): def freeze(self) -> None: r""" Freeze all params for inference. + + This method sets `requires_grad` to False for all parameters of the module. + It also stores the original `requires_grad` state of each parameter in a dictionary, + so that `unfreeze()` can restore the original state if `partial=True` is set in `unfreeze()`. """ - for param in self.parameters(): + grad_map = {} + + for pname, param in self.named_parameters(): + # Store the original grad state + grad_map[pname] = param.requires_grad + # Freeze the parameter param.requires_grad = False + # Store the frozen grad map + if not hasattr(self, '_frozen_grad_map'): + self._frozen_grad_map = grad_map + else: + self._frozen_grad_map.update(grad_map) + self.eval() - def unfreeze(self) -> None: + def unfreeze(self, partial: bool = False) -> None: """ Unfreeze all parameters for training. + + Allows for either total unfreeze or partial unfreeze (if the module was explicitly frozen previously with `freeze()`). + The `partial` argument is used to determine whether to unfreeze all parameters or only the parameters that were + previously unfrozen prior `freeze()`. + + Example: + Consider a model that has an encoder and a decoder module. Assume we want the encoder to be frozen always. + + ```python + model.encoder.freeze() # Freezes all parameters in the encoder explicitly + ``` + + During inference, all parameters of the model should be frozen - we do this by calling the model's freeze method. + This step records that the encoder module parameters were already frozen, and so if partial unfreeze is called, + we should keep the encoder parameters frozen. + + ```python + model.freeze() # Freezes all parameters in the model; encoder remains frozen + ``` + + Now, during fine-tuning, we want to unfreeze the decoder but keep the encoder frozen. We can do this by calling + `unfreeze(partial=True)`. + + ```python + model.unfreeze(partial=True) # Unfreezes only the decoder; encoder remains frozen + ``` + + Args: + partial: If True, only unfreeze parameters that were previously frozen. If the parameter was already frozen + when calling `freeze()`, it will remain frozen after calling `unfreeze(partial=True)`. """ - for param in self.parameters(): - param.requires_grad = True + if partial and not hasattr(self, '_frozen_grad_map'): + raise ValueError("Cannot unfreeze partially without first freezing the module with `freeze()`") + + for pname, param in self.named_parameters(): + if not partial: + # Unfreeze all parameters + param.requires_grad = True + else: + # Unfreeze only parameters that were previously frozen + + # Check if the parameter was frozen + if pname in self._frozen_grad_map: + param.requires_grad = self._frozen_grad_map[pname] + else: + # Log a warning if the parameter was not found in the frozen grad map + logging.warning( + f"Parameter {pname} not found in list of previously frozen parameters. " + f"Unfreezing this parameter." + ) + param.requires_grad = True + + # Clean up the frozen grad map + if hasattr(self, '_frozen_grad_map'): + delattr(self, '_frozen_grad_map') self.train() @contextmanager def as_frozen(self): """ - Context manager which temporarily freezes a module, yields control and finally unfreezes the module. + Context manager which temporarily freezes a module, yields control and finally unfreezes the module partially + to return to original state. + + Allows for either total unfreeze or partial unfreeze (if the module was explicitly frozen previously with `freeze()`). + The `partial` argument is used to determine whether to unfreeze all parameters or only the parameters that were + previously unfrozen prior `freeze()`. + + Example: + with model.as_frozen(): # by default, partial = True + # Do something with the model + pass + + # Model's parameters are now back to original state of requires_grad """ training_mode = self.training - grad_map = {} - for pname, param in self.named_parameters(): - grad_map[pname] = param.requires_grad - self.freeze() try: yield finally: - self.unfreeze() - - for pname, param in self.named_parameters(): - param.requires_grad = grad_map[pname] + self.unfreeze(partial=True) if training_mode: self.train() diff --git a/nemo/core/connectors/save_restore_connector.py b/nemo/core/connectors/save_restore_connector.py index 23b38510bb00..cd9971a9c383 100644 --- a/nemo/core/connectors/save_restore_connector.py +++ b/nemo/core/connectors/save_restore_connector.py @@ -19,7 +19,8 @@ import tarfile import tempfile import uuid -from typing import Optional, Set, Union +from contextlib import contextmanager +from typing import Callable, Generator, Optional, Set, Union import torch from omegaconf import DictConfig, OmegaConf @@ -141,9 +142,11 @@ def load_config_and_state_dict( else: # Extract the nemo file into the temporary directory - self._unpack_nemo_file( - path2file=restore_path, out_folder=tmpdir, extract_config_only=return_config is True - ) + filter_fn = None + if return_config: + filter_fn = lambda name: '.yaml' in name + members = self._filtered_tar_info(restore_path, filter_fn=filter_fn) + self._unpack_nemo_file(path2file=restore_path, out_folder=tmpdir, members=members) # Change current working directory to os.chdir(tmpdir) @@ -485,6 +488,29 @@ def _handle_artifacts(self, model, nemo_file_folder): # TODO: see cases when this can occur, and if we can fix them logging.warning("Model contains registered artifacts, but no restoration paths found") if len(tarfile_artifacts) > 0 and len(restoration_paths) > 0: + + def check_artifact_and_query_basename_match(query_path: str) -> bool: + for _, artiitem in tarfile_artifacts: + # Get basename and copy it to nemo_file_folder + if 'nemo:' in artiitem.path: + artifact_base_name = artiitem.path.split('nemo:')[1] + else: + artifact_base_name = os.path.basename(artiitem.path) + + if artifact_base_name == os.path.basename(query_path): + return True + return False + + artifact_rel_paths = {} + for path in restoration_paths: + if self.model_extracted_dir: + artifact_rel_paths[path] = self._filtered_recursive_walk( + path, filter_fn=check_artifact_and_query_basename_match + ) + else: + artifact_rel_paths[path] = self._filtered_tar_info( + path, filter_fn=check_artifact_and_query_basename_match + ) # Need to step into nemo archive to extract file # Get path where the command is executed - the artifacts will be "retrieved" there # (original .nemo behavior) @@ -493,13 +519,16 @@ def _handle_artifacts(self, model, nemo_file_folder): # TemporaryDirectory context must always be outer to try-catch chdir otherwise it crashes on Windows with tempfile.TemporaryDirectory() as archive_dir: try: - # unpack all restorations paths (nemo checkpoints) + # unpack artifacts from all restorations paths (nemo checkpoints) # in nemo checkpoints all resources contain hash in name, so there should be no collisions for path in restoration_paths: if self.model_extracted_dir: - shutil.copytree(src=path, dst=archive_dir, dirs_exist_ok=True) + for rel_path in artifact_rel_paths[path]: + shutil.copy2(src=rel_path, dst=archive_dir) else: - self._unpack_nemo_file(path2file=path, out_folder=archive_dir) + self._unpack_nemo_file( + path2file=path, out_folder=archive_dir, members=artifact_rel_paths[path] + ) os.chdir(archive_dir) for conf_path, artiitem in tarfile_artifacts: # Get basename and copy it to nemo_file_folder @@ -586,7 +615,36 @@ def _safe_extract(tar, out_folder: str, members=None): logging.warning(f"Skipping potentially unsafe member: {member.name}") @staticmethod - def _unpack_nemo_file(path2file: str, out_folder: str, extract_config_only: bool = False) -> str: + def _filtered_tar_info(tar_path: str, filter_fn: Optional[Callable[[str], bool]] = None) -> list[tarfile.TarInfo]: + """ + Returns the members of the tarball filtered by a function + """ + with SaveRestoreConnector._tar_open(tar_path) as tar: + members = tar.getmembers() + if filter_fn is None: + return members + + return [x for x in members if filter_fn(x.name)] + + @staticmethod + def _filtered_recursive_walk(path: str, filter_fn: Optional[Callable[[str], bool]] = None) -> list[str]: + """ + Returns the result of recursive walking a path and filtering each element + """ + if not os.path.isdir(path): + raise NotADirectoryError(f"Expected {path=} to be a directory") + + filtered_rel_paths = [] + for root, _, files in os.walk(path): + for f in files: + full_rel_path = os.path.join(root, f) + if filter_fn is None or filter_fn(full_rel_path): + filtered_rel_paths.append(full_rel_path) + return filtered_rel_paths + + @staticmethod + @contextmanager + def _tar_open(path2file: str) -> Generator[tarfile.TarFile, None, None]: if not os.path.exists(path2file): raise FileNotFoundError(f"{path2file} does not exist") @@ -599,13 +657,20 @@ def _unpack_nemo_file(path2file: str, out_folder: str, extract_config_only: bool except tarfile.ReadError: # can be older checkpoint => try compressed tar tar_header = "r:gz" + tar = tarfile.open(path2file, tar_header) - if not extract_config_only: - SaveRestoreConnector._safe_extract(tar, out_folder) - else: - members = [x for x in tar.getmembers() if ".yaml" in x.name] - SaveRestoreConnector._safe_extract(tar, out_folder, members) - tar.close() + try: + yield tar + finally: + tar.close() + + @staticmethod + def _unpack_nemo_file(path2file: str, out_folder: str, members: Optional[list[str]] = None) -> str: + with SaveRestoreConnector._tar_open(path2file) as tar: + if members is None: + SaveRestoreConnector._safe_extract(tar, out_folder) + else: + SaveRestoreConnector._safe_extract(tar, out_folder, members) return out_folder @staticmethod diff --git a/nemo/deploy/multimodal/query_multimodal.py b/nemo/deploy/multimodal/query_multimodal.py index 9f747ff6d306..1c01c6861048 100644 --- a/nemo/deploy/multimodal/query_multimodal.py +++ b/nemo/deploy/multimodal/query_multimodal.py @@ -13,7 +13,6 @@ # limitations under the License. import numpy as np -from decord import VideoReader from PIL import Image from nemo.deploy.utils import str_list2numpy @@ -24,6 +23,13 @@ except Exception: use_pytriton = False +try: + from decord import VideoReader +except Exception: + import logging + + logging.warning("The package `decord` was not installed in this environment.") + class NemoQueryMultimodal: """ @@ -56,12 +62,31 @@ def setup_media(self, input_media): vr = VideoReader(input_media) frames = [f.asnumpy() for f in vr] return np.array(frames) - elif self.model_type == "neva": + elif self.model_type == "lita" or self.model_type == "vita": + vr = VideoReader(input_media) + frames = [f.asnumpy() for f in vr] + subsample_len = self.frame_len(frames) + sub_frames = self.get_subsampled_frames(frames, subsample_len) + return np.array(sub_frames) + elif self.model_type == "neva" or self.model_type == "vila": media = Image.open(input_media).convert('RGB') return np.expand_dims(np.array(media), axis=0) else: raise RuntimeError(f"Invalid model type {self.model_type}") + def frame_len(self, frames): + max_frames = 256 + if len(frames) <= max_frames: + return len(frames) + else: + subsample = int(np.ceil(float(len(frames)) / max_frames)) + return int(np.round(float(len(frames)) / subsample)) + + def get_subsampled_frames(self, frames, subsample_len): + idx = np.round(np.linspace(0, len(frames) - 1, subsample_len)).astype(int) + sub_frames = [frames[i] for i in idx] + return sub_frames + def query( self, input_text, diff --git a/nemo/deploy/nlp/__init__.py b/nemo/deploy/nlp/__init__.py index a2110931c6df..5ebbe6816664 100644 --- a/nemo/deploy/nlp/__init__.py +++ b/nemo/deploy/nlp/__init__.py @@ -15,7 +15,7 @@ use_query_llm = True try: - from nemo.deploy.nlp.query_llm import NemoQueryLLM + from nemo.deploy.nlp.query_llm import NemoQueryLLM, NemoQueryLLMPyTorch except Exception: use_query_llm = False diff --git a/nemo/deploy/nlp/megatronllm_deployable.py b/nemo/deploy/nlp/megatronllm_deployable.py index c27bbbd0102b..1fe029f9fade 100644 --- a/nemo/deploy/nlp/megatronllm_deployable.py +++ b/nemo/deploy/nlp/megatronllm_deployable.py @@ -15,6 +15,7 @@ import logging from enum import IntEnum, auto from pathlib import Path +from typing import List import numpy as np import torch @@ -129,6 +130,12 @@ def _load_from_nemo_checkpoint(self, nemo_checkpoint_filepath: str, num_devices: nemo_checkpoint_filepath, trainer=trainer, return_config=True ) # transformer_engine should always be true according to EricH, but GPT-2B model will fail if it is enabled + if not custom_config.transformer_engine: + LOGGER.warning( + "MegatronLLMDeployable expects model config transformer_engine=True, but this model has it =False. " + "Overriding it to =True, but this may break certain checkpoints converted on older Nemo versions. " + "If your model breaks, please try re-converting the checkpoint on the current Nemo version." + ) custom_config.transformer_engine = True # using multi-gpu for tensor parallelism directly for now, could do pipeline parallel instead or a combination custom_config.tensor_model_parallel_size = num_devices @@ -233,9 +240,7 @@ def _length_params_from_triton_inputs(**inputs: np.ndarray): length_params[length_param_field] = inputs.pop(length_param_field)[0][0] return length_params - @batch - def triton_infer_fn(self, **inputs: np.ndarray): - """Triton server inference function that actually runs the model""" + def generate(self, inputs: List[str], length_params: LengthParam, sampling_params: SamplingParam): if torch.distributed.is_initialized(): distributed_rank = torch.distributed.get_rank() if distributed_rank != 0: @@ -245,13 +250,16 @@ def triton_infer_fn(self, **inputs: np.ndarray): signal_value = ServerSync.SIGNAL.to_long_tensor() torch.distributed.broadcast(signal_value, 0) + return self.model.generate(inputs=inputs, length_params=length_params, sampling_params=sampling_params) + + @batch + def triton_infer_fn(self, **inputs: np.ndarray): + """Triton server inference function that actually runs the model""" input_strings = str_ndarray2list(inputs.pop("prompts")) sampling_params = self._sampling_params_from_triton_inputs(**inputs) length_params = self._length_params_from_triton_inputs(**inputs) - model_output = self.model.generate( - inputs=input_strings, length_params=length_params, sampling_params=sampling_params - ) + model_output = self.generate(input_strings, length_params, sampling_params) ''' model_output['sentences'] will be a list of strings (one per prompt) other fields will either be a list of lists (tokens, for example) diff --git a/nemo/deploy/nlp/query_llm.py b/nemo/deploy/nlp/query_llm.py index 940a927c7a54..7e873db6b5b1 100644 --- a/nemo/deploy/nlp/query_llm.py +++ b/nemo/deploy/nlp/query_llm.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import time from abc import ABC, abstractmethod import numpy as np @@ -30,23 +31,99 @@ def __init__(self, url, model_name): self.url = url self.model_name = model_name - @abstractmethod + +class NemoQueryLLMPyTorch(NemoQueryLLMBase): + """ + Sends a query to Triton for LLM inference + + Example: + from nemo.deploy import NemoTritonQueryLLMPyTorch + + nq = NemoTritonQueryLLMPyTorch(url="localhost", model_name="GPT-2B") + + prompts = ["hello, testing GPT inference", "another GPT inference test?"] + output = nq.query_llm( + prompts=prompts, + max_length=100, + top_k=1, + top_p=0.0, + temperature=0.0, + ) + print("prompts: ", prompts) + """ + + def __init__(self, url, model_name): + super().__init__( + url=url, + model_name=model_name, + ) + + # these arguments are explicitly defined in order to make it clear to user what they can pass + # names and optionality should exactly match the get_triton_input() results for MegatronGPTDeployable def query_llm( self, prompts, - stop_words_list=None, - bad_words_list=None, - no_repeat_ngram_size=None, - max_output_len=512, - top_k=1, - top_p=0.0, - temperature=1.0, - random_seed=None, - task_id=None, - lora_uids=None, + use_greedy: bool = None, + temperature: float = None, + top_k: int = None, + top_p: float = None, + repetition_penalty: float = None, + add_BOS: bool = None, + all_probs: bool = None, + compute_logprob: bool = None, + end_strings=None, + min_length: int = None, + max_length: int = None, init_timeout=60.0, ): - pass + """ + Query the Triton server synchronously and return a list of responses. + + Args: + prompts (List(str)): list of sentences. + use_greedy (bool): use greedy sampling, effectively the same as top_k=1 + temperature (float): A parameter of the softmax function, which is the last layer in the network. + top_k (int): limits us to a certain number (K) of the top tokens to consider. + top_p (float): limits us to the top tokens within a certain probability mass (p). + repetition_penalty (float): penalty applied to repeated sequences, 1.0 means no penalty. + add_BOS (bool): whether or not to add a BOS (beginning of sentence) token. + all_probs (bool): when using compute_logprob, returns probabilities for all tokens in vocabulary. + compute_logprob (bool): get back probabilities of all tokens in the sequence. + end_strings (List(str)): list of strings which will terminate generation when they appear in the output. + min_length (int): min generated tokens. + max_length (int): max generated tokens. + init_timeout (flat): timeout for the connection. + """ + prompts = str_list2numpy(prompts) + inputs = { + "prompts": prompts, + } + if use_greedy is not None: + inputs["use_greedy"] = np.full(prompts.shape, use_greedy, dtype=np.bool_) + if temperature is not None: + inputs["temperature"] = np.full(prompts.shape, temperature, dtype=np.single) + if top_k is not None: + inputs["top_k"] = np.full(prompts.shape, top_k, dtype=np.int_) + if top_p is not None: + inputs["top_p"] = np.full(prompts.shape, top_p, dtype=np.single) + if repetition_penalty is not None: + inputs["repetition_penalty"] = np.full(prompts.shape, repetition_penalty, dtype=np.single) + if add_BOS is not None: + inputs["add_BOS"] = np.full(prompts.shape, add_BOS, dtype=np.bool_) + if all_probs is not None: + inputs["all_probs"] = np.full(prompts.shape, all_probs, dtype=np.bool_) + if compute_logprob is not None: + inputs["compute_logprob"] = np.full(prompts.shape, compute_logprob, dtype=np.bool_) + if end_strings is not None: + inputs["end_strings"] = str_list2numpy(end_strings) + if min_length is not None: + inputs["min_length"] = np.full(prompts.shape, min_length, dtype=np.int_) + if max_length is not None: + inputs["max_length"] = np.full(prompts.shape, max_length, dtype=np.int_) + + with ModelClient(self.url, self.model_name, init_timeout_s=init_timeout) as client: + result_dict = client.infer_batch(**inputs) + return result_dict class NemoQueryLLM(NemoQueryLLMBase): @@ -96,6 +173,7 @@ def query_llm( compute_logprob: bool = None, end_strings=None, init_timeout=60.0, + openai_format_response: bool = False, ): """ Query the Triton server synchronously and return a list of responses. @@ -183,7 +261,17 @@ def query_llm( return "Unknown output keyword." sentences = np.char.decode(output.astype("bytes"), "utf-8") - return sentences + if openai_format_response: + openai_response = { + "id": f"cmpl-{int(time.time())}", + "object": "text_completion", + "created": int(time.time()), + "model": self.model_name, + "choices": [{"text": str(sentences)}], + } + return openai_response + else: + return sentences else: return result_dict["outputs"] diff --git a/nemo/deploy/service/config.json b/nemo/deploy/service/config.json deleted file mode 100644 index d3b3440dd97b..000000000000 --- a/nemo/deploy/service/config.json +++ /dev/null @@ -1,5 +0,0 @@ -{ - "triton_service_port": 8000, - "triton_service_ip": "0.0.0.0", - "triton_request_timeout": 60 - } \ No newline at end of file diff --git a/nemo/deploy/service/rest_model_api.py b/nemo/deploy/service/rest_model_api.py index 5c49370fd45f..fbc774883faa 100644 --- a/nemo/deploy/service/rest_model_api.py +++ b/nemo/deploy/service/rest_model_api.py @@ -12,8 +12,9 @@ import json import os from pathlib import Path +import requests -from fastapi import FastAPI +from fastapi import FastAPI, HTTPException from pydantic import BaseModel from pydantic_settings import BaseSettings @@ -33,6 +34,7 @@ def __init__(self): self._triton_service_port = config_json["triton_service_port"] self._triton_service_ip = config_json["triton_service_ip"] self._triton_request_timeout = config_json["triton_request_timeout"] + self._openai_format_response = config_json["openai_format_response"] except Exception as error: print("An exception occurred:", error) return @@ -49,6 +51,14 @@ def triton_service_ip(self): def triton_request_timeout(self): return self._triton_request_timeout + @property + def openai_format_response(self): + """ + Retuns the response from Triton server in OpenAI compatible formar if set to True, + default set in config.json is false. + """ + return self._openai_format_response + app = FastAPI() triton_settings = TritonSettings() @@ -66,6 +76,23 @@ class CompletionRequest(BaseModel): frequency_penalty: float = 1.0 +@app.get("/triton_health") +async def check_triton_health(): + """ + This method exposes endpoint "/triton_health" which can be used to verify if Triton server is accessible while running the REST or FastAPI application. + Verify by running: curl http://service_http_address:service_port/triton_health and the returned status should inform if the server is accessible. + """ + triton_url = f"triton_settings.triton_service_ip:str(triton_settings.triton_service_port)/v2/health/ready" + try: + response = requests.get(triton_url, timeout=5) + if response.status_code == 200: + return {"status": "Triton server is reachable and ready"} + else: + raise HTTPException(status_code=503, detail="Triton server is not ready") + except requests.RequestException as e: + raise HTTPException(status_code=503, detail=f"Cannot reach Triton server: {str(e)}") + + @app.post("/v1/completions/") def completions_v1(request: CompletionRequest): try: @@ -78,10 +105,14 @@ def completions_v1(request: CompletionRequest): top_p=request.top_p, temperature=request.temperature, init_timeout=triton_settings.triton_request_timeout, + openai_format_response=triton_settings.openai_format_response, ) - return { - "output": output[0][0], - } + if triton_settings.openai_format_response: + return output + else: + return { + "output": output[0][0], + } except Exception as error: print("An exception occurred:", error) return {"error": "An exception occurred"} diff --git a/nemo/export/multimodal/build.py b/nemo/export/multimodal/build.py index b21e5383b57f..03afec176325 100644 --- a/nemo/export/multimodal/build.py +++ b/nemo/export/multimodal/build.py @@ -17,6 +17,7 @@ import shutil import tarfile import tempfile +from pathlib import Path from time import time import tensorrt as trt @@ -37,7 +38,7 @@ def build_trtllm_engine( llm_checkpoint_path: str = None, model_type: str = "neva", llm_model_type: str = "llama", - tensor_parallel_size: int = 1, + tensor_parallelism_size: int = 1, max_input_len: int = 256, max_output_len: int = 256, max_batch_size: int = 1, @@ -45,10 +46,11 @@ def build_trtllm_engine( dtype: str = "bfloat16", ): trt_llm_exporter = TensorRTLLM(model_dir=model_dir, load_model=False) + visual_checkpoint_model = ['neva', 'lita', 'vila', 'vita'] trt_llm_exporter.export( - nemo_checkpoint_path=visual_checkpoint_path if model_type == "neva" else llm_checkpoint_path, + nemo_checkpoint_path=visual_checkpoint_path if model_type in visual_checkpoint_model else llm_checkpoint_path, model_type=llm_model_type, - tensor_parallel_size=tensor_parallel_size, + tensor_parallelism_size=tensor_parallelism_size, max_input_len=max_input_len, max_output_len=max_output_len, max_batch_size=max_batch_size, @@ -75,12 +77,24 @@ def export_visual_wrapper_onnx( def build_trt_engine( - model_type, input_sizes, output_dir, max_batch_size, dtype=torch.bfloat16, image_size=None, num_frames=None + model_type, + input_sizes, + output_dir, + vision_max_batch_size, + dtype=torch.bfloat16, + image_size=None, + num_frames=None, + nemo_config=None, ): part_name = 'visual_encoder' onnx_file = '%s/onnx/%s.onnx' % (output_dir, part_name) engine_file = '%s/%s.engine' % (output_dir, part_name) config_file = '%s/%s' % (output_dir, "config.json") + nemo_config_file = '%s/%s' % (output_dir, "nemo_config.yaml") + + with open(nemo_config_file, 'w') as f: + yaml.dump(nemo_config, f) + logger.log(trt.Logger.INFO, "Building TRT engine for %s" % part_name) builder = trt.Builder(logger) @@ -110,8 +124,8 @@ def build_trt_engine( nBS = -1 nMinBS = 1 - nOptBS = max(nMinBS, int(max_batch_size / 2)) - nMaxBS = max_batch_size + nOptBS = max(nMinBS, int(vision_max_batch_size / 2)) + nMaxBS = vision_max_batch_size inputT = network.get_input(0) @@ -145,17 +159,41 @@ def build_trt_engine( def build_neva_engine( + model_type: str, model_dir: str, visual_checkpoint_path: str, - max_batch_size: int = 1, + vision_max_batch_size: int = 1, ): device = torch.device("cuda") if torch.cuda.is_available() else "cpu" # extract NeMo checkpoint with tempfile.TemporaryDirectory() as temp: - mp0_weights, nemo_config, _ = load_nemo_model(visual_checkpoint_path, temp) + temp_path = Path(temp) + mp0_weights, nemo_config, _ = load_nemo_model(visual_checkpoint_path, temp_path) vision_config = nemo_config["mm_cfg"]["vision_encoder"] + class DownSampleBlock(torch.nn.Module): + def forward(self, x): + vit_embeds = x + h = w = int(vit_embeds.shape[1] ** 0.5) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) + vit_embeds = self.flat_square(vit_embeds) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) + return vit_embeds + + def flat_square(self, x): + n, w, h, c = x.size() + if w % 2 == 1: + x = torch.cat([x, torch.zeros((n, 1, h, c), dtype=x.dtype).to(x.device)], dim=1).contiguous() + n, w, h, c = x.size() + if h % 2 == 1: + x = torch.cat([x, torch.zeros((n, w, 1, c), dtype=x.dtype).to(x.device)], dim=2).contiguous() + n, w, h, c = x.size() + x = x.view(n, w, int(h / 2), int(c * 2)) + x = x.permute(0, 2, 1, 3).contiguous() + x = x.view(n, int(h / 2), int(w / 2), int(c * 4)) + return x + class VisionEncoderWrapper(torch.nn.Module): def __init__(self, encoder, connector): @@ -166,7 +204,6 @@ def __init__(self, encoder, connector): def forward(self, images): vision_x = self.encoder(pixel_values=images, output_hidden_states=True) vision_x = vision_x.hidden_states[-2] - vision_x = vision_x[:, 1:] vision_x = self.connector(vision_x) return vision_x @@ -178,44 +215,82 @@ def forward(self, images): dtype = hf_config.torch_dtype # connector - assert nemo_config["mm_cfg"]["mm_mlp_adapter_type"] == "mlp2x_gelu" - vision_connector = torch.nn.Sequential( - torch.nn.Linear(vision_config["hidden_size"], nemo_config["hidden_size"], bias=True), - torch.nn.GELU(), - torch.nn.Linear(nemo_config["hidden_size"], nemo_config["hidden_size"], bias=True), - ).to(dtype=dtype) - - key_prefix = "model.embedding.word_embeddings.adapter_layer.mm_projector_adapter.mm_projector" - for layer in range(0, 3, 2): - vision_connector[layer].load_state_dict( + if nemo_config["mm_cfg"]["mm_mlp_adapter_type"] == "mlp2x_gelu": + vision_connector = torch.nn.Sequential( + torch.nn.Linear(vision_config["hidden_size"], nemo_config["hidden_size"], bias=True), + torch.nn.GELU(), + torch.nn.Linear(nemo_config["hidden_size"], nemo_config["hidden_size"], bias=True), + ).to(dtype=dtype) + + key_prefix = "model.embedding.word_embeddings.adapter_layer.mm_projector_adapter.mm_projector" + for layer in range(0, 3, 2): + vision_connector[layer].load_state_dict( + { + 'weight': mp0_weights[f"{key_prefix}.{layer}.weight"].to(dtype), + 'bias': mp0_weights[f"{key_prefix}.{layer}.bias"].to(dtype), + } + ) + elif nemo_config["mm_cfg"]["mm_mlp_adapter_type"] == "linear": + vision_connector = torch.nn.Linear(vision_config["hidden_size"], nemo_config["hidden_size"], bias=True) + key_prefix = "model.embedding.word_embeddings.adapter_layer.mm_projector_adapter.mm_projector" + vision_connector.load_state_dict( { - 'weight': mp0_weights[f"{key_prefix}.{layer}.weight"].to(dtype), - 'bias': mp0_weights[f"{key_prefix}.{layer}.bias"].to(dtype), + 'weight': mp0_weights[f"{key_prefix}.weight"].to(dtype), + 'bias': mp0_weights[f"{key_prefix}.bias"].to(dtype), } ) + elif nemo_config["mm_cfg"]["mm_mlp_adapter_type"] == "mlp_downsample": + vision_connector = torch.nn.Sequential( + DownSampleBlock(), + torch.nn.LayerNorm(vision_config["hidden_size"] * 4), + torch.nn.Linear(vision_config["hidden_size"] * 4, nemo_config["hidden_size"], bias=True), + torch.nn.GELU(), + torch.nn.Linear(nemo_config["hidden_size"], nemo_config["hidden_size"], bias=True), + ).to(dtype=dtype) + key_prefix = "model.embedding.word_embeddings.adapter_layer.mm_projector_adapter.mm_projector" + for layer in [1, 2, 4]: + vision_connector[layer].load_state_dict( + { + 'weight': mp0_weights[f"{key_prefix}.{layer}.weight"].to(dtype), + 'bias': mp0_weights[f"{key_prefix}.{layer}.bias"].to(dtype), + } + ) + + else: + raise ValueError(f"Unknown projector type: {nemo_config['mm_cfg']['mm_mlp_adapter_type']}") # export the whole wrapper + lita_num_frames = None wrapper = VisionEncoderWrapper(vision_encoder, vision_connector).to(device, dtype) - image_size = hf_config.vision_config.image_size + if model_type == "lita" or model_type == "vila": + image_size = hf_config.image_size + if model_type == "lita": + lita_num_frames = nemo_config['mm_cfg']['lita']['sample_frames'] + else: + image_size = hf_config.vision_config.image_size + if model_type == "vita": + lita_num_frames = nemo_config['mm_cfg']['lita']['sample_frames'] dummy_image = torch.empty( 1, 3, image_size, image_size, dtype=dtype, device=device ) # dummy image shape [B, C, H, W] export_visual_wrapper_onnx(wrapper, dummy_image, model_dir) build_trt_engine( - "neva", + model_type, [3, image_size, image_size], model_dir, - max_batch_size, + vision_max_batch_size, dtype, image_size=image_size, + num_frames=lita_num_frames if model_type == "lita" or model_type == 'vita' else None, + nemo_config=nemo_config, ) def build_video_neva_engine( model_dir: str, visual_checkpoint_path: str, - max_batch_size: int = 1, + vision_max_batch_size: int = 1, ): device = torch.device("cuda") if torch.cuda.is_available() else "cpu" # extract NeMo checkpoint @@ -279,7 +354,7 @@ def forward(self, images): "video-neva", [num_frames, 3, image_size, image_size], # [num_frames, 3, H, W] model_dir, - max_batch_size, + vision_max_batch_size, dtype, image_size=image_size, num_frames=num_frames, @@ -290,11 +365,12 @@ def build_visual_engine( model_dir: str, visual_checkpoint_path: str, model_type: str = "neva", - max_batch_size: int = 1, + vision_max_batch_size: int = 1, ): - if model_type == "neva": - build_neva_engine(model_dir, visual_checkpoint_path, max_batch_size) + model_list = ['neva', 'lita', 'vila', 'vita'] + if model_type in model_list: + build_neva_engine(model_type, model_dir, visual_checkpoint_path, vision_max_batch_size) elif model_type == "video-neva": - build_video_neva_engine(model_dir, visual_checkpoint_path, max_batch_size) + build_video_neva_engine(model_dir, visual_checkpoint_path, vision_max_batch_size) else: raise RuntimeError(f"Invalid model type {model_type}") diff --git a/nemo/export/multimodal/run.py b/nemo/export/multimodal/run.py index f94c2e3f3944..149df995c77a 100644 --- a/nemo/export/multimodal/run.py +++ b/nemo/export/multimodal/run.py @@ -16,17 +16,27 @@ import json import os +try: + import decord +except Exception: + import logging + + logging.warning("The package `decord` was not installed in this environment.") + +import einops import numpy as np import tensorrt as trt import tensorrt_llm import tensorrt_llm.profiler as profiler import torch +import yaml from PIL import Image from tensorrt_llm import logger from tensorrt_llm._utils import str_dtype_to_trt from tensorrt_llm.runtime import ModelRunner, Session, TensorInfo +from torch.nn import functional as F from torchvision import transforms -from transformers import CLIPImageProcessor +from transformers import AutoProcessor, CLIPImageProcessor def trt_dtype_to_torch(dtype): @@ -67,6 +77,8 @@ def __init__(self, visual_engine_dir, llm_engine_dir): self.init_image_encoder(visual_engine_dir) self.init_tokenizer(llm_engine_dir) self.init_llm(llm_engine_dir) + if self.model_type == 'lita' or self.model_type == 'vila' or self.model_type == 'vita': + self.init_vision_preprocessor(visual_engine_dir) def init_tokenizer(self, llm_engine_dir): if os.path.exists(os.path.join(llm_engine_dir, 'huggingface_tokenizer')): @@ -74,6 +86,11 @@ def init_tokenizer(self, llm_engine_dir): self.tokenizer = AutoTokenizer.from_pretrained(os.path.join(llm_engine_dir, 'huggingface_tokenizer')) self.tokenizer.pad_token = self.tokenizer.eos_token + if self.model_type == 'vita': + self.tokenizer.im_start_id = self.tokenizer.convert_tokens_to_ids("") + self.tokenizer.im_end_id = self.tokenizer.convert_tokens_to_ids("") + self.tokenizer.vid_start_id = self.tokenizer.convert_tokens_to_ids("") + self.tokenizer.vid_end_id = self.tokenizer.convert_tokens_to_ids("") else: from sentencepiece import SentencePieceProcessor @@ -115,6 +132,12 @@ def batch_decode(self, x, **kwargs): self.tokenizer.padding_side = "right" + if self.model_type == 'lita': + self.tokenizer.im_start_id = sp.piece_to_id("") + self.tokenizer.im_end_id = sp.piece_to_id("") + self.tokenizer.vid_start_id = sp.piece_to_id("") + self.tokenizer.vid_end_id = sp.piece_to_id("") + def init_image_encoder(self, visual_engine_dir): vision_encoder_path = os.path.join(visual_engine_dir, 'visual_encoder.engine') logger.info(f'Loading engine from {vision_encoder_path}') @@ -123,6 +146,25 @@ def init_image_encoder(self, visual_engine_dir): logger.info(f'Creating session from engine {vision_encoder_path}') self.visual_encoder_session = Session.from_serialized_engine(engine_buffer) + def init_vision_preprocessor(self, visual_encoder_dir): + with open(os.path.join(visual_encoder_dir, 'nemo_config.yaml'), 'r') as f: + self.nemo_config = yaml.safe_load(f) + + vision_config = self.nemo_config["mm_cfg"]["vision_encoder"] + + if self.model_type == 'lita': + self.image_processor = AutoProcessor.from_pretrained( + vision_config["from_pretrained"], torch_dtype=torch.bfloat16, trust_remote_code=True + ) + elif self.model_type == 'vila' or self.model_type == 'vita': + from transformers import SiglipImageProcessor + + self.image_processor = SiglipImageProcessor.from_pretrained( + vision_config["from_pretrained"], torch_dtype=torch.bfloat16, trust_remote_code=True + ) + else: + raise ValueError(f"Invalid model type: {self.model_type}") + def init_llm(self, llm_engine_dir): self.model = ModelRunner.from_dir( llm_engine_dir, rank=tensorrt_llm.mpi_rank(), debug_mode=False, stream=self.stream @@ -137,25 +179,25 @@ def video_preprocess(self, video_path): vr = VideoReader(video_path) num_frames = self.num_frames if num_frames == -1: - frames = [Image.fromarray(frame.asnumpy()[:, :, ::-1]).convert('RGB') for frame in vr] + frames = [Image.fromarray(frame.asnumpy()).convert('RGB') for frame in vr] else: # equally sliced frames into self.num_frames frames # if self.num_frames is greater than the number of frames in the video, we will repeat the last frame num_frames = min(num_frames, len(vr)) indices = np.linspace(0, len(vr) - 1, num=num_frames, dtype=int) - frames = [Image.fromarray(vr[idx].asnumpy()[:, :, ::-1]).convert('RGB') for idx in indices] + frames = [Image.fromarray(vr[idx].asnumpy()).convert('RGB') for idx in indices] if len(frames) < num_frames: frames += [frames[-1]] * (num_frames - len(frames)) elif isinstance(video_path, np.ndarray): num_frames = self.num_frames if num_frames == -1: - frames = [Image.fromarray(frame[:, :, ::-1]).convert('RGB') for frame in video_path] + frames = [Image.fromarray(frame).convert('RGB') for frame in video_path] else: # equally sliced frames into self.num_frames frames # if self.num_frames is greater than the number of frames in the video, we will repeat the last frame num_frames = min(num_frames, video_path.shape[0]) indices = np.linspace(0, video_path.shape[0] - 1, num=num_frames, dtype=int) - frames = [Image.fromarray(video_path[idx][:, :, ::-1]).convert('RGB') for idx in indices] + frames = [Image.fromarray(video_path[idx]).convert('RGB') for idx in indices] if len(frames) < num_frames: frames += [frames[-1]] * (num_frames - len(frames)) else: @@ -169,25 +211,105 @@ def video_preprocess(self, video_path): ) # [num_frames, 3, H, W] return media_tensors.unsqueeze(0) # [1, num_frames, 3, H, W] + def insert_tokens_by_index(self, input_ids, num_frames): + im_start_id = self.tokenizer.im_start_id + im_end_id = self.tokenizer.im_end_id + vid_start_id = self.tokenizer.vid_start_id + vid_end_id = self.tokenizer.vid_end_id + + image_token_indices = (input_ids == 0).nonzero(as_tuple=False).squeeze().tolist() + input_ids = input_ids.squeeze().tolist() + offset = 0 + + # Insert the image tokens and corresponding start/end tokens + for i in range(num_frames): + idx = image_token_indices[1] + offset + input_ids.insert(idx + 1, im_end_id) + input_ids.insert(idx + 1, 0) + input_ids.insert(idx + 1, im_start_id) + offset += 3 + + # Insert the video start and end tokens around the video token + vid_idx = image_token_indices[1] + offset + input_ids.insert(vid_idx + 1, vid_end_id) + input_ids.insert(vid_idx + 1, 0) + input_ids.insert(vid_idx + 1, vid_start_id) + + input_ids.pop(image_token_indices[1]) + input_ids = torch.tensor(input_ids, dtype=torch.long).unsqueeze(0) + + return input_ids + def preprocess(self, warmup, pre_prompt, post_prompt, image, attention_mask, batch_size): if not warmup: profiler.start("Vision") - visual_features, visual_atts = self.get_visual_features(image, attention_mask) - if not warmup: profiler.stop("Vision") - pre_input_ids = self.tokenizer(pre_prompt, return_tensors="pt", padding=True).input_ids - if post_prompt[0] is not None: - post_input_ids = self.tokenizer(post_prompt, return_tensors="pt", padding=True).input_ids - if self.model_type == 'video-neva': - length = pre_input_ids.shape[1] + post_input_ids.shape[1] + visual_atts.shape[2] * visual_atts.shape[1] + if self.model_type == 'vila': + visual_features, visual_atts = self.get_visual_features(image, attention_mask) + input_ids = self.tokenizer_image_token(batch_size, pre_prompt[0] + post_prompt[0], self.tokenizer) + batch_split_prompts = self.split_prompt_by_images(input_ids) + first_batch_split_prompts = batch_split_prompts[0] + # compute prompt length + visual length + length = sum([ids.shape[1] for ids in first_batch_split_prompts]) + if batch_size == 1 and len(image) > 1: + # mode 1: multiple image as a whole, flatten visual dims + length += visual_atts.shape[0] * visual_atts.shape[1] + else: + # mode 2: multiple images individually (replicate prompt for each image) + length += visual_atts.shape[1] + + input_lengths = torch.IntTensor([length] * batch_size).to(torch.int32) + input_ids, ptuning_args = self.setup_fake_prompts_vila( + batch_size, visual_features, first_batch_split_prompts, input_lengths + ) + return input_ids, input_lengths, ptuning_args, visual_features + + elif self.model_type == 'lita' or self.model_type == 'vita': + visual_input = [] + for i, img in enumerate(image): + visual_features, visual_atts = self.get_visual_features(img, attention_mask) + visual_features = visual_features.unsqueeze(0) + im_tokens, vid_tokens, num_sample_frames = self.preprocess_lita_visual(visual_features, self.nemo_config) + visual_input.extend([im_tokens, vid_tokens]) + + input_ids = self.tokenizer_image_token(batch_size, pre_prompt[0] + post_prompt[0], self.tokenizer) + input_ids = self.insert_tokens_by_index(input_ids, num_sample_frames) + batch_splits = self.split_prompt_by_images(input_ids) + first_batch_split_prompts = batch_splits[0] + length = sum([ids.shape[1] for ids in first_batch_split_prompts]) + + # Update visual atts shape to match im_tokens shape and vid_tokens shape + im_tokens = im_tokens.view(1, -1, im_tokens.shape[-1]) + visual_features = torch.cat([im_tokens, vid_tokens], dim=1) + visual_atts = torch.ones(visual_features.size()[:-1], dtype=torch.long).to(image.device) + + if batch_size == 1: + length += visual_atts.shape[0] * visual_atts.shape[1] else: - length = pre_input_ids.shape[1] + post_input_ids.shape[1] + visual_atts.shape[1] + raise ValueError("Batch size greater than 1 is not supported for LITA and VITA models") + + input_lengths = torch.IntTensor([length] * batch_size).to(torch.int32) + input_ids, ptuning_args = self.setup_fake_prompts_vila( + batch_size, visual_input, first_batch_split_prompts, input_lengths + ) + return input_ids, input_lengths, ptuning_args, visual_features else: - post_input_ids = None - length = pre_input_ids.shape[1] + visual_atts.shape[1] + visual_features, visual_atts = self.get_visual_features(image, attention_mask) + pre_input_ids = self.tokenizer(pre_prompt, return_tensors="pt", padding=True).input_ids + if post_prompt[0] is not None: + post_input_ids = self.tokenizer(post_prompt, return_tensors="pt", padding=True).input_ids + if self.model_type == 'video-neva': + length = ( + pre_input_ids.shape[1] + post_input_ids.shape[1] + visual_atts.shape[2] * visual_atts.shape[1] + ) + else: + length = pre_input_ids.shape[1] + post_input_ids.shape[1] + visual_atts.shape[1] + else: + post_input_ids = None + length = pre_input_ids.shape[1] + visual_atts.shape[1] input_lengths = torch.IntTensor([length] * batch_size).to(torch.int32) @@ -197,6 +319,48 @@ def preprocess(self, warmup, pre_prompt, post_prompt, image, attention_mask, bat return input_ids, input_lengths, ptuning_args, visual_features + @staticmethod + def tokenizer_image_token(batch_size, prompt, tokenizer, image_token_index=-200): + prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("")] + + def insert_separator(X, sep): + return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1] + + input_ids = [] + offset = 0 + if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: + offset = 1 + input_ids.append(prompt_chunks[0][0]) + + for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): + input_ids.extend(x[offset:]) + + input_ids = torch.tensor(input_ids, dtype=torch.long) + input_ids[input_ids == image_token_index] = 0 + input_ids = input_ids.unsqueeze(0).expand(batch_size, -1) + + return input_ids + + def split_prompt_by_images(self, tensor): + batch_splits = [] + for batch in tensor: + # Find indices where value is zero () + zero_indices = (batch == 0).nonzero(as_tuple=False).squeeze(0) + # Add starting point for slicing + start_idx = 0 + splits = [] + for idx in zero_indices: + if start_idx != idx: # Ensure not slicing zero-length tensors + splits.append(batch[start_idx:idx].unsqueeze(0)) + start_idx = idx + 1 # Move start index past the zero + if start_idx < len(batch): # Handle last segment if it's not zero-ending + splits.append(batch[start_idx:].unsqueeze(0)) + # Remove empty tensors resulting from consecutive zeros + splits = [split for split in splits if split.numel() > 0] + batch_splits.append(splits) + + return batch_splits + def generate( self, pre_prompt, @@ -313,8 +477,104 @@ def setup_fake_prompts(self, visual_features, pre_input_ids, post_input_ids, inp return input_ids, ptuning_args + def setup_fake_prompts_vila(self, batch_size, visual_features, split_input_ids, input_lengths): + + if self.model_type == 'lita' or self.model_type == 'vita': + squeeze_img_tokens = visual_features[0].squeeze(0) + reshape_img_tokens = [t.unsqueeze(0) for t in squeeze_img_tokens] + visual_features = reshape_img_tokens + [visual_features[1]] + + fake_prompt_counter = self.model_config.vocab_size + if batch_size == 1: + # only check for multi-image inference (mode 1) + assert len(visual_features) <= len( + split_input_ids + ), "Unexpected number of visual features. Please check # in prompt and the #image files." + + input_ids = [] + if batch_size == 1: + input_ids = [split_input_ids[0]] + + if self.model_type == 'vila': + # mode 1: multiple image as a whole, concat all prompts together,
...
+                for idx, visual_feature in enumerate(visual_features):
+                    fake_prompt_id = torch.arange(fake_prompt_counter, fake_prompt_counter + visual_feature.shape[0])
+                    fake_prompt_counter += visual_feature.shape[0]
+                    fake_prompt_id = fake_prompt_id.unsqueeze(0)
+                    input_ids.append(fake_prompt_id)
+
+                    # in case no post prompt
+                    if len(split_input_ids) > idx + 1:
+                        input_ids.append(split_input_ids[idx + 1])
+            elif self.model_type == 'lita' or self.model_type == 'vita':
+                for idx, visual_f in enumerate(visual_features):
+                    fake_prompt_id = torch.arange(fake_prompt_counter, fake_prompt_counter + visual_f.shape[1])
+                    fake_prompt_id = fake_prompt_id.reshape(visual_f.shape[1])
+                    fake_prompt_counter += visual_f.shape[1]
+                    fake_prompt_id = fake_prompt_id.unsqueeze(0)
+                    input_ids.append(fake_prompt_id)
+
+                    # in case no post prompt
+                    if len(split_input_ids) > idx + 1:
+                        input_ids.append(split_input_ids[idx + 1])
+
+        elif batch_size > 1 and self.model_type == 'vila':
+            # mode 2: each image have individual prompt, 

+            for idx, visual_feature in enumerate(visual_features):
+                input_ids.append(split_input_ids[0])
+                fake_prompt_id = torch.arange(fake_prompt_counter, fake_prompt_counter + visual_feature.shape[0])
+                fake_prompt_counter += visual_feature.shape[0]
+                fake_prompt_id = fake_prompt_id.unsqueeze(0)
+                input_ids.append(fake_prompt_id)
+                if len(split_input_ids) > 1:
+                    input_ids.append(split_input_ids[1])
+
+        input_ids = torch.cat(input_ids, dim=1).contiguous().to(torch.int32)
+        input_ids = input_ids.reshape(batch_size, -1)
+        ptuning_args = self.ptuning_setup(visual_features, input_ids, input_lengths)
+        return input_ids, ptuning_args
+
+    def preprocess_lita_visual(self, visual_features, config):
+
+        b, t, s, d = visual_features.shape
+
+        num_frames = t
+        if (
+            'visual_token_format' in config['mm_cfg']['lita']
+            and config['mm_cfg']['lita']['visual_token_format'] == 'im_vid_start_end'
+        ):
+            num_image_frames = min(num_frames, config['mm_cfg']['lita']['sample_frames'])
+            idx = np.round(np.linspace(0, num_frames - 1, num_image_frames)).astype(int)
+
+            # Image and video features
+            im_features = visual_features[:, idx, ...]
+
+            vid_features = einops.reduce(visual_features, 'b t s d -> b t d', 'mean')
+            return im_features, vid_features, num_image_frames
+
+        elif (
+            'lita_video_arch' in config['mm_cfg']['lita']
+            and config['mm_cfg']['lita']['lita_video_arch'] == 'temporal_spatial_pool'
+        ):
+            pool_size = 2
+            selected_frames = np.round(np.linspace(0, visual_features.shape[1] - 1, pool_size * pool_size)).astype(int)
+            s_tokens = visual_features[:, selected_frames, ...]
+            s_tokens = einops.rearrange(s_tokens, 'b t (h w) d -> (b t) d h w', h=16, w=16)
+            s_tokens = F.avg_pool2d(s_tokens, kernel_size=pool_size)
+            s_tokens = einops.rearrange(s_tokens, '(b t) d h w -> b (t h w) d', b=b)
+
+            t_tokens = einops.reduce(visual_features, 'b t s d -> b t d', 'mean')
+
+            return t_tokens, s_tokens, pool_size**2
+
+        else:
+            raise ValueError(f'Invalid visual token format: {config["mm_cfg"]["lita"]["visual_token_format"]}')
+
     def ptuning_setup(self, prompt_table, input_ids, input_lengths):
         hidden_size = self.model_config.hidden_size * self.runtime_mapping.tp_size
+
+        if self.model_type == 'lita' or self.model_type == 'vita':
+            prompt_table = torch.cat(prompt_table, dim=1)
         if prompt_table is not None:
             task_vocab_size = torch.tensor(
                 [prompt_table.shape[1]],
@@ -338,8 +598,109 @@ def ptuning_setup(self, prompt_table, input_ids, input_lengths):
 
         return [prompt_table, tasks, task_vocab_size]
 
+    def expand2square_pt(self, images, background_color):
+        height, width = images.shape[-2:]
+        b = len(images)
+        background_color = torch.Tensor(background_color)
+        if width == height:
+            return images
+        elif width > height:
+            result = einops.repeat(background_color, 'c -> b c h w', b=b, h=width, w=width).clone()
+            paste_start = (width - height) // 2
+            paste_end = paste_start + height
+            result[:, :, paste_start:paste_end, :] = images
+            return result
+        else:
+            result = einops.repeat(background_color, 'c -> b c h w', b=b, h=height, w=height).clone()
+            paste_start = (height - width) // 2
+            paste_end = paste_start + width
+            result[:, :, :, paste_start:paste_end] = images
+            return result
+
+    def load_video(self, config, video_path, processor, num_frames=None):
+        frames = None
+        if isinstance(video_path, str):
+            decord.bridge.set_bridge('torch')
+            video_reader = decord.VideoReader(uri=video_path)
+            if num_frames is not None:
+                idx = np.round(np.linspace(0, len(video_reader) - 1, num_frames)).astype(int)
+                frames = video_reader.get_batch(idx)
+            else:
+                frames = torch.cat([torch.tensor(f.asnumpy()) for f in video_reader])
+        elif isinstance(video_path, np.ndarray):
+            frames = torch.tensor(video_path, dtype=torch.float32)
+
+        return self.preprocess_frames(frames, config, processor)
+
+    def preprocess_frames(self, frames, config, processor):
+        frames = einops.rearrange(frames, 't h w c -> t c h w')
+        if config['data']['image_aspect_ratio'] == 'pad':
+            frames = self.expand2square_pt(frames, tuple(int(x * 255) for x in processor.image_mean))
+        processed_frames = processor.preprocess(frames, return_tensors='pt')['pixel_values']
+        return processed_frames
+
+    def get_num_sample_frames(self, config, vid_len):
+        if (
+            'visual_token_format' in config['mm_cfg']['lita']
+            and config['mm_cfg']['lita']['visual_token_format'] == 'im_vid_start_end'
+        ):
+            max_frames = config['data']['num_frames']
+            if vid_len <= max_frames:
+                return vid_len
+            else:
+                subsample = int(np.ceil(float(vid_len) / max_frames))
+                return int(np.round(float(vid_len) / subsample))
+        else:
+            return config['mm_cfg']['lita']['sample_frames']
+
+    def process_lita_video(self, nemo_config, video_path, image_processor):
+        image = None
+        if isinstance(video_path, str):
+            vid_len = len(decord.VideoReader(video_path))
+            num_sample_frames = self.get_num_sample_frames(nemo_config, vid_len)
+            image = (
+                self.load_video(nemo_config, video_path, image_processor, num_sample_frames)
+                .unsqueeze(0)
+                .to(self.device, dtype=torch.bfloat16)
+            )
+        elif isinstance(video_path, np.ndarray):
+            image = (
+                self.load_video(nemo_config, video_path, image_processor)
+                .unsqueeze(0)
+                .to(self.device, dtype=torch.bfloat16)
+            )
+        return image
+
+    def process_image(self, image_file, image_processor, nemo_config, image_folder):
+        if isinstance(image_file, str):
+            if image_folder is not None:
+                image = Image.open(os.path.join(image_folder, image_file)).convert("RGB")
+            else:
+                image = Image.open(image_file).convert("RGB")
+        else:
+            # image is stored in bytearray
+            image = image_file
+
+        crop_size = nemo_config['mm_cfg']['vision_encoder']['crop_size']
+        crop_size = tuple(crop_size)
+        image = image.resize(crop_size)
+        if nemo_config['data']['image_aspect_ratio'] == 'pad':
+            image = self.expand2square_pt(image, tuple(int(x * 255) for x in image_processor.image_mean))
+            image = image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
+        else:
+            image = image_processor.preprocess(image, return_tensors="pt")["pixel_values"][0]
+        return image
+
+    def process_vila_img(self, images):
+        new_images = [self.process_image(image, self.image_processor, self.nemo_config, None) for image in images]
+
+        if all(x.shape == new_images[0].shape for x in new_images):
+            new_images = torch.stack(new_images, dim=0)
+        return new_images
+
     def setup_inputs(self, input_text, raw_image, batch_size):
         attention_mask = None
+        image = None
 
         if self.model_type == "neva":
             image_size = self.image_size
@@ -370,21 +731,42 @@ def setup_inputs(self, input_text, raw_image, batch_size):
                 f"\n{input_text}\nAssistant\nquality:4,toxicity:0,humor:0,creativity:0,helpfulness:4,correctness:4,coherence:4,complexity:4,verbosity:4\n"
                 ""
             )
+        elif self.model_type in ['vila', 'lita', 'vita']:
+            if self.model_type == "vila" or self.model_type == "lita":
+                pre_prompt = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: "
+                if input_text is None:
+                    input_text = "\n Please elaborate what you see in the images?"
+                post_prompt = input_text + " ASSISTANT:"
+
+            elif self.model_type == "vita":
+                # llama3 prompt template
+                pre_prompt = """<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nYou are a helpful language and vision assistant. "
+                                "You are able to understand the visual content that the user provides, "
+                                "and assist the user with a variety of tasks using natural language. <|start_header_id|>user<|end_header_id|>\n\n"""
+                if input_text is None:
+                    input_text = "\n Please elaborate what you see in the images?"
+                post_prompt = input_text + "<|start_header_id|>assistant<|end_header_id|>\n\n"
+
         else:
             raise RuntimeError(f"Invalid model type {self.model_type}")
 
+        if self.model_type == 'lita' or self.model_type == 'vita':
+            image = self.process_lita_video(self.nemo_config, raw_image, self.image_processor)
+
+        if self.model_type == 'vila':
+            raw_image = [raw_image] * batch_size
+            image = self.process_vila_img(raw_image)
+
         # Repeat inputs to match batch size
         pre_prompt = [pre_prompt] * batch_size
         post_prompt = [post_prompt] * batch_size
-        if image.dim() == 5:
-            image = image.expand(batch_size, -1, -1, -1, -1).contiguous()
-        else:
-            image = image.expand(batch_size, -1, -1, -1).contiguous()
+        if self.model_type not in ['vila', 'lita', 'vita']:
+            if image.dim() == 5:
+                image = image.expand(batch_size, -1, -1, -1, -1).contiguous()
+            else:
+                image = image.expand(batch_size, -1, -1, -1).contiguous()
         image = image.to(self.device)
 
-        # Generate decoder_input_ids for enc-dec models
-        # Custom prompts can be added as:
-        # decoder_input_ids = model.tokenizer(decoder_prompt).input_ids
         decoder_input_ids = None
 
         return input_text, pre_prompt, post_prompt, image, decoder_input_ids, attention_mask
@@ -473,9 +855,10 @@ def print_result(self, input_text, output_text, batch_size, num_beams, run_profi
         logger.info("---------------------------------------------------------")
 
     def load_test_media(self, input_media):
-        if self.model_type == "video-neva":
+        media_model = ["video-neva", "lita", "vita"]
+        if self.model_type in media_model:
             media = input_media
-        elif self.model_type == "neva":
+        elif self.model_type == "neva" or self.model_type == "vila":
             media = Image.open(input_media).convert('RGB')
         else:
             raise RuntimeError(f"Invalid model type {self.model_type}")
diff --git a/nemo/export/quantize/quantizer.py b/nemo/export/quantize/quantizer.py
index e645ed8971c3..590cf50c804c 100644
--- a/nemo/export/quantize/quantizer.py
+++ b/nemo/export/quantize/quantizer.py
@@ -225,7 +225,8 @@ def export(self, model: MegatronGPTModel):
         assert self.export_config is not None, "Export config is not set"
         torch_dtype = torch_dtype_from_precision(self.export_config.dtype)
 
-        self._sample_output(model)
+        if self.export_config.get("sample_output", True):
+            self._sample_output(model)
 
         if model.cfg.megatron_amp_O2:
             model.model = unwrap_model(model.model, Float16Module)
diff --git a/nemo/export/tensorrt_llm.py b/nemo/export/tensorrt_llm.py
index b4299dfd8945..3c73da1c0731 100644
--- a/nemo/export/tensorrt_llm.py
+++ b/nemo/export/tensorrt_llm.py
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import gc
 import json
 import logging
 import os
@@ -23,9 +24,11 @@
 from typing import List, Optional
 
 import numpy as np
+import safetensors
 import tensorrt_llm
 import torch
 import wrapt
+from tensorrt_llm._utils import numpy_to_torch
 
 from nemo.deploy import ITritonDeployable
 from nemo.export.tarutils import TarPath, unpack_tarball
@@ -47,9 +50,11 @@
 use_deploy = True
 try:
     from nemo.deploy.utils import cast_output, str_ndarray2list
-except Exception:
+except Exception as e:
     use_deploy = False
 
+LOGGER = logging.getLogger("NeMo")
+
 
 @wrapt.decorator
 def noop_decorator(func):
@@ -67,8 +72,6 @@ def wrapper(*args, **kwargs):
 except Exception:
     use_pytriton = False
 
-LOGGER = logging.getLogger("NeMo")
-
 
 class TensorRTLLM(ITritonDeployable):
     """
@@ -95,6 +98,8 @@ def __init__(
         lora_ckpt_list: List[str] = None,
         load_model: bool = True,
         use_python_runtime: bool = True,
+        enable_chunked_context: bool = None,
+        max_tokens_in_paged_kv_cache: int = None,
     ):
         """
         Args:
@@ -104,9 +109,19 @@ def __init__(
             use_python_runtime (bool): whether to use python or c++ runtime.
         """
 
+        if use_python_runtime:
+            if enable_chunked_context is not None or max_tokens_in_paged_kv_cache is not None:
+                raise Exception(
+                    "enable_chunked_context and max_tokens_in_paged_kv_cache options "
+                    "work only with the TensorRT-LLM C++ runtime. Please set "
+                    "use_python_runtime=False to use these options."
+                )
+
         self.model_dir = model_dir
         self.lora_ckpt_list = lora_ckpt_list
         self.use_python_runtime = use_python_runtime
+        self.enable_chunked_context = enable_chunked_context if enable_chunked_context is not None else False
+        self.max_tokens_in_paged_kv_cache = max_tokens_in_paged_kv_cache
         self.model = None
         self.tokenizer = None
         self.n_gpus = None
@@ -125,16 +140,16 @@ def export(
         nemo_checkpoint_path: str,
         model_type: Optional[str] = None,
         delete_existing_files: bool = True,
-        n_gpus: int = None,
+        n_gpus: Optional[int] = None,
         tensor_parallelism_size: int = 1,
         pipeline_parallelism_size: int = 1,
-        gpus_per_node: int = None,
+        gpus_per_node: Optional[int] = None,
         max_input_len: int = 256,
         max_output_len: int = 256,
         max_input_token: Optional[int] = None,
         max_output_token: Optional[int] = None,
         max_batch_size: int = 8,
-        max_prompt_embedding_table_size=None,
+        max_prompt_embedding_table_size: Optional[int] = None,
         use_parallel_embedding: bool = False,
         use_embedding_sharing: bool = False,
         paged_kv_cache: bool = True,
@@ -146,8 +161,12 @@ def export(
         use_lora_plugin: str = None,
         lora_target_modules: List[str] = None,
         max_lora_rank: int = 64,
-        max_num_tokens: int = None,
-        opt_num_tokens: int = None,
+        max_num_tokens: Optional[int] = None,
+        opt_num_tokens: Optional[int] = None,
+        max_seq_len: Optional[int] = None,
+        multiple_profiles: bool = False,
+        gpt_attention_plugin: str = "auto",
+        gemm_plugin: str = "auto",
     ):
         """
         Exports nemo checkpoints to TensorRT-LLM.
@@ -179,6 +198,10 @@ def export(
             max_lora_rank (int): maximum lora rank.
             max_num_tokens (int):
             opt_num_tokens (int):
+            max_seq_len (int):
+            multiple_profiles: (bool): enables multiple profiles feature of TRT-LLM. Default = False
+            gpt_attention_plugin (str): enable the gpt attention plugin. Default = "auto"
+            gemm_plugin (str): enable the gpt plugin. Default = "auto"
         """
 
         if n_gpus is not None:
@@ -229,6 +252,17 @@ def export(
             )
             max_output_len = max_output_token
 
+        if max_seq_len is None:
+            max_seq_len = max_input_len + max_output_len
+
+        if max_batch_size < 4:
+            warnings.warn(
+                "TensorRT LLM may hit a runtime issue with batch size is smaller than 4 on some models."
+                " Force set to 4",
+                stacklevel=2,
+            )
+            max_batch_size = 4
+
         if tensorrt_llm.mpi_rank() == 0:
             tmp_dir = tempfile.TemporaryDirectory()
             nemo_export_dir = Path(tmp_dir.name)
@@ -246,6 +280,7 @@ def export(
                     engine_dir=self.model_dir,
                     max_input_len=max_input_len,
                     max_output_len=max_output_len,
+                    max_seq_len=max_seq_len,
                     max_batch_size=max_batch_size,
                     max_prompt_embedding_table_size=max_prompt_embedding_table_size,
                     tensor_parallel_size=tensor_parallelism_size,
@@ -259,6 +294,7 @@ def export(
                     max_lora_rank=max_lora_rank,
                     max_num_tokens=max_num_tokens,
                     opt_num_tokens=opt_num_tokens,
+                    multiple_profiles=multiple_profiles,
                 )
             else:
                 if model_type is None:
@@ -310,6 +346,10 @@ def export(
                         paged_context_fmha=paged_context_fmha,
                         max_num_tokens=max_num_tokens,
                         opt_num_tokens=opt_num_tokens,
+                        max_seq_len=max_seq_len,
+                        multiple_profiles=multiple_profiles,
+                        gpt_attention_plugin=gpt_attention_plugin,
+                        gemm_plugin=gemm_plugin,
                     )
 
             tokenizer_path = os.path.join(nemo_export_dir, "tokenizer.model")
@@ -330,6 +370,84 @@ def export(
         if load_model:
             self._load()
 
+    def convert_to_safe_tensors(
+        self,
+        nemo_checkpoint_path: str,
+        model_type: Optional[str] = None,
+        delete_existing_files: bool = True,
+        tensor_parallelism_size: int = 1,
+        pipeline_parallelism_size: int = 1,
+        gpus_per_node: int = None,
+        use_parallel_embedding: bool = False,
+        use_embedding_sharing: bool = False,
+        dtype: str = "bfloat16",
+    ):
+        gpus_per_node = tensor_parallelism_size if gpus_per_node is None else gpus_per_node
+
+        if Path(self.model_dir).exists():
+            if delete_existing_files and len(os.listdir(self.model_dir)) > 0:
+                for files in os.listdir(self.model_dir):
+                    path = os.path.join(self.model_dir, files)
+                    try:
+                        shutil.rmtree(path)
+                    except OSError:
+                        os.remove(path)
+
+                if len(os.listdir(self.model_dir)) > 0:
+                    raise Exception("Couldn't delete all files.")
+            elif len(os.listdir(self.model_dir)) > 0:
+                raise Exception("There are files in this folder. Try setting delete_existing_files=True.")
+        else:
+            Path(self.model_dir).mkdir(parents=True, exist_ok=True)
+
+        if model_type == "gpt" or model_type == "starcoder":
+            model_type = "gptnext"
+
+        if model_type == "mixtral":
+            model_type = "llama"
+
+        if tensorrt_llm.mpi_rank() == 0:
+            tmp_dir = tempfile.TemporaryDirectory()
+            nemo_export_dir = Path(tmp_dir.name)
+
+            model, model_configs, self.tokenizer = load_nemo_model(nemo_checkpoint_path, nemo_export_dir)
+            weights_dicts, model_configs = model_to_trtllm_ckpt(
+                model=model,
+                nemo_model_config=model_configs,
+                nemo_export_dir=nemo_export_dir,
+                decoder_type=model_type,
+                dtype=dtype,
+                tensor_parallel_size=tensor_parallelism_size,
+                pipeline_parallel_size=pipeline_parallelism_size,
+                gpus_per_node=gpus_per_node,
+                use_parallel_embedding=use_parallel_embedding,
+                use_embedding_sharing=use_embedding_sharing,
+            )
+
+            for weight_dict, model_config in zip(weights_dicts, model_configs):
+                rank = model_config.mapping.tp_rank
+                for k, v in weight_dict.items():
+                    weight_dict[k] = numpy_to_torch(v)
+
+                safetensors.torch.save_file(weight_dict, os.path.join(self.model_dir, f'rank{rank}.safetensors'))
+
+            model_configs[0].to_json_file(os.path.join(self.model_dir, 'config.json'))
+
+            tokenizer_path = os.path.join(nemo_export_dir, "tokenizer.model")
+            if os.path.exists(tokenizer_path):
+                shutil.copy(tokenizer_path, self.model_dir)
+            else:
+                self.tokenizer.save_pretrained(os.path.join(self.model_dir, 'huggingface_tokenizer'))
+
+            nemo_model_config = os.path.join(nemo_export_dir, "model_config.yaml")
+            if os.path.exists(nemo_model_config):
+                shutil.copy(nemo_model_config, self.model_dir)
+
+            tmp_dir.cleanup()
+
+        if tensorrt_llm.mpi_world_size() > 1:
+            tensorrt_llm.mpi_barrier()
+
     def build(
         self,
         model,
@@ -402,6 +520,8 @@ def refit(self, model, model_config):
             tokenizer_vocab_size=self.tokenizer.vocab_size,
         )
         load_distributed(self.model_dir, self.mp_rank, self.gpus_per_node)
+        gc.collect()
+        torch.cuda.empty_cache()
         refit(weights_dict)
 
     def forward(
@@ -838,6 +958,8 @@ def _load(self):
                         engine_dir=self.model_dir,
                         lora_ckpt_list=self.lora_ckpt_list,
                         use_python_runtime=self.use_python_runtime,
+                        enable_chunked_context=self.enable_chunked_context,
+                        max_tokens_in_paged_kv_cache=self.max_tokens_in_paged_kv_cache,
                     )
                     self._load_prompt_tables()
                 except Exception as error:
diff --git a/nemo/export/tensorrt_mm_exporter.py b/nemo/export/tensorrt_mm_exporter.py
index 13bc82b39334..b0536a55f95f 100644
--- a/nemo/export/tensorrt_mm_exporter.py
+++ b/nemo/export/tensorrt_mm_exporter.py
@@ -91,6 +91,7 @@ def export(
         max_input_len: int = 4096,
         max_output_len: int = 256,
         max_batch_size: int = 1,
+        vision_max_batch_size: int = 1,
         max_multimodal_len: int = 3072,
         dtype: str = "bfloat16",
         delete_existing_files: bool = True,
@@ -119,7 +120,7 @@ def export(
             llm_checkpoint_path=llm_checkpoint_path,
             model_type=model_type,
             llm_model_type=llm_model_type,
-            tensor_parallel_size=tensor_parallel_size,
+            tensor_parallelism_size=tensor_parallel_size,
             max_input_len=max_input_len,
             max_output_len=max_output_len,
             max_batch_size=max_batch_size,
@@ -128,7 +129,7 @@ def export(
         )
 
         visual_dir = os.path.join(self.model_dir, "visual_engine")
-        build_visual_engine(visual_dir, visual_checkpoint_path, model_type, max_batch_size)
+        build_visual_engine(visual_dir, visual_checkpoint_path, model_type, vision_max_batch_size)
 
         if load_model:
             self._load()
@@ -192,9 +193,10 @@ def triton_infer_fn(self, **inputs: np.ndarray):
                 )
 
             infer_input = {"input_text": str_ndarray2list(inputs.pop("input_text")[0])}
-            if self.runner.model_type == "neva":
+            video_model_list = ["video-neva", "lita", "vita"]
+            if self.runner.model_type == "neva" or self.runner.model_type == "vila":
                 infer_input["input_image"] = ndarray2img(inputs.pop("input_media")[0])[0]
-            elif self.runner.model_type == "video-neva":
+            elif self.runner.model_type in video_model_list:
                 infer_input["input_image"] = inputs.pop("input_media")[0]
             if "batch_size" in inputs:
                 infer_input["batch_size"] = inputs.pop("batch_size")[0][0]
diff --git a/nemo/export/trt_llm/converter/model_converter.py b/nemo/export/trt_llm/converter/model_converter.py
old mode 100644
new mode 100755
index 2a78f6833782..60d50316e9ed
--- a/nemo/export/trt_llm/converter/model_converter.py
+++ b/nemo/export/trt_llm/converter/model_converter.py
@@ -22,6 +22,8 @@
 from tensorrt_llm._utils import pad_vocab_size
 from tensorrt_llm.functional import non_gated_version
 from tensorrt_llm.layers import MoeConfig
+from tensorrt_llm.models.gpt.config import GPTConfig
+from tensorrt_llm.models.llama.config import LLaMAConfig
 from tensorrt_llm.models.modeling_utils import PretrainedConfig
 
 from nemo.export.trt_llm.converter.model_to_trt_llm_ckpt import (
@@ -33,6 +35,15 @@
 LOGGER = logging.getLogger("NeMo")
 
 
+def get_config(decoder_type, config):
+    if decoder_type == "llama":
+        return LLaMAConfig(**config)
+    elif decoder_type == "gpt" or decoder_type == "gptnext":
+        return GPTConfig(**config)
+    else:
+        return PretrainedConfig(**config)
+
+
 def prompt_convert(prompt_config, prompt_weights):
     if "task_templates" in prompt_config:
         prompt_templates = prompt_config["task_templates"]
@@ -156,11 +167,13 @@ def model_to_trtllm_ckpt(
         'rotary_pct': nemo_model_config.get('rotary_percentage', 1.0),
         'rotary_base': nemo_model_config.get('rotary_base', 10000),
         'moe_num_experts': nemo_model_config.get('num_moe_experts', 0),
-        'moe_top_k': nemo_model_config.get('moe_router_topk'),
+        'moe_top_k': nemo_model_config.get('moe_router_topk', 0),
         'moe_normalization_mode': nemo_model_config.get(
             'moe_renorm_mode', MoeConfig.ExpertScaleNormalizationMode.RENORMALIZE
         ),
-        'moe_tp_mode': nemo_model_config.get('moe_tp_mode', MoeConfig.ParallelismMode.TENSOR_PARALLEL),
+        'moe_tp_mode': nemo_model_config.get(
+            'moe_tp_mode', 2
+        ),  # change MoeConfig.ParallelismMode.TENSOR_PARALLEL to 2
         'logits_dtype': 'float32',
         'world_size': world_size,
         'tp_size': tensor_parallel_size,
@@ -179,7 +192,7 @@ def model_to_trtllm_ckpt(
 
     if use_distributed_convert:
         config["gpus_per_node"] = gpus_per_node
-        model_configs.append(PretrainedConfig(**config))
+        model_configs.append(get_config(decoder_type, config))
         model_configs[0].mapping = tensorrt_llm.Mapping(
             world_size=world_size,
             rank=model_parallel_rank,
@@ -258,7 +271,7 @@ def model_to_trtllm_ckpt(
                 weights_dict_local["transformer.ln_f.bias"] = ln_f_bias
 
         config["gpus_per_node"] = gpus_per_node
-        model_config = PretrainedConfig(**config)
+        model_config = get_config(decoder_type, config)
         model_config.mapping = mapping
         model_configs.append(model_config)
         weights_dicts.append(weights_dict_local)
diff --git a/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py b/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py
index 0345f979b8c2..db8a66308047 100644
--- a/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py
+++ b/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py
@@ -59,7 +59,7 @@ def get_layer_prefix(layer_names, is_mcore):
         if 'self_attention' in layer_name:
             transformer_layer_prefix = layer_name.split('layers')[0]
             break
-    assert transformer_layer_prefix is not None, "Cannot extract transformer layer prefix from {layer_name}"
+    assert transformer_layer_prefix is not None, f"Cannot extract transformer layer prefix from {layer_name}"
     if is_mcore:
         model_prefix = transformer_layer_prefix.split('decoder')[0]
     else:
diff --git a/nemo/export/trt_llm/converter/utils.py b/nemo/export/trt_llm/converter/utils.py
old mode 100644
new mode 100755
index 3768ff4b2844..eab17167cbd5
--- a/nemo/export/trt_llm/converter/utils.py
+++ b/nemo/export/trt_llm/converter/utils.py
@@ -26,7 +26,7 @@
 DECODER_MODEL_TYPE = {
     "gptj": 'GPTForCausalLM',
     "gptnext": 'GPTForCausalLM',
-    "llama": 'LLaMAForCausalLM',
+    "llama": 'LlamaForCausalLM',
     "gemma": 'GemmaForCausalLM',
     "falcon": 'FalconForCausalLM',
 }
diff --git a/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py b/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py
index 1d473f497f51..479d93498475 100644
--- a/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py
+++ b/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py
@@ -244,9 +244,7 @@ def load_nemo_model(nemo_ckpt: Union[str, Path], nemo_export_dir: Union[str, Pat
                 tokenizer_config["model"] = os.path.join(nemo_export_dir, "tokenizer.model")
                 tokenizer = build_tokenizer(tokenizer_config)
         else:
-            raise Exception(
-                "Not a supported nemo file format. " "Only distributed mcore nemo checkpoints are support."
-            )
+            raise Exception("Not a supported NeMo file format: only distributed MCore NeMo checkpoints are supported.")
     finally:
         if isinstance(nemo_dir, TarPath):
             nemo_dir.tarobject.close()
diff --git a/nemo/export/trt_llm/qnemo/qnemo_to_tensorrt_llm.py b/nemo/export/trt_llm/qnemo/qnemo_to_tensorrt_llm.py
index 630330381e56..2a8a9a91e46d 100644
--- a/nemo/export/trt_llm/qnemo/qnemo_to_tensorrt_llm.py
+++ b/nemo/export/trt_llm/qnemo/qnemo_to_tensorrt_llm.py
@@ -12,13 +12,13 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
 import glob
 import os
+import subprocess
 import warnings
 from typing import List, Optional
 
-from modelopt.deploy.llm import build_tensorrt_llm
+from tensorrt_llm.models import PretrainedConfig
 
 from nemo.export.trt_llm.qnemo.utils import CONFIG_NAME, WEIGHTS_NAME
 
@@ -28,50 +28,97 @@ def qnemo_to_tensorrt_llm(
     engine_dir: str,
     max_input_len: int,
     max_output_len: int,
+    max_seq_len: Optional[int],
     max_batch_size: int,
     max_prompt_embedding_table_size: int,
-    tensor_parallel_size: int = None,
-    pipeline_parallel_size: int = None,
+    tensor_parallel_size: Optional[int] = None,
+    pipeline_parallel_size: Optional[int] = None,
     use_parallel_embedding: bool = False,
     paged_kv_cache: bool = True,
     remove_input_padding: bool = True,
     enable_multi_block_mode: bool = False,
-    use_lora_plugin: str = None,
+    use_lora_plugin: Optional[str] = None,
     lora_target_modules: Optional[List[str]] = None,
     max_lora_rank: int = 64,
-    max_num_tokens: int = None,
-    opt_num_tokens: int = None,
+    max_num_tokens: Optional[int] = None,
+    opt_num_tokens: Optional[int] = None,
+    max_beam_width: int = 1,
+    multiple_profiles: bool = False,
 ):
-    """Build TensorRT-LLM engine with ModelOpt build_tensorrt_llm function."""
+    """Build TensorRT-LLM engine with trtllm-build command in a subprocess."""
     assert not lora_target_modules, f"LoRA is not supported for quantized checkpoints, got {lora_target_modules}"
 
     warnings.warn(
-        "Note that setting tensor_parallel_size and pipeline_parallel_size parameters"
-        " for quantized models should be done on calibration step with nemo.export.quantize module."
+        "Note that setting tensor_parallel_size, pipeline_parallel_size and use_parallel_embedding "
+        " parameters for quantized models is done on calibration step with nemo.export.quantize module."
         " These parameters are ignored when building and running TensorRT-LLM engine below.",
         UserWarning,
         stacklevel=3,
     )
 
-    warnings.warn(
-        "Also use_parallel_embedding, paged_kv_cache, remove_input_padding, enable_multi_block_mode, max_num_tokens"
-        " and opt_num_tokens parameters are set by ModelOpt build_tensorrt_llm function in the optimal way and are"
-        " ignored on engine build step.",
-        UserWarning,
-        stacklevel=3,
-    )
-
     num_build_workers = len(glob.glob(os.path.join(nemo_checkpoint_path, WEIGHTS_NAME.format("*"))))
     assert num_build_workers, f"No TensorRT-LLM weight files found in {nemo_checkpoint_path}"
 
-    build_tensorrt_llm(
-        pretrained_config=os.path.join(nemo_checkpoint_path, CONFIG_NAME),
-        engine_dir=engine_dir,
-        max_input_len=max_input_len,
-        max_output_len=max_output_len,
-        max_batch_size=max_batch_size,
-        max_beam_width=1,
-        num_build_workers=num_build_workers,
-        enable_sparsity=False,
-        max_prompt_embedding_table_size=max_prompt_embedding_table_size,
-    )
+    config = PretrainedConfig.from_json_file(os.path.join(nemo_checkpoint_path, CONFIG_NAME))
+
+    log_level = "warning"
+
+    quant_algo = config.quantization.quant_algo
+
+    use_fused_mlp = quant_algo in [
+        "FP8",
+        None,
+    ] and config.hidden_act in ["silu", "swiglu", "fast-swiglu", "gelu", "geglu"]
+
+    use_qdq = quant_algo in ["FP8", "W8A8_SQ_PER_CHANNEL"]
+
+    builder_opt = 4 if "RecurrentGemma" not in config.architecture else 0
+
+    speculative_decoding_mode = "medusa" if "Medusa" in config.architecture else None
+
+    build_cmd = "trtllm-build "
+    build_cmd += f"--checkpoint_dir {nemo_checkpoint_path} "
+    build_cmd += f"--log_level {log_level} "
+    build_cmd += f"--output_dir {engine_dir} "
+    build_cmd += f"--workers {num_build_workers} "
+    build_cmd += f"--max_batch_size {max_batch_size} "
+    build_cmd += f"--max_input_len {max_input_len} "
+    build_cmd += f"--max_output_len {max_output_len} "
+    build_cmd += f"--max_beam_width {max_beam_width} "
+    build_cmd += f"--tp_size {config.mapping.tp_size} "
+    build_cmd += f"--pp_size {config.mapping.pp_size} "
+    build_cmd += f"--max_prompt_embedding_table_size {max_prompt_embedding_table_size} "
+    build_cmd += f"--builder_opt {builder_opt} "
+    build_cmd += f"--gpt_attention_plugin {config.dtype} "
+    build_cmd += f"--nccl_plugin {config.dtype} "
+    build_cmd += f"--paged_kv_cache {'enable' if paged_kv_cache else 'disable'} "
+    build_cmd += f"--remove_input_padding {'enable' if remove_input_padding else 'disable'} "
+    build_cmd += f"--multi_block_mode {'enable' if enable_multi_block_mode else 'disable'} "
+    build_cmd += f"--multiple_profiles {'enable' if multiple_profiles else 'disable'} "
+
+    if use_fused_mlp:
+        build_cmd += "--use_fused_mlp " if "RecurrentGemma" not in config.architecture else ""
+
+    if not use_qdq:
+        build_cmd += f"--gemm_plugin {config.dtype} "
+
+    if max_seq_len:
+        build_cmd += f"--max_seq_len {max_seq_len} "
+
+    if max_num_tokens:
+        build_cmd += f"--max_num_tokens {max_num_tokens} "
+    else:
+        build_cmd += f"--max_num_tokens {max_batch_size * max_input_len} "
+
+    if opt_num_tokens is not None:
+        build_cmd += f"--opt_num_tokens {opt_num_tokens} "
+
+    if speculative_decoding_mode:
+        build_cmd += f"--speculative_decoding_mode {speculative_decoding_mode} "
+
+    build_cmd = build_cmd.replace("--", "\\\n  --")  # Separate parameters line by line
+
+    print("trtllm-build command:")
+    print(build_cmd)
+
+    subprocess.run(build_cmd, shell=True, check=True)
diff --git a/nemo/export/trt_llm/tensorrt_llm_build.py b/nemo/export/trt_llm/tensorrt_llm_build.py
old mode 100644
new mode 100755
index b329de2a3b18..1544fdf032d8
--- a/nemo/export/trt_llm/tensorrt_llm_build.py
+++ b/nemo/export/trt_llm/tensorrt_llm_build.py
@@ -45,41 +45,48 @@ def build_and_save_engine(
     paged_kv_cache: bool = True,
     remove_input_padding: bool = True,
     paged_context_fmha: bool = False,
-    custom_all_reduce: bool = True,
+    use_custom_all_reduce: bool = True,
     use_refit: bool = False,
     max_num_tokens: int = None,
+    max_seq_len: int = None,
     opt_num_tokens: int = None,
     max_beam_width: int = 1,
     tokens_per_block: int = 128,
+    multiple_profiles: bool = False,
+    gpt_attention_plugin: str = "auto",
+    gemm_plugin: str = "auto",
 ):
+    architecture = "LLaMAForCausalLM" if model_config.architecture == "LlamaForCausalLM" else model_config.architecture
     try:
-        model_cls = getattr(tensorrt_llm.models, model_config.architecture)
+        model_cls = getattr(tensorrt_llm.models, architecture)
     except:
         raise AttributeError(f"Could not find TRTLLM model type: {model_type}!")
 
     logger.set_level("info")
-    str_dtype = model_config.dtype
     plugin_config = PluginConfig()
-    plugin_config.set_gpt_attention_plugin(dtype=str_dtype)
-    plugin_config.set_gemm_plugin(dtype=str_dtype)
-    plugin_config.use_custom_all_reduce = custom_all_reduce
-    plugin_config.set_plugin("multi_block_mode", enable_multi_block_mode)
+    plugin_config.gpt_attention_plugin = gpt_attention_plugin
+    plugin_config.gemm_plugin = gemm_plugin
+    plugin_config.set_nccl_plugin(use_custom_all_reduce=use_custom_all_reduce)
+    plugin_config.multi_block_mode = enable_multi_block_mode
     if paged_kv_cache:
         plugin_config.enable_paged_kv_cache(tokens_per_block=tokens_per_block)
     else:
         plugin_config.paged_kv_cache = False
     plugin_config.remove_input_padding = remove_input_padding
     plugin_config.use_paged_context_fmha = paged_context_fmha
+    plugin_config.multiple_profiles = multiple_profiles
 
     max_num_tokens, opt_num_tokens = check_max_num_tokens(
         max_num_tokens=max_num_tokens,
         opt_num_tokens=opt_num_tokens,
+        max_seq_len=max_seq_len,
         max_batch_size=max_batch_size,
         max_input_len=max_input_len,
         max_beam_width=max_beam_width,
         remove_input_padding=remove_input_padding,
         enable_context_fmha=plugin_config.context_fmha,
         tokens_per_block=tokens_per_block,
+        multiple_profiles=multiple_profiles,
     )
 
     build_dict = {
@@ -87,6 +94,7 @@ def build_and_save_engine(
         'max_output_len': max_output_len,
         'max_batch_size': max_batch_size,
         'max_beam_width': max_beam_width,
+        'max_seq_len': max_seq_len,
         'max_num_tokens': max_num_tokens,
         'opt_num_tokens': opt_num_tokens,
         'max_prompt_embedding_table_size': max_prompt_embedding_table_size,
@@ -95,11 +103,13 @@ def build_and_save_engine(
         'strongly_typed': False,
         'builder_opt': None,
         'use_refit': use_refit,
+        'multiple_profiles': multiple_profiles,
     }
     build_config = BuildConfig.from_dict(build_dict, plugin_config=plugin_config)
 
     if use_lora_plugin is not None:
-        build_config.plugin_config.set_lora_plugin(use_lora_plugin)
+        # build_config.plugin_config.set_lora_plugin(use_lora_plugin)
+        # build_config.plugin_config._lora_plugin = use_lora_plugin
         lora_config = LoraConfig(
             lora_dir=lora_ckpt_list,
             lora_ckpt_source='nemo',
diff --git a/nemo/export/trt_llm/tensorrt_llm_run.py b/nemo/export/trt_llm/tensorrt_llm_run.py
index dbbf40cc3cf1..14ad0be699bb 100644
--- a/nemo/export/trt_llm/tensorrt_llm_run.py
+++ b/nemo/export/trt_llm/tensorrt_llm_run.py
@@ -26,15 +26,26 @@
 import tensorrt_llm
 import torch
 from mpi4py.futures import MPIPoolExecutor
-from tensorrt_llm.bindings import GptJsonConfig, GptSession, GptSessionConfig, KvCacheConfig, WorldConfig
 from tensorrt_llm.lora_manager import LoraManager
 from tensorrt_llm.quantization import QuantMode
 from tensorrt_llm.runtime import ModelConfig, ModelRunner, ModelRunnerCpp, SamplingConfig
-from tensorrt_llm.runtime.model_runner_cpp import ModelRunnerCppGptSession
+
 from transformers import PreTrainedTokenizer
 
 LOGGER = logging.getLogger("NeMo")
 
+use_trtllm_bindings = True
+try:
+    from tensorrt_llm.bindings import GptJsonConfig, GptSession, GptSessionConfig, KvCacheConfig, WorldConfig
+except Exception as e:
+    use_trtllm_bindings = False
+
+use_cpp_gpt_session = True
+try:
+    from tensorrt_llm.runtime.model_runner_cpp import ModelRunnerCppGptSession
+except Exception as e:
+    use_cpp_gpt_session = False
+
 
 @dataclass
 class TensorrtLLMHostContext:
@@ -131,6 +142,8 @@ def _load(
     lora_ckpt_list=None,
     num_beams=1,
     use_python_runtime: bool = True,
+    enable_chunked_context: bool = False,
+    max_tokens_in_paged_kv_cache: int = None,
 ):
     """The impl of `load` API for on a single GPU worker."""
     try:
@@ -145,7 +158,7 @@ def _load(
 
         max_batch_size = config["build_config"]["max_batch_size"]
         max_input_len = config["build_config"]["max_input_len"]
-        max_output_len = config["build_config"]["max_output_len"]
+        # max_output_len = config["build_config"]["max_output_len"]
         max_beam_width = config["build_config"]["max_beam_width"]
 
         runtime_rank = tensorrt_llm.mpi_rank()
@@ -166,8 +179,10 @@ def _load(
                 rank=runtime_rank,
                 max_batch_size=max_batch_size,
                 max_input_len=max_input_len,
-                max_output_len=max_output_len,
+                # max_output_len=max_output_len,
                 max_beam_width=max_beam_width,
+                enable_chunked_context=enable_chunked_context,
+                max_tokens_in_paged_kv_cache=max_tokens_in_paged_kv_cache,
                 debug_mode=False,
             )
 
@@ -279,6 +294,8 @@ def load(
     lora_ckpt_list: List[str] = None,
     num_beams: int = 1,
     use_python_runtime: bool = True,
+    enable_chunked_context: bool = False,
+    max_tokens_in_paged_kv_cache: int = None,
 ) -> TensorrtLLMHostContext:
     """Loaded the compiled LLM model and run it.
 
@@ -290,17 +307,42 @@ def load(
         config = json.load(f)
     world_size = config["pretrained_config"]["mapping"]["world_size"]
     if world_size == 1:
-        _load(tokenizer, engine_dir, lora_ckpt_list, num_beams, use_python_runtime)
+        _load(
+            tokenizer,
+            engine_dir,
+            lora_ckpt_list,
+            num_beams,
+            use_python_runtime,
+            enable_chunked_context,
+            max_tokens_in_paged_kv_cache,
+        )
         executor = None
     elif tensorrt_llm.mpi_world_size() > 1:
-        _load(tokenizer, engine_dir, lora_ckpt_list, num_beams, use_python_runtime)
+        _load(
+            tokenizer,
+            engine_dir,
+            lora_ckpt_list,
+            num_beams,
+            use_python_runtime,
+            enable_chunked_context,
+            max_tokens_in_paged_kv_cache,
+        )
         executor = None
         tensorrt_llm.mpi_barrier()
     else:
         executor = MPIPoolExecutor(max_workers=world_size)
         futures = []
         for _ in range(world_size):
-            future = executor.submit(_load, tokenizer, engine_dir, lora_ckpt_list, num_beams, use_python_runtime)
+            future = executor.submit(
+                _load,
+                tokenizer,
+                engine_dir,
+                lora_ckpt_list,
+                num_beams,
+                use_python_runtime,
+                enable_chunked_context,
+                max_tokens_in_paged_kv_cache,
+            )
             futures.append(future)
         for future in futures:
             future.result()
diff --git a/nemo/export/vllm/model_loader.py b/nemo/export/vllm/model_loader.py
index e7f3f1d1569f..2acdc127d674 100644
--- a/nemo/export/vllm/model_loader.py
+++ b/nemo/export/vllm/model_loader.py
@@ -22,7 +22,7 @@
 import tensorstore  # needed to register 'bfloat16' dtype with numpy for zarr compatibility
 import torch
 import zarr
-from vllm.config import CacheConfig, DeviceConfig, LoRAConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig
+from vllm.config import CacheConfig, DeviceConfig, LoRAConfig, MultiModalConfig, ParallelConfig, SchedulerConfig
 from vllm.model_executor.model_loader.loader import BaseModelLoader, _initialize_model
 from vllm.model_executor.model_loader.utils import set_default_torch_dtype
 
@@ -74,7 +74,7 @@ def load_model(
         model_config: NemoModelConfig,
         device_config: DeviceConfig,
         lora_config: Optional[LoRAConfig],
-        vision_language_config: Optional[VisionLanguageConfig],
+        multimodal_config: Optional[MultiModalConfig],
         parallel_config: ParallelConfig,
         scheduler_config: SchedulerConfig,
         cache_config: CacheConfig,
@@ -88,9 +88,7 @@ def load_model(
 
         with set_default_torch_dtype(model_config.dtype):
             with torch.device(device_config.device):
-                model = _initialize_model(
-                    model_config, self.load_config, lora_config, vision_language_config, cache_config
-                )
+                model = _initialize_model(model_config, self.load_config, lora_config, multimodal_config, cache_config)
 
             weights_iterator = model_config.model_converter.convert_weights(model_config.nemo_model_config, state_dict)
 
diff --git a/nemo/export/vllm_exporter.py b/nemo/export/vllm_exporter.py
index f3dd6c8a248b..de06ea830e07 100644
--- a/nemo/export/vllm_exporter.py
+++ b/nemo/export/vllm_exporter.py
@@ -240,9 +240,11 @@ def export(
             device_config=device_config,
             load_config=load_config,
             lora_config=None,
-            vision_language_config=None,
+            multimodal_config=None,
             speculative_config=None,
             decoding_config=None,
+            observability_config=None,
+            prompt_adapter_config=None,
             executor_class=executor_class,
             log_stats=log_stats,
         )
diff --git a/nemo/lightning/_strategy_lib.py b/nemo/lightning/_strategy_lib.py
index e6452de16512..a5a395a51108 100644
--- a/nemo/lightning/_strategy_lib.py
+++ b/nemo/lightning/_strategy_lib.py
@@ -61,12 +61,14 @@ def init_parallel_ranks(
         global_rank=init_global_rank,
         local_rank=init_local_rank,
         tensor_model_parallel_size=parallel_config.tensor_model_parallel_size,
+        expert_model_parallel_size=parallel_config.expert_model_parallel_size,
         pipeline_model_parallel_size=parallel_config.pipeline_model_parallel_size,
         virtual_pipeline_model_parallel_size=parallel_config.virtual_pipeline_model_parallel_size,
+        context_parallel_size=parallel_config.context_parallel_size,
         seed=seed,
         pipeline_model_parallel_split_rank=getattr(parallel_config, "pipeline_model_parallel_split_rank", None),
         use_fp8=fp8,
-        init_mpi_proc_group=getattr(parallel_config, "ub_tp_comm_overlap", False),
+        init_mpi_proc_group=getattr(parallel_config, "tp_comm_overlap", False),
         # apex_transformer_log_level=self.cfg.get('apex_transformer_log_level', 30),
     )
 
@@ -92,6 +94,8 @@ def init_model_parallel(model: Optional[nn.Module] = None) -> None:
                 pipeline_model_parallel_size=app_state.pipeline_model_parallel_size,
                 virtual_pipeline_model_parallel_size=app_state.virtual_pipeline_model_parallel_size,
                 pipeline_model_parallel_split_rank=app_state.pipeline_model_parallel_split_rank,
+                context_parallel_size=app_state.context_parallel_size,
+                expert_model_parallel_size=app_state.expert_model_parallel_size,
             )
 
             # assert that fake tp and pp rank match after model parallel init
@@ -123,19 +127,19 @@ def set_model_parallel_attributes(model, parallelism):
     # Right now mcore sub-classes ModelParellelConfig, we should remove that
     # Given Lightning's structure it would be better if parallelism is a different object
     # Since then it can be passed to the Strategy
-
+    # Note: Importing nemo.lightning.pytorch.strategies creates an import cycle.
     from megatron.core.transformer.transformer_config import TransformerConfig
 
+    assert (
+        type(parallelism).__name__ == 'ParallelismConfig'
+    ), f"Expected parallelism config to be of type ParallelismConfig, but got {type(parallelism)}"
     has_mcore_config = isinstance(getattr(model, "config", None), TransformerConfig)
     if has_mcore_config and hasattr(model, "configure_model"):
         config: TransformerConfig = model.config
-        config.tensor_model_parallel_size = parallelism.tensor_model_parallel_size
-        config.pipeline_model_parallel_size = parallelism.pipeline_model_parallel_size
-        config.virtual_pipeline_model_parallel_size = parallelism.virtual_pipeline_model_parallel_size
-        config.context_parallel_size = parallelism.context_parallel_size
-        config.expert_model_parallel_size = parallelism.expert_model_parallel_size
-        config.moe_extended_tp = parallelism.moe_extended_tp
-        config.sequence_parallel = parallelism.sequence_parallel
+        for attr_name in filter(lambda x: not x.startswith('__'), dir(parallelism)):
+            if not hasattr(config, attr_name):
+                continue
+            setattr(config, attr_name, getattr(parallelism, attr_name))
 
         return config
 
@@ -515,4 +519,7 @@ def load_model_state_dict(megatron_parallel, checkpoint: Mapping[str, Any], stri
             elif count > n_nesting:
                 to_remove = "module." * (count - n_nesting)
                 _state_dict[key[len(to_remove) :]] = value
+            else:
+                _state_dict[key] = value
+
         module.load_state_dict(_state_dict, strict=strict)
diff --git a/nemo/lightning/base.py b/nemo/lightning/base.py
index 128ecb661efd..0684cbeee2da 100644
--- a/nemo/lightning/base.py
+++ b/nemo/lightning/base.py
@@ -45,8 +45,11 @@ def teardown(trainer: Trainer, model: Optional[nn.Module] = None) -> None:
     trainer._teardown()  # noqa: SLF001
     if model is not None:
         for obj in gc.get_objects():
-            if torch.is_tensor(obj) and obj.is_cuda:
-                del obj
+            try:
+                if torch.is_tensor(obj) and obj.is_cuda:
+                    del obj
+            except:
+                pass
 
     gc.collect()
     torch.cuda.empty_cache()
diff --git a/nemo/lightning/data.py b/nemo/lightning/data.py
index d83f5ba3b728..a26354fee8be 100644
--- a/nemo/lightning/data.py
+++ b/nemo/lightning/data.py
@@ -8,6 +8,7 @@
 from torch.utils.data import DataLoader, Dataset
 
 
+## TODO: remove? unused
 def create_dataloader(
     dataset: "Dataset", drop_last: bool = True, pad_samples_to_global_batch_size=False, **kwargs
 ) -> DataLoader:
@@ -46,36 +47,73 @@ def setup_microbatch_calculator(
     from nemo.lightning._strategy_lib import NEMO_MEGATRON_MODEL_PARALLEL_APPSTATE_OVERRIDE
     from nemo.utils import AppState
 
+    try:
+        from megatron.core.num_microbatches_calculator import (
+            ConstantNumMicroBatchesCalculator,
+            get_current_global_batch_size,
+            get_micro_batch_size,
+            get_num_microbatches,
+            init_num_microbatches_calculator,
+        )
+
+        MCORE_MB_CALCULATOR = True
+
+    except (ImportError, ModuleNotFoundError):
+        logging.warning("Megatron num_microbatches_calculator not found, using Apex version.")
+        from apex.transformer.microbatches import ConstantNumMicroBatches as ConstantNumMicroBatchesCalculator
+        from apex.transformer.pipeline_parallel.utils import (
+            get_current_global_batch_size,
+            get_micro_batch_size,
+            get_num_microbatches,
+        )
+        from apex.transformer.pipeline_parallel.utils import (
+            setup_microbatch_calculator as init_num_microbatches_calculator,
+        )
+
+        MCORE_MB_CALCULATOR = False
+
     app_state = AppState()
 
     if os.environ.get(NEMO_MEGATRON_MODEL_PARALLEL_APPSTATE_OVERRIDE, "false").lower() == "true":
         init_global_rank = app_state.global_rank
     else:
         init_global_rank = global_rank
-
-    from apex.transformer.microbatches import ConstantNumMicroBatches
-    from apex.transformer.pipeline_parallel.utils import (
-        _GLOBAL_NUM_MICROBATCHES_CALCULATOR,
-        setup_microbatch_calculator,
-    )
-
-    if _GLOBAL_NUM_MICROBATCHES_CALCULATOR is None:
-        setup_microbatch_calculator(
-            rank=init_global_rank,
-            global_batch_size=global_batch_size,
-            micro_batch_size=micro_batch_size,
-            data_parallel_size=app_state.data_parallel_size,
-            rampup_batch_size=rampup_batch_size,
-        )
+    if MCORE_MB_CALCULATOR:
+        from megatron.core.num_microbatches_calculator import _GLOBAL_NUM_MICROBATCHES_CALCULATOR
+
+        if _GLOBAL_NUM_MICROBATCHES_CALCULATOR is None:
+            init_num_microbatches_calculator(
+                rank=init_global_rank,
+                global_batch_size=global_batch_size,
+                micro_batch_size=micro_batch_size,
+                data_parallel_size=app_state.data_parallel_size,
+                rampup_batch_size=rampup_batch_size,
+            )
+        else:
+            if isinstance(_GLOBAL_NUM_MICROBATCHES_CALCULATOR, ConstantNumMicroBatchesCalculator):
+                assert get_current_global_batch_size() == global_batch_size
+                assert get_micro_batch_size() == micro_batch_size
+                assert get_num_microbatches() == global_batch_size // (micro_batch_size * app_state.data_parallel_size)
+            else:
+                raise Exception("Microbatch calculator already initialized.")
     else:
-        if isinstance(_GLOBAL_NUM_MICROBATCHES_CALCULATOR, ConstantNumMicroBatches):
-            assert _GLOBAL_NUM_MICROBATCHES_CALCULATOR.current_global_batch_size == global_batch_size
-            assert _GLOBAL_NUM_MICROBATCHES_CALCULATOR.micro_batch_size == micro_batch_size
-            assert _GLOBAL_NUM_MICROBATCHES_CALCULATOR.num_micro_batches == global_batch_size // (
-                micro_batch_size * app_state.data_parallel_size
+        from apex.transformer.pipeline_parallel.utils import _GLOBAL_NUM_MICROBATCHES_CALCULATOR
+
+        if _GLOBAL_NUM_MICROBATCHES_CALCULATOR is None:
+            init_num_microbatches_calculator(
+                rank=init_global_rank,
+                global_batch_size=global_batch_size,
+                micro_batch_size=micro_batch_size,
+                data_parallel_size=app_state.data_parallel_size,
+                rampup_batch_size=rampup_batch_size,
             )
         else:
-            raise Exception("Microbatch calculator already initialized.")
+            if isinstance(_GLOBAL_NUM_MICROBATCHES_CALCULATOR, ConstantNumMicroBatchesCalculator):
+                assert get_current_global_batch_size() == global_batch_size
+                assert get_micro_batch_size() == micro_batch_size
+                assert get_num_microbatches() == global_batch_size // (micro_batch_size * app_state.data_parallel_size)
+            else:
+                raise Exception("Microbatch calculator already initialized.")
 
 
 def add_megatron_sampler(
@@ -85,10 +123,13 @@ def add_megatron_sampler(
     rampup_batch_size: Optional[List[int]] = None,
     consumed_samples: int = 0,
     dataloader_type: Literal["single", "cyclic"] = "single",
+    drop_last: bool = True,
+    pad_samples_to_global_batch_size: bool = False,
     # data_sharding: bool = False
 ) -> DataLoader:
     from megatron.core import parallel_state
 
+    ## TODO: expose drop_last and pad_samples_to_global_batch_size args
     if dataloader_type == 'single':
         batch_sampler = MegatronPretrainingSampler(
             total_samples=len(dataloader.dataset),
@@ -98,8 +139,8 @@ def add_megatron_sampler(
             rampup_batch_size=rampup_batch_size,
             data_parallel_rank=parallel_state.get_data_parallel_rank(),
             data_parallel_size=parallel_state.get_data_parallel_world_size(),
-            drop_last=getattr(dataloader, "_drop_last", False),
-            pad_samples_to_global_batch_size=getattr(dataloader, "_pad_samples_to_global_batch_size", False),
+            drop_last=drop_last,
+            pad_samples_to_global_batch_size=pad_samples_to_global_batch_size,
         )
     elif dataloader_type == 'cyclic':
         batch_sampler = MegatronPretrainingRandomSampler(
@@ -108,7 +149,7 @@ def add_megatron_sampler(
             micro_batch_size=micro_batch_size,
             data_parallel_rank=parallel_state.get_data_parallel_rank(),
             data_parallel_size=parallel_state.get_data_parallel_world_size(),
-            pad_samples_to_global_batch_size=getattr(dataloader, "_pad_samples_to_global_batch_size", False),
+            drop_last=drop_last,
             # data_sharding=data_sharding
         )
     else:
@@ -124,6 +165,14 @@ def add_megatron_sampler(
     )
 
 
+class WrappedDataLoader(DataLoader):
+    """Wrapper around torch DataLoader which stores the dataloader mode"""
+
+    def __init__(self, mode="train", **dataloader_kwargs):
+        super().__init__(**dataloader_kwargs)
+        self.mode = mode
+
+
 # TODO: Replace this with megatron.core.data.data_samplers after we upgrade
 class BaseMegatronSampler:
     def __init__(
@@ -141,8 +190,6 @@ def __init__(
         # Sanity checks.
         if total_samples <= 0:
             raise RuntimeError(f"no sample to consume: {total_samples}")
-        if consumed_samples >= total_samples:
-            raise RuntimeError(f"no samples left to consume: {consumed_samples}, {total_samples}")
         if micro_batch_size <= 0:
             raise RuntimeError(f"micro_batch_size size must be greater than 0, but {micro_batch_size}")
         if data_parallel_size <= 0:
@@ -197,6 +244,32 @@ def __iter__(self): ...
 
 
 class MegatronPretrainingSampler(BaseMegatronSampler):
+    def __init__(
+        self,
+        total_samples: int,
+        consumed_samples: int,
+        micro_batch_size: int,
+        data_parallel_rank: int,
+        data_parallel_size: int,
+        drop_last: bool = True,
+        global_batch_size: Optional[int] = None,
+        rampup_batch_size: Optional[list] = None,
+        pad_samples_to_global_batch_size: Optional[bool] = False,
+    ):
+        super().__init__(
+            total_samples=total_samples,
+            consumed_samples=consumed_samples,
+            micro_batch_size=micro_batch_size,
+            data_parallel_rank=data_parallel_rank,
+            data_parallel_size=data_parallel_size,
+            drop_last=drop_last,
+            global_batch_size=global_batch_size,
+            rampup_batch_size=rampup_batch_size,
+            pad_samples_to_global_batch_size=pad_samples_to_global_batch_size,
+        )
+        if consumed_samples >= total_samples:
+            raise RuntimeError(f"no samples left to consume: {consumed_samples}, {total_samples}")
+
     def get_start_end_idx(self):
         start_idx = self.data_parallel_rank * self.micro_batch_size
         end_idx = start_idx + self.micro_batch_size
diff --git a/nemo/lightning/fabric/strategies.py b/nemo/lightning/fabric/strategies.py
index a662386a9119..5c2b634ea282 100644
--- a/nemo/lightning/fabric/strategies.py
+++ b/nemo/lightning/fabric/strategies.py
@@ -23,7 +23,6 @@
 from lightning_fabric.plugins.precision import Precision
 from lightning_fabric.strategies import DDPStrategy
 from lightning_fabric.strategies.strategy import _validate_keys_for_strict_loading
-from lightning_fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1
 from lightning_fabric.utilities.types import _PATH, _Stateful
 from megatron.core.distributed import DistributedDataParallelConfig
 from pytorch_lightning.loops.fetchers import _DataFetcher
@@ -208,7 +207,7 @@ def module_init_context(self, empty_init: Optional[bool] = None) -> ContextManag
         precision_init_ctx = self.precision.module_init_context()
         module_sharded_ctx = self.megatron_context()
         stack = ExitStack()
-        if _TORCH_GREATER_EQUAL_2_1 and empty_init:
+        if empty_init:
             # Materialization happens in `setup`. When modules get wrapped by FSDP, the sequence of operations is:
             # 1) materialize module 2) call `reset_parameters()` 3) shard the module.
             # These operations are applied to each submodule 'bottom up' in the module hierarchy.
diff --git a/nemo/lightning/io/api.py b/nemo/lightning/io/api.py
index cc594b562cff..4d31f020c44a 100644
--- a/nemo/lightning/io/api.py
+++ b/nemo/lightning/io/api.py
@@ -1,11 +1,13 @@
+import json
 from pathlib import Path
+from pydoc import locate
 from typing import Any, Callable, Optional, Type, TypeVar
 
 import fiddle as fdl
 import pytorch_lightning as pl
 from fiddle._src.experimental import serialization
 
-from nemo.lightning.io.mixin import ConnectorMixin, ConnT, ModelConnector
+from nemo.lightning.io.mixin import ConnectorMixin, ConnT, ModelConnector, track_io
 from nemo.lightning.io.pl import TrainerContext
 
 CkptType = TypeVar("CkptType")
@@ -41,6 +43,14 @@ def load(path: Path, output_type: Type[CkptType] = Any) -> CkptType:
     if not _path.is_file():
         raise FileNotFoundError(f"No such file: '{_path}'")
 
+    ## add IO functionality to custom objects present in the json file
+    with open(_path) as f:
+        j = json.load(f)
+        for obj, val in j["objects"].items():
+            clss = ".".join([val["type"]["module"], val["type"]["name"]])
+            if not serialization.find_node_traverser(locate(clss)):
+                track_io(locate(clss))
+
     with open(_path, "rb") as f:
         config = serialization.load_json(f.read())
 
diff --git a/nemo/lightning/io/connector.py b/nemo/lightning/io/connector.py
index 500d0203cfd4..1cd45648ea20 100644
--- a/nemo/lightning/io/connector.py
+++ b/nemo/lightning/io/connector.py
@@ -2,7 +2,7 @@
 import logging
 import os
 import shutil
-from pathlib import Path, PosixPath, WindowsPath
+from pathlib import Path, PosixPath, PurePath, WindowsPath
 from typing import Generic, Optional, Tuple, TypeVar
 
 import pytorch_lightning as pl
@@ -139,7 +139,7 @@ def nemo_setup(self, model: pl.LightningModule, trainer: Optional[pl.Trainer] =
         from nemo.lightning import MegatronStrategy, Trainer
 
         _trainer = trainer or Trainer(
-            devices=1, accelerator="cpu", strategy=MegatronStrategy(store_optimizer_states=False, ddp="pytorch")
+            devices=1, accelerator="cpu", strategy=MegatronStrategy(store_optimizer_states=False)
         )
 
         _trainer.strategy.connect(model)
@@ -160,12 +160,9 @@ def nemo_save(self, output_path: Path, trainer: pl.Trainer) -> None:
             output_path (Path): The path where the model checkpoint will be saved.
             trainer (pl.Trainer): The trainer with the strategy to save the model.
         """
-        _setup_kwargs = {}
-        setup_signature = inspect.signature(trainer.strategy.setup)
-        if 'setup_optimizers' in setup_signature.parameters:
-            _setup_kwargs["setup_optimizers"] = False
-
-        trainer.strategy.setup(trainer, **_setup_kwargs)
+        trainer.strategy._setup_optimizers = False
+        trainer.strategy._init_model_parallel = False
+        trainer.strategy.setup(trainer)
         trainer.save_checkpoint(output_path)
 
     def nemo_load(
@@ -215,6 +212,10 @@ def local_path(self, base_path: Optional[Path] = None) -> Path:
 
             _base = Path(NEMO_MODELS_CACHE)
 
+        # If the useu supplied `hf:///path/to/downloaded/my-model/`
+        # then extract the last dir-name (i.e. my-model) and append it to _base
+        if str(self).startswith('/'):
+            return _base / PurePath((str(self))).name
         return _base / str(self).replace("://", "/")
 
     def on_import_ckpt(self, model: pl.LightningModule):
diff --git a/nemo/lightning/io/fdl_torch.py b/nemo/lightning/io/fdl_torch.py
index c74e48e1c411..aa46341a105f 100644
--- a/nemo/lightning/io/fdl_torch.py
+++ b/nemo/lightning/io/fdl_torch.py
@@ -5,7 +5,9 @@
 """
 
 import types
+from functools import partial
 
+import fiddle as fdl
 import libcst as cst
 import torch
 import torch.nn as nn
@@ -110,6 +112,8 @@ def enable():
     def _modified_serialize(self, value, current_path, all_paths=None):
         if isinstance(value, types.BuiltinFunctionType):
             return self._pyref(value, current_path)
+        if isinstance(value, partial):
+            value = fdl.Partial(value.func, *value.args, **value.keywords)
         return self._original_serialize(value, current_path, all_paths)
 
     serialization.Serialization._original_serialize = serialization.Serialization._serialize
diff --git a/nemo/lightning/io/pl.py b/nemo/lightning/io/pl.py
index 2cadc56e59b4..d0749fbeead7 100644
--- a/nemo/lightning/io/pl.py
+++ b/nemo/lightning/io/pl.py
@@ -5,6 +5,7 @@
 
 import pytorch_lightning as pl
 import torch
+from lightning_fabric.plugins import CheckpointIO
 from lightning_fabric.plugins.io.checkpoint_io import CheckpointIO
 from lightning_fabric.utilities.cloud_io import get_filesystem
 from lightning_fabric.utilities.types import _PATH
@@ -12,10 +13,6 @@
     get_default_load_sharded_strategy,
     get_default_save_sharded_strategy,
 )
-
-# from nemo.utils.callbacks.torch_dist_async import TorchDistAsyncSaveShardedStrategy
-from megatron.core.dist_checkpointing.strategies import tensorstore
-from megatron.core.dist_checkpointing.strategies.async_utils import AsyncCallsQueue, AsyncRequest
 from megatron.core.dist_checkpointing.strategies.base import SaveShardedStrategy
 from megatron.core.dist_checkpointing.strategies.fully_parallel import (
     FullyParallelLoadStrategyWrapper,
@@ -28,7 +25,12 @@
 
 from nemo.lightning.io.capture import IOProtocol
 from nemo.lightning.io.mixin import IOMixin
-from nemo.utils.callbacks.dist_ckpt_io import AsyncCompatibleCheckpointIO
+
+try:
+    from nemo.utils.callbacks.dist_ckpt_io import AsyncCompatibleCheckpointIO
+except ImportError:
+    AsyncCompatibleCheckpointIO = CheckpointIO
+
 
 log = logging.getLogger(__name__)
 
@@ -77,6 +79,7 @@ def __init__(
         torch_dist_multiproc: Optional[int] = None,
         assume_constant_structure: bool = False,
         parallel_save: bool = True,
+        parallel_save_within_dp: bool = False,
         parallel_load: bool = False,
     ):
         self.save_ckpt_format = save_ckpt_format
@@ -85,6 +88,7 @@ def __init__(
         self.torch_dist_multiproc = torch_dist_multiproc
         self.assume_constant_structure = assume_constant_structure
         self.parallel_save = parallel_save
+        self.parallel_save_within_dp = parallel_save_within_dp
         self.parallel_load = parallel_load
 
         self._save_sharded_strategy = None
@@ -161,7 +165,9 @@ def load_checkpoint(
             raise ValueError(f"Distributed checkpoints should be a directory. Found: {path}.")
 
         if self.save_ckpt_format == 'zarr' and self.load_directly_on_device:
-            sharded_strategy = tensorstore.TensorStoreLoadShardedStrategy(load_directly_on_device=True)
+            from megatron.core.dist_checkpointing.strategies.tensorstore import TensorStoreLoadShardedStrategy
+
+            sharded_strategy = TensorStoreLoadShardedStrategy(load_directly_on_device=True)
         else:
             sharded_strategy = None
 
@@ -216,8 +222,11 @@ def _determine_dist_ckpt_save_strategy(self):
             save_strategy.use_cached_ckpt_structure = self.assume_constant_structure
 
         if self.parallel_save:
+            parallelization_group = (
+                get_data_parallel_group(with_context_parallel=True) if self.parallel_save_within_dp else None
+            )
             save_strategy = FullyParallelSaveStrategyWrapper(
-                save_strategy, get_data_parallel_group(with_context_parallel=True), self.assume_constant_structure
+                save_strategy, parallelization_group, self.assume_constant_structure
             )
 
         logging.info(f'Using {save_strategy} dist-ckpt save strategy.')
diff --git a/nemo/lightning/io/state.py b/nemo/lightning/io/state.py
index b69fed9d0f4f..9fd81a960358 100644
--- a/nemo/lightning/io/state.py
+++ b/nemo/lightning/io/state.py
@@ -4,6 +4,7 @@
 from typing import Any, Callable, Dict, Generic, List, Optional, Tuple, TypeVar, Union, overload
 
 import numpy as np
+import torch
 from torch import nn
 
 SourceModuleT = TypeVar("SourceModuleT", bound=nn.Module)
@@ -19,11 +20,12 @@ class TransformCTX:
     target_state: dict
 
 
+@torch.no_grad
 def apply_transforms(
     source: nn.Module,
     target: TargetModuleT,
     mapping: Dict[str, str],
-    transforms: Optional[List[Callable[[TransformCTX], TransformCTX]]] = None,
+    transforms: Optional[List[Callable[[TransformCTX], TransformCTX]]] = [],
 ) -> TargetModuleT:
     """
     Applies a series of transformations to adapt the state dictionary of a source module to
@@ -101,9 +103,8 @@ def scale_weights(ctx):
     for key, val in mapping.items():
         ctx = StateDictTransform(key, val)(ctx)
 
-    if transforms:
-        for transform in transforms:
-            ctx = transform(ctx)
+    for transform in transforms:
+        ctx = transform(ctx)
 
     _params: Dict[str, nn.Parameter] = {}
     for name, param in _target.named_parameters():
@@ -144,9 +145,9 @@ def scale_weights(ctx):
 
         _module.register_buffer(_key, val)
 
-    keys = [name for name in list(target_state.keys()) if not name.endswith("_extra_state")]
+    keys = list(filter(lambda x: x is not None and not x.endswith("_extra_state"), target_state.keys()))
     if len(keys) != 0:
-        raise RuntimeError(f"Additional keys: {target_state.keys()} in checkpoint but not in model.")
+        raise RuntimeError(f"Additional keys: {keys} in checkpoint but not in model.")
 
     # TODO: Is this correct?
     # for key in target.state_dict():
@@ -165,7 +166,7 @@ def scale_weights(ctx):
 
 
 def _default_transform(inp):
-    return inp.float()
+    return inp
 
 
 class StateDictTransform(Generic[F]):
@@ -324,7 +325,7 @@ def _match_keys(keys: List[str], pattern: str) -> np.ndarray:
     regex_pattern = re.compile("^" + pattern.replace("*", "(.*)") + "$")
     wildcard_matches = [[] for _ in range(pattern.count("*"))]
 
-    for key in keys:
+    for key in filter(lambda x: x is not None, keys):
         match = regex_pattern.match(key)
         if match:
             for i, group in enumerate(match.groups()):
@@ -342,7 +343,7 @@ def _match_keys(keys: List[str], pattern: str) -> np.ndarray:
     output_array = np.empty(shape, dtype=object)
 
     # Populate the array with the keys, now that we have the correct shape and ordering
-    for key in keys:
+    for key in filter(lambda x: x is not None, keys):
         match = regex_pattern.match(key)
         if match:
             # Convert match groups to indices based on their position in wildcard_matches
diff --git a/nemo/lightning/megatron_parallel.py b/nemo/lightning/megatron_parallel.py
index 2f2308717004..5b2ac7ff1adb 100644
--- a/nemo/lightning/megatron_parallel.py
+++ b/nemo/lightning/megatron_parallel.py
@@ -12,7 +12,6 @@
     Iterable,
     Iterator,
     List,
-    Mapping,
     Optional,
     Protocol,
     Sequence,
@@ -25,9 +24,11 @@
 
 import torch
 import torch.distributed
+from megatron.core import parallel_state
 from megatron.core.distributed import DistributedDataParallel as McoreDDP
 from megatron.core.distributed import DistributedDataParallelConfig
 from megatron.core.transformer.transformer_config import TransformerConfig
+from pytorch_lightning.utilities import move_data_to_device
 from torch import Tensor, nn
 from typing_extensions import override
 
@@ -43,15 +44,35 @@ def convert_output(self, output: torch.Tensor) -> torch.Tensor: ...
 
 
 def default_data_step(dataloader_iter: Iterator[DataT]) -> DataT:
+    """
+    Moves the data to a device.
+
+    In this case we unpack the dataloader iterator. There may be a wrapper on the dataloader
+    iter from here: https://github.com/NVIDIA/NeMo/blob/main/nemo/lightning/fabric/strategies.py#L441.
+
+    This will not subset the data for your with context parallel so please override this function if you
+    want to use context parallel.
+
+    Examples:
+        If the dataloader_iter returns: [Tuple[, , ]] -> move to device
+        If the dataloader_iter returns: [, ] -> move to device
+
+    Returns:
+        DataT: The data moved to the device.
+    """
+    if parallel_state.get_context_parallel_world_size() > 1:
+        raise ValueError(
+            "Default data step is being used in a context parallel environment."
+            "Please define your own data step that appropriately slices the data for context parallel."
+        )
+
     batch = next(dataloader_iter)
 
+    # If its wrapped in a tuple, unpack it.
     if isinstance(batch, tuple) and len(batch) == 3:
         batch = batch[0]
 
-    if isinstance(batch, dict):
-        batch = {k: v.cuda(non_blocking=True) for k, v in batch.items()}
-
-    return batch
+    return move_data_to_device(batch, torch.cuda.current_device())
 
 
 def default_forward_step(model: nn.Module, batch, *args, **kwargs) -> torch.Tensor:
@@ -95,6 +116,12 @@ class MegatronParallel(nn.ModuleList, Generic[ModelT]):
             forward pass of a model.
         loss_reduction (Optional[Callable[[nn.Module], MegatronLossReduction]]): An optional
             function that defines how the loss is reduced.
+        vp_size (Optional[int]): Virtual pipeline parallel size.
+        ddp_config (Optional[DistributedDataParallelConfig]): An instance of Megatron core's
+            DistributedDataParallelConfig which controls the Megatron DDP configuration.
+        cpu (bool): Whether model should reside on CPU.
+        convert_module_fn (Optional[Callable[[ModelT], nn.Module]]): An optional function to
+            apply to the model parameters after initialization.
 
     Examples
     --------
@@ -129,8 +156,8 @@ def __init__(
         cpu: bool = False,
         convert_module_fn: Optional[Callable[[ModelT], nn.Module]] = None,
     ) -> None:
-        from apex.transformer.tensor_parallel.layers import set_defaults_if_not_set_tensor_model_parallel_attributes
         from megatron.core import parallel_state
+        from megatron.core.tensor_parallel import set_defaults_if_not_set_tensor_model_parallel_attributes
 
         _pipeline: List[nn.Module]
         if isinstance(pipeline, nn.ModuleList):
@@ -152,67 +179,15 @@ def __init__(
                         _model.configure_model()
                     _pipeline.append(_model)
 
-        if convert_module_fn:
-            for i in range(len(_pipeline)):
-                _pipeline[i] = convert_module_fn(_pipeline[i])
-
-        if isinstance(ddp_config, DistributedDataParallelConfig):
-            for model_chunk_idx, model_chunk in enumerate(_pipeline):
-                module = model_chunk.module
-
-                ddp = DDP(
-                    module.config,
-                    ddp_config,
-                    module,
-                    data_parallel_group=parallel_state.get_data_parallel_group(with_context_parallel=True),
-                    expert_data_parallel_group=parallel_state.get_data_modulo_expert_parallel_group(),
-                    # Turn off bucketing for model_chunk 2 onwards, since communication for these
-                    # model chunks is overlapped with compute anyway.
-                    disable_bucketing=(model_chunk_idx > 0),
-                )
-                model_chunk.module = ddp
-                model_chunk.buffers = ddp.buffers  # We need to do this explicitly since this is a attr pytorch uses
-                model_chunk.__class__.__getattr__ = getattr_proxy  # type: ignore
-
-            # param_sync_func is set in nemo.lightning.pytorch.optim.megatron
-            no_sync_func, grad_sync_func = extract_ddp_funcs(ddp_config, _pipeline)
-            for module in _pipeline:
-                module.config.no_sync_func = no_sync_func
-                module.config.grad_sync_func = grad_sync_func
-
-        for i, model_module in enumerate(_pipeline):
-            if not cpu:
-                model_module.cuda(torch.cuda.current_device())
-
-            for param in model_module.parameters():
-                set_defaults_if_not_set_tensor_model_parallel_attributes(param)
-
-            if hasattr(model_module, "configure_model"):
-                if not hasattr(model_module, "set_input_tensor"):
-                    if hasattr(model_module.module, "set_input_tensor"):
-                        model_module.set_input_tensor = model_module.module.set_input_tensor
-                    else:
-                        # TODO: What to do here?
-                        pass
-
-            # Print number of parameters.
-            if parallel_state.model_parallel_is_initialized() and parallel_state.get_data_parallel_rank() == 0:
-                from nemo.utils import logging
-
-                msg = (
-                    f" > number of parameters on (tensor, pipeline) model parallel rank "
-                    f"({parallel_state.get_tensor_model_parallel_rank()}, {parallel_state.get_pipeline_model_parallel_rank()}): "
-                    f"{_calc_number_of_params(_pipeline)}"
-                )
-                logging.info(msg)
-
         super().__init__(_pipeline)
         self.precision_plugin = precision_plugin
+        self._cpu = cpu
         self.callbacks = callbacks or CallbackConnector()
         self.data_step = data_step or default_data_step
         self.forward_step = forward_step or default_forward_step
         self.loss_reduction: MegatronLossReduction = loss_reduction
         self.ddp_config = ddp_config
+        self.convert_module_fn = convert_module_fn
 
     def forward(
         self,
@@ -475,6 +450,82 @@ def infer_num_microbatches(self, data: Union[DataT, Iterator[DataT], List[Iterat
 
         raise ValueError("Cannot infer `num_microbatches` from data, please specify it manually")
 
+    def init_model_parallel(self):
+        from megatron.core import parallel_state
+        from megatron.core.tensor_parallel.layers import set_defaults_if_not_set_tensor_model_parallel_attributes
+
+        for model_module in self:
+            if not self._cpu:
+                model_module.cuda(torch.cuda.current_device())
+
+            for param in model_module.parameters():
+                set_defaults_if_not_set_tensor_model_parallel_attributes(param)
+
+            if hasattr(model_module, "configure_model"):
+                if not hasattr(model_module, "set_input_tensor"):
+                    if hasattr(model_module.module, "set_input_tensor"):
+                        model_module.set_input_tensor = model_module.module.set_input_tensor
+                    else:
+                        # TODO: What to do here?
+                        pass
+
+            # Print number of parameters.
+            if parallel_state.model_parallel_is_initialized() and parallel_state.get_data_parallel_rank() == 0:
+                from nemo.utils import logging
+
+                num_params = _calc_number_of_params(list(self))
+                num_trainable_params = _calc_number_of_trainable_params(list(self))
+
+                msg = (
+                    f" > number of parameters on (tensor, pipeline) model parallel rank "
+                    f"({parallel_state.get_tensor_model_parallel_rank()}, {parallel_state.get_pipeline_model_parallel_rank()}): "
+                    f"{num_params}"
+                )
+                logging.info(msg)
+
+                if num_params != num_trainable_params:
+                    logging.info(
+                        f" > number of trainable parameters: {num_trainable_params} ({num_trainable_params / num_params:.2%} of total)"
+                    )
+
+        if self.convert_module_fn:
+            self.apply_convert_module_fn()
+
+        self.init_ddp()
+
+    def apply_convert_module_fn(self):
+        for i in range(len(self)):
+            self[i] = self.convert_module_fn(self[i])
+
+    def init_ddp(self):
+        if not isinstance(self.ddp_config, DistributedDataParallelConfig):
+            return
+
+        from megatron.core import parallel_state
+
+        for model_chunk_idx, model_chunk in enumerate(self):
+            module = model_chunk.module
+
+            ddp = DDP(
+                module.config,
+                self.ddp_config,
+                module,
+                data_parallel_group=parallel_state.get_data_parallel_group(with_context_parallel=True),
+                expert_data_parallel_group=parallel_state.get_data_modulo_expert_parallel_group(),
+                # Turn off bucketing for model_chunk 2 onwards, since communication for these
+                # model chunks is overlapped with compute anyway.
+                disable_bucketing=(model_chunk_idx > 0),
+            )
+            model_chunk.module = ddp
+            model_chunk.buffers = ddp.buffers  # We need to do this explicitly since this is a attr pytorch uses
+            model_chunk.__class__.__getattr__ = getattr_proxy  # type: ignore
+
+        # param_sync_func is set in nemo.lightning.pytorch.optim.megatron
+        no_sync_func, grad_sync_func = extract_ddp_funcs(self.ddp_config, self)
+        for module in self:
+            module.config.no_sync_func = no_sync_func
+            module.config.grad_sync_func = grad_sync_func
+
     def _build_context(self, context: Dict[str, Any]) -> Dict[str, Any]:
         if "self" in context:
             del context["self"]
@@ -565,21 +616,29 @@ def forward_backward_func(self) -> "MegatronStepProtocol":
 
     @override
     def __getattr__(self, item: Any) -> Any:
-        if len(self) == 0:
-            return super().__getattr__(item)
-
         try:
-            # __getattr__ gets called as a last resort if the attribute does not exist
-            # call nn.Module's implementation first
+            # First, try to get the attribute from the superclass (nn.ModuleList)
             return super().__getattr__(item)
         except AttributeError:
-            # If the attribute is not available on the _FabricModule wrapper, redirect to the wrapped nn.Module
-            attr = getattr(self._modules[self._get_abs_string_index(0)], item)
+            # If not found in superclass, check if we have any modules
+            if len(self) == 0:
+                raise AttributeError(
+                    f"'{self.__class__.__name__}' object has no attribute '{item}' and contains no modules"
+                )
 
-            return attr
+            # Try to get it from the first module
+            try:
+                return getattr(self._modules[self._get_abs_string_index(0)], item)
+            except AttributeError:
+                raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{item}'")
 
 
 class _ModuleStepFunction:
+    """
+    This class acts as a bridge between Megatron core's lower-level functional API and PTL's object-oriented API,
+        making it possible to use PTL-compatible functions in Megatron core.
+    """
+
     def __init__(self, name: str, is_property: bool = False, includes_self: bool = False):
         self.name = name
         self.is_property = is_property
@@ -608,7 +667,9 @@ def wrapped(self, *args):
 def getattr_proxy(self, item: Any) -> Any:
     try:
         return super(self.__class__, self).__getattr__(item)
-    except AttributeError:
+    except AttributeError as e:
+        if item == 'module':  ## this is a hacky WAR and may cause misleading error messages
+            raise e
         try:
             return getattr(self.module, item)
         except AttributeError:
@@ -915,6 +976,12 @@ def _calc_number_of_params(model: List[nn.Module]) -> int:
     return sum([sum([p.nelement() for p in model_module.parameters()]) for model_module in model])
 
 
+def _calc_number_of_trainable_params(model: List[nn.Module]) -> int:
+    assert isinstance(model, list)
+
+    return sum([sum([p.numel() for p in model_module.parameters() if p.requires_grad]) for model_module in model])
+
+
 def is_list_of_iterators(var) -> bool:
     if not isinstance(var, list):
         return False
diff --git a/nemo/lightning/nemo_logger.py b/nemo/lightning/nemo_logger.py
index 5ed783fdbefe..1df40cf659ae 100644
--- a/nemo/lightning/nemo_logger.py
+++ b/nemo/lightning/nemo_logger.py
@@ -7,7 +7,6 @@
 
 import lightning_fabric as fl
 import pytorch_lightning as pl
-from fiddle._src.experimental import serialization
 from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint as PTLModelCheckpoint
 from pytorch_lightning.loggers import Logger, TensorBoardLogger, WandbLogger
 
@@ -30,7 +29,12 @@ class NeMoLogger(IOMixin):
         log_local_rank_0_only (bool): Log only on local rank 0.
         log_global_rank_0_only (bool): Log only on global rank 0.
         files_to_copy (Optional[List[str]]): List of files to copy to log directory.
-        update_logger_directory (bool): Whether to update logger directory.
+        update_logger_directory (bool): Whether to update logger directory to write to `exp_dir`.
+            If True, the `save_dir` passed to the logger will be treated as a relative path and
+            the logger will be reconfigured to write to `exp_dir / save_dir`. This ensures that
+            all output from an experiment is written to a common directory. If False, the logger's
+            save_dir will not be overwritten. This argument applies only to TensorBoardLogger and
+            WandbLogger instances.
         ckpt (Optional[ModelCheckpoint]): Model checkpoint callback.
     """
 
@@ -73,30 +77,45 @@ def setup(self, trainer: Union[pl.Trainer, fl.Fabric], resume_if_exists: bool =
         logging.rank = self.global_rank
 
         if self.explicit_log_dir and isinstance(trainer, pl.Trainer):  # If explicit log_dir was passed, short circuit
-            return check_explicit_log_dir(trainer, self.explicit_log_dir, self.dir, self.name, self.version)
-
-        # Default dir to ./nemo_experiments if None was passed
-        _dir = self.dir
-        if self.dir is None:
-            _dir = str(Path.cwd() / 'nemo_experiments')
-
-        if not self.name:
-            self.name = "default"
-
-        version = self.version or os.environ.get(NEMO_ENV_VARNAME_VERSION, None)
-        if is_global_rank_zero():
-            if self.use_datetime_version:
-                version = time.strftime('%Y-%m-%d_%H-%M-%S')
-        if resume_if_exists:
-            logging.warning(
-                "No version folders would be created under the log folder as 'resume_if_exists' is enabled."
-            )
-            version = None
-        if version:
+            if trainer.logger is not None and not self.update_logger_directory:
+                logging.warning(
+                    f"nemo logger received explicit_log_dir: {self.explicit_log_dir} and the pytorch lightning trainer "
+                    f"that was passed to nemo_logger container a logger, but update_logger_directory is False. This means "
+                    f"that the trainer's logger directory may not match with the explicit_log_dir."
+                )
+            if self.dir or self.version:
+                logging.error(
+                    f"nemo logger received explicit_log_dir: {self.explicit_log_dir} and at least one of dir: {self.dir}, "
+                    f"or version: {self.version}. Please note that dir, name, and version will be ignored."
+                )
+            if is_global_rank_zero() and Path(self.explicit_log_dir).exists():
+                logging.warning(f"NeMoLogger is logging to {self.explicit_log_dir}, but it already exists.")
+            log_dir, _dir, self.name, version = Path(self.explicit_log_dir), str(self.explicit_log_dir), "", ""
+
+        else:
+            # Default dir to ./nemo_experiments if None was passed
+            _dir = self.dir
+            if self.dir is None:
+                _dir = str(Path.cwd() / 'nemo_experiments')
+
+            if not self.name:
+                self.name = "default"
+
+            version = self.version or os.environ.get(NEMO_ENV_VARNAME_VERSION, None)
             if is_global_rank_zero():
-                os.environ[NEMO_ENV_VARNAME_VERSION] = version
+                if self.use_datetime_version:
+                    version = time.strftime('%Y-%m-%d_%H-%M-%S')
+            if resume_if_exists:
+                logging.warning(
+                    "No version folders would be created under the log folder as 'resume_if_exists' is enabled."
+                )
+                version = None
+            if version:
+                if is_global_rank_zero():
+                    os.environ[NEMO_ENV_VARNAME_VERSION] = version
+
+            log_dir = Path(_dir) / Path(str(self.name)) / Path("" if version is None else str(version))
 
-        log_dir = Path(_dir) / Path(str(self.name)) / Path("" if version is None else str(version))
         # update app_state with log_dir, exp_dir, etc
         app_state = AppState()
         app_state.log_dir = log_dir
@@ -124,25 +143,29 @@ def _setup_trainer_loggers(self, trainer, dir, version):
         loggers = [self.tensorboard, self.wandb, *self.extra_loggers]
         loggers = [logger for logger in loggers if logger is not None]
 
-        if self.update_logger_directory and self.wandb:
-            self.wandb._save_dir = dir
-            self.wandb._wandb_init["dir"] = dir
-            self.wandb._wandb_init["name"] = self.name
-            self.wandb._name = self.name
-
         if loggers:
             if trainer.logger is not None and not self.tensorboard:
                 loggers = [trainer.logger] + loggers
             trainer._logger_connector.configure_logger(loggers)
 
-        if trainer.logger is not None:
-            trainer.logger._version = version or ""
-            if self.update_logger_directory:
-                logging.warning(
-                    f'"update_logger_directory" is True. Overwriting logger "save_dir" to {dir} and "name" to {self.name}'
-                )
-                trainer.logger._root_dir = dir
-                trainer.logger._name = self.name
+        if self.update_logger_directory:
+            for logger in trainer.loggers:
+                if isinstance(logger, TensorBoardLogger):
+                    logger._version = version or ""
+                    logger._root_dir = Path(dir) / logger.save_dir
+                    trainer.logger._name = self.name
+                    logging.warning(
+                        f'"update_logger_directory" is True. Overwriting tensorboard logger "save_dir" to {logger._root_dir}'
+                    )
+                elif isinstance(logger, WandbLogger):
+                    logger._id = version or ""
+                    logger._save_dir = Path(dir) / logger.save_dir
+                    logger._wandb_init["dir"] = Path(dir) / logger.save_dir
+                    logger._wandb_init["name"] = self.name
+                    logger._name = self.name
+                    logging.warning(
+                        f'"update_logger_directory" is True. Overwriting wandb logger "save_dir" to {logger._save_dir}'
+                    )
 
     def _setup_trainer_model_checkpoint(self, trainer, log_dir, ckpt=None):
         if ckpt:
@@ -187,10 +210,15 @@ def _setup_trainer_model_checkpoint(self, trainer, log_dir, ckpt=None):
                 ModelCheckpoint.CHECKPOINT_NAME_LAST = callback.filename + '-last'
 
     def _handle_task_config(self, task_config, log_dir):
-        task_config.save_config_img(log_dir / "task.png")
-        task_json = serialization.dump_json(task_config)
-        with open(log_dir / "task.json", "w") as f:
-            f.write(task_json)
+        try:
+            from fiddle._src.experimental import serialization
+
+            task_config.save_config_img(log_dir / "task.png")
+            task_json = serialization.dump_json(task_config)
+            with open(log_dir / "task.json", "w") as f:
+                f.write(task_json)
+        except Exception as e:
+            logging.warning(f'Saving task config failed: {e}. Skipping saving')
 
     def _setup_file_logging(self, log_dir):
         """Set up file logging based on rank settings."""
diff --git a/nemo/lightning/pytorch/callbacks/model_checkpoint.py b/nemo/lightning/pytorch/callbacks/model_checkpoint.py
index 83e750ff281e..a2068468a0f7 100644
--- a/nemo/lightning/pytorch/callbacks/model_checkpoint.py
+++ b/nemo/lightning/pytorch/callbacks/model_checkpoint.py
@@ -17,7 +17,7 @@
 import shutil
 from datetime import timedelta
 from pathlib import Path
-from typing import Any, Dict, Iterable, Optional, Union
+from typing import Any, Dict, Iterable, List, Optional, Union
 
 import pytorch_lightning
 import torch
@@ -26,15 +26,33 @@
 from pytorch_lightning.callbacks.model_checkpoint import _is_local_file_protocol
 from pytorch_lightning.utilities import rank_zero_info
 
-from nemo.lightning.io.mixin import IOMixin
 from nemo.lightning.io.pl import TrainerContext
 from nemo.utils import logging
 from nemo.utils.app_state import AppState
-from nemo.utils.callbacks.dist_ckpt_io import AsyncFinalizableCheckpointIO
 from nemo.utils.model_utils import ckpt_to_dir
 
 
-class ModelCheckpoint(PTLModelCheckpoint, IOMixin):
+class ModelCheckpoint(PTLModelCheckpoint):
+    """Light wrapper around Lightning's ModelCheckpoint to force a saved checkpoint on train_end.
+    Adds support for asyncronous checkpointing and provides some additional logic to clean up invalid checkpoints
+    Args:
+        monitor: Metric to monitor when saving top-k checkpoints.
+        verbose: Verbosity mode.
+        save_last: When ``True``, saves a `*-last` copy whenever a checkpoint file gets saved.
+        save_top_k: When ``True``, saves the top-k checkpoints according to ``monitor``.
+        save_weights_only:  if ``True``, then only the model's weights will be saved.
+        mode: One of {min, max}. Whether the objective is to minimize or maximize the monitored quantity.
+        every_n_epochs: Number of epochs between checkpoints.
+        every_n_train_steps: Number of train steps between checkpoints.
+        train_time_interval: After each interval, monitor checkpoints. Not to be used with
+            ``every_n_epochs`` or ``every_n_train_steps``.
+        save_best_model: When ``True``, reloads and saves the best checkpoint.
+        save_on_train_epoch_end: Whether to run checkpointing at the end of the training epoch
+        enable_nemo_ckpt_io: Whether to dump the current model model state, including the
+            config file, to allow for reproducibility of experiments.
+        async_save: Whether to enable asynchronous checkpointing.
+        try_restore_best_ckpt: Whether to restore the best model path.
+    """
 
     UNFINISHED_CHECKPOINT_SUFFIX = "-unfinished"
 
@@ -52,14 +70,12 @@ def __init__(
         save_best_model: bool = False,
         save_on_train_epoch_end: Optional[bool] = False,  # Save after training, not after validation
         enable_nemo_ckpt_io: bool = True,
-        async_save: bool = False,
         try_restore_best_ckpt: bool = True,
         **kwargs,
     ):
         self.save_best_model = save_best_model
         self.previous_best_path = ""
         self.enable_nemo_ckpt_io = enable_nemo_ckpt_io
-        self.async_save = async_save
         # Checkpoints which removal is deferred until async save is done.
         # Each element of `deferred_ckpts_to_remove` is a growing list
         # that `self._remove_checkpoint` adds to. Once `self._save_checkpoint`
@@ -167,7 +183,7 @@ def nemo_topk_check_previous_run(self):
             if index != len(self.monitor):
                 match = re.search('[A-z]', checkpoint[index:])
                 if match:
-                    value = checkpoint[index : index + match.start() - 1]  # -1 due to separator hypen
+                    value = checkpoint[index : index + match.start() - 1]  # -1 due to separator hyphen
                     self.best_k_models[checkpoint] = float(value)
         if len(self.best_k_models) < 1:
             return  # No saved checkpoints yet
@@ -222,7 +238,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
         super().load_state_dict(state_dict)
         self._remove_invalid_entries_from_topk()
 
-    def setup(self, *args, **kwargs) -> None:
+    def setup(self, trainer, *args, **kwargs) -> None:
         from nemo.utils.get_rank import is_global_rank_zero
 
         if is_global_rank_zero():
@@ -231,7 +247,9 @@ def setup(self, *args, **kwargs) -> None:
         # Ensure that all ranks continue with unfinished checkpoints removed
         if torch.distributed.is_initialized():
             torch.distributed.barrier()
-        super().setup(*args, **kwargs)
+
+        self.async_save = getattr(trainer.strategy, "async_save", False)
+        super().setup(trainer, *args, **kwargs)
 
     def on_save_checkpoint(self, trainer, pl_module, checkpoint):
         output = super().on_save_checkpoint(trainer, pl_module, checkpoint)
@@ -381,6 +399,8 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str)
         self.set_checkpoint_unfinished_marker(filepath, barrier_after=True)
         ema_callback = self._ema_callback(trainer)
 
+        self._last_global_step_saved = trainer.global_step
+
         if ema_callback is not None:
             if self.async_save:
                 raise ValueError('async_save with EMA not supported')
@@ -401,6 +421,8 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str)
             finalize_fn = self._get_finalize_save_checkpoint_callback(trainer, filepath, trainer.global_step)
             if self.async_save:
                 checkpoint_io = trainer.strategy.checkpoint_io
+                from nemo.utils.callbacks.dist_ckpt_io import AsyncFinalizableCheckpointIO
+
                 if not isinstance(checkpoint_io, AsyncFinalizableCheckpointIO):
                     raise ValueError('Async save requires async compatible CheckpointIO')
                 storage_options = dict(finalize_fn=finalize_fn)
@@ -409,6 +431,12 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str)
             else:
                 storage_options = None
             trainer.save_checkpoint(filepath, self.save_weights_only, storage_options=storage_options)
+
+            ## NOTE: saving context happens synchronously always
+            from nemo.utils.get_rank import is_global_rank_zero
+
+            if self.enable_nemo_ckpt_io and is_global_rank_zero():
+                TrainerContext.from_trainer(trainer).io_dump(ckpt_to_dir(filepath))
             if self.async_save:
                 logging.info(f'Scheduled async checkpoint save for {filepath}')
             else:
@@ -421,14 +449,8 @@ def _get_finalize_save_checkpoint_callback(
 
         def _cb():
             logging.debug(f'Finalize callback called for step {global_step}, filepath {filepath}')
-            self._last_global_step_saved = global_step
             self._last_checkpoint_saved = filepath
 
-            from nemo.utils.get_rank import is_global_rank_zero
-
-            if self.enable_nemo_ckpt_io and is_global_rank_zero():
-                TrainerContext.from_trainer(trainer).io_dump(ckpt_to_dir(filepath))
-
             # notify loggers
             if trainer.is_global_zero:
                 for logger in trainer.loggers:
diff --git a/nemo/lightning/pytorch/callbacks/model_transform.py b/nemo/lightning/pytorch/callbacks/model_transform.py
index 68b3db16f473..5d48851843fc 100644
--- a/nemo/lightning/pytorch/callbacks/model_transform.py
+++ b/nemo/lightning/pytorch/callbacks/model_transform.py
@@ -4,11 +4,10 @@
 import pytorch_lightning as pl
 from torch import nn
 
-from nemo.lightning.io.mixin import IOMixin
 from nemo.utils import logging
 
 
-class ModelTransform(pl.Callback, IOMixin):
+class ModelTransform(pl.Callback):
     """
     A PyTorch Lightning callback that applies a model transformation function at the start of fitting or validation.
 
@@ -63,9 +62,15 @@ def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: s
     def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
         self._maybe_apply_transform(trainer)
 
+    def on_validation_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
+        self._maybe_apply_transform(trainer)
+
     def _maybe_apply_transform(self, trainer):
         if self._needs_to_call:
-            self.model_transform(trainer.model)
+            self.apply_transform(trainer)
+
+    def apply_transform(self, trainer):
+        self.model_transform(trainer.model)
 
     @property
     def _needs_to_call(self) -> bool:
diff --git a/nemo/lightning/pytorch/callbacks/nsys.py b/nemo/lightning/pytorch/callbacks/nsys.py
index d24d7fd974be..9848fdb2b8fd 100644
--- a/nemo/lightning/pytorch/callbacks/nsys.py
+++ b/nemo/lightning/pytorch/callbacks/nsys.py
@@ -3,12 +3,11 @@
 import torch
 from pytorch_lightning.callbacks.callback import Callback
 
-from nemo.lightning.io.mixin import IOMixin
 from nemo.utils import logging
 from nemo.utils.get_rank import get_rank
 
 
-class NsysCallback(Callback, IOMixin):
+class NsysCallback(Callback):
     """
     A PyTorch Lightning callback for NVIDIA Nsight Systems (Nsys) profiling.
 
diff --git a/nemo/lightning/pytorch/callbacks/peft.py b/nemo/lightning/pytorch/callbacks/peft.py
index 26325bf549d0..869882671096 100644
--- a/nemo/lightning/pytorch/callbacks/peft.py
+++ b/nemo/lightning/pytorch/callbacks/peft.py
@@ -7,6 +7,7 @@
 import torch.nn as nn
 from lightning_fabric.utilities.types import _PATH
 from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO
+from pytorch_lightning.trainer.states import TrainerFn
 from typing_extensions import override
 
 from nemo.lightning.io.pl import ckpt_to_dir
@@ -84,23 +85,44 @@ def __call__(self, model: nn.Module) -> nn.Module:
     def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str) -> None:
         super().setup(trainer, pl_module, stage=stage)
 
+        trainer.strategy.trainer = trainer
         self.wrapped_io = WrappedAdapterIO(trainer.strategy.checkpoint_io)
         trainer.strategy._checkpoint_io = self.wrapped_io
+        trainer.strategy._init_model_parallel = False
+        trainer.strategy._setup_optimizers = False
 
-    def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
-        needs_to_call = self._needs_to_call
-        self._maybe_apply_transform(trainer)
+    def apply_transform(self, trainer):
+        super().apply_transform(trainer)
 
-        # Check if we need to load the adapters
-        if needs_to_call and self.wrapped_io.adapter_ckpt_path is not None:
+        if self.wrapped_io.adapter_ckpt_path is not None:
             logging.info(f"Loading adapters from {self.wrapped_io.adapter_ckpt_path}")
             adapter_state = self.wrapped_io.load_checkpoint(self.wrapped_io.adapter_ckpt_path)
             trainer.strategy.load_model_state_dict(adapter_state, strict=False)
 
-    def on_load_checkpoint(
+        if hasattr(trainer.strategy, "init_model_parallel"):
+            logging.info("Initializing model parallel")
+            trainer.strategy.init_model_parallel()
+
+        if trainer.state.fn == TrainerFn.FITTING:
+            logging.info("Setting up optimizers")
+            trainer.strategy.setup_optimizers(trainer)
+
+    def on_save_checkpoint(
         self, trainer: pl.Trainer, pl_module: pl.LightningModule, checkpoint: Dict[str, Any]
     ) -> None:
-        pl_module.strict_loading = False
+        # Filter out non-trainable parameters
+        trainable_params = set(name for name, param in pl_module.named_parameters() if param.requires_grad)
+        filtered_state_dict = {}
+        for name, value in checkpoint['state_dict'].items():
+            if name in trainable_params:
+                filtered_state_dict[name] = value
+            elif self.adapter_key_filter(name):  # Include all adapter-related parameters
+                filtered_state_dict[name] = value
+
+        checkpoint['state_dict'] = filtered_state_dict
+
+    def adapter_key_filter(self, key: str) -> bool:
+        return ".adapter." in key or key.endswith(".adapters")
 
 
 class AdapterWrapper(nn.Module):
@@ -224,9 +246,6 @@ class WrappedAdapterIO(_WrappingCheckpointIO):
     def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None:
         assert self.checkpoint_io is not None
 
-        key = "sharded_state_dict" if "sharded_state_dict" in checkpoint else "state_dict"
-        checkpoint[key] = dict(filter(lambda x: ".adapter." in x[0], checkpoint[key].items()))
-
         self.checkpoint_io.save_checkpoint(checkpoint, path, storage_options=storage_options)
 
         from nemo.utils.get_rank import is_global_rank_zero
diff --git a/nemo/lightning/pytorch/callbacks/preemption.py b/nemo/lightning/pytorch/callbacks/preemption.py
index 7f1dd94256d2..69ac378ed698 100644
--- a/nemo/lightning/pytorch/callbacks/preemption.py
+++ b/nemo/lightning/pytorch/callbacks/preemption.py
@@ -14,16 +14,19 @@
 
 import contextlib
 import signal
+import sys
 from typing import Optional
 
 import torch
 from pytorch_lightning.callbacks import Callback
 from pytorch_lightning.trainer.trainer import Trainer
 
+from nemo.lightning.io.mixin import IOMixin
 from nemo.utils import logging
+from nemo.utils.callbacks.dist_ckpt_io import AsyncFinalizableCheckpointIO
 
 
-class PreemptionCallback(Callback):
+class PreemptionCallback(Callback, IOMixin):
     """
     PreemptionCallback checks for preemption during training at the end of every step.
     Upon preemption, it signals the trainer to stop gracefully.
@@ -61,13 +64,15 @@ def on_train_end(self, trainer: Trainer, pl_module) -> None:
 
     def on_train_batch_end(self, trainer: Trainer, pl_module, outputs, batch, batch_idx: int) -> None:
         if self.interrupted:
-            logging.info("Preemption detected, signaling trainer to stop")
-            trainer.should_stop = True
-
-    def on_exception(self, trainer: Trainer, pl_module, exception: BaseException) -> None:
-        if isinstance(exception, PreemptionException):
-            logging.info("Handling PreemptionException")
+            logging.info("Preemption detected, saving checkpoint and exiting")
             trainer.should_stop = True
+            if trainer.checkpoint_callback:
+                monitor_candidates = trainer.checkpoint_callback._monitor_candidates(trainer)
+                trainer.checkpoint_callback._save_last_checkpoint(trainer, monitor_candidates)
+                if isinstance(trainer.strategy.checkpoint_io, AsyncFinalizableCheckpointIO):
+                    logging.info("Async checkpointing detected, waiting for it to complete")
+                    trainer.strategy.checkpoint_io.maybe_finalize_save_checkpoint(blocking=True)
+                sys.exit(0)
 
     @contextlib.contextmanager
     def _preemption_handler(self):
@@ -81,7 +86,6 @@ def _preemption_handler(self):
         def master_handler(signum, frame):
             logging.info(f"Received signal {signum}, initiating graceful stop")
             self._interrupted = True
-            raise PreemptionException("Preemption signal received")
 
         def ignoring_handler(signum, frame):
             logging.debug(f"Received signal {signum} on non-master rank, ignoring")
@@ -109,7 +113,3 @@ def interrupted(self) -> bool:
         interrupted = torch.tensor(self._interrupted, device=torch.cuda.current_device(), dtype=torch.int32)
         torch.distributed.broadcast(interrupted, 0)
         return bool(interrupted.item())
-
-
-class PreemptionException(Exception):
-    """Custom exception for preemption events."""
diff --git a/nemo/lightning/pytorch/callbacks/progress.py b/nemo/lightning/pytorch/callbacks/progress.py
index 17178618852f..9ccf871f820f 100644
--- a/nemo/lightning/pytorch/callbacks/progress.py
+++ b/nemo/lightning/pytorch/callbacks/progress.py
@@ -22,7 +22,7 @@ def init_train_tqdm(self):
         Override bar_format to not have 's/it'.
         """
         self.bar = super().init_train_tqdm()
-        self.bar.bar_format = "{desc}: {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}{postfix}]"
+        self.bar.bar_format = "{desc} {n_fmt}/{total_fmt}{postfix}"
         return self.bar
 
     def on_train_epoch_start(self, trainer, *_):
diff --git a/nemo/lightning/pytorch/optim/base.py b/nemo/lightning/pytorch/optim/base.py
index 8e857a156649..c6160fa14b0e 100644
--- a/nemo/lightning/pytorch/optim/base.py
+++ b/nemo/lightning/pytorch/optim/base.py
@@ -149,10 +149,10 @@ def optimizers(self, model) -> List[Optimizer]:
         """
         raise NotImplementedError("The optimizers method should be implemented by subclasses.")
 
-    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx) -> None:
+    def on_train_batch_start(self, trainer, pl_module, batch, batch_idx) -> None:
         if self._optimizers is not None:
             lr = self._optimizers[0].param_groups[0]['lr']
-            pl_module.log('lr', lr, rank_zero_only=True, batch_size=1)
+            pl_module.log('lr', lr, rank_zero_only=True, batch_size=1, prog_bar=True)
 
     def __call__(self, model: L.LightningModule, megatron_parallel=None) -> OptimizerLRScheduler:
         """Calls the setup and optimizers methods.
diff --git a/nemo/lightning/pytorch/optim/lr_scheduler.py b/nemo/lightning/pytorch/optim/lr_scheduler.py
index 298a6e7a7f45..4e865443b8fc 100644
--- a/nemo/lightning/pytorch/optim/lr_scheduler.py
+++ b/nemo/lightning/pytorch/optim/lr_scheduler.py
@@ -25,7 +25,7 @@ def __init__(
         warmup_ratio: Optional[float] = None,
         max_steps: int = 10,
         min_lr: float = 0.0,
-        interval: str = "epoch",
+        interval: str = "step",
         frequency: int = 1,
         monitor: str = "val_loss",
     ):
@@ -68,7 +68,7 @@ def __init__(
         hold_ratio: Optional[float] = None,
         max_steps: int = 10,
         min_lr: float = 0.0,
-        interval: str = "epoch",
+        interval: str = "step",
         frequency: int = 1,
         monitor: str = "val_loss",
     ):
@@ -111,7 +111,7 @@ def __init__(
         self,
         max_steps: int = 10,
         min_lr: float = 1e-5,
-        interval: str = "epoch",
+        interval: str = "step",
         frequency: int = 1,
         monitor: str = "val_loss",
     ):
@@ -142,7 +142,7 @@ def __init__(
         self,
         max_steps: int = 10,
         min_lr: float = 0.0,
-        interval: str = "epoch",
+        interval: str = "step",
         frequency: int = 1,
         monitor: str = "val_loss",
     ):
@@ -176,7 +176,7 @@ def __init__(
         warmup_ratio: Optional[float] = None,
         max_steps: int = 10,
         min_lr: float = 0.0,
-        interval: str = "epoch",
+        interval: str = "step",
         frequency: int = 1,
         monitor: str = "val_loss",
     ):
@@ -218,7 +218,7 @@ def __init__(
         max_steps: int = 10,
         decay_rate: float = 0.5,
         min_lr: float = 0.0,
-        interval: str = "epoch",
+        interval: str = "step",
         frequency: int = 1,
         monitor: str = "val_loss",
     ):
@@ -252,7 +252,7 @@ def __init__(
         self,
         max_steps: int = 10,
         min_lr: float = 0.0,
-        interval: str = "epoch",
+        interval: str = "step",
         frequency: int = 1,
         monitor: str = "val_loss",
     ):
@@ -283,7 +283,7 @@ def __init__(
         self,
         max_steps: int = 10,
         min_lr: float = 0.0,
-        interval: str = "epoch",
+        interval: str = "step",
         frequency: int = 1,
         monitor: str = "val_loss",
     ):
@@ -314,7 +314,7 @@ def __init__(
         self,
         max_steps: int = 10,
         min_lr: float = 0.0,
-        interval: str = "epoch",
+        interval: str = "step",
         frequency: int = 1,
         monitor: str = "val_loss",
     ):
@@ -347,7 +347,7 @@ def __init__(
         min_lr: float = 0.0,
         power: float = 1.0,
         cycle: bool = False,
-        interval: str = "epoch",
+        interval: str = "step",
         frequency: int = 1,
         monitor: str = "val_loss",
     ):
@@ -384,7 +384,7 @@ def __init__(
         min_lr: float = 0.0,
         power: float = 1.0,
         cycle: bool = False,
-        interval: str = "epoch",
+        interval: str = "step",
         frequency: int = 1,
         monitor: str = "val_loss",
     ):
@@ -415,13 +415,13 @@ def scheduler(self, model, optimizer):
 class CosineAnnealingScheduler(LRSchedulerModule):
     def __init__(
         self,
-        max_steps=10,
-        warmup_steps=750,
-        constant_steps=80000,
-        min_lr=int(6e-5),
-        interval="epoch",
-        frequency=1,
-        monitor="val_loss",
+        max_steps: int = 10,
+        warmup_steps: int = 750,
+        constant_steps: int = 80000,
+        min_lr: float = 6e-5,
+        interval: str = "step",
+        frequency: int = 1,
+        monitor: str = "val_loss",
     ):
         super().__init__()
         self.max_steps = max_steps
@@ -445,7 +445,6 @@ def scheduler(self, model, optimizer):
 
         return {
             "optimizer": optimizer,
-            "scheduler": lr_scheduler,
             "lr_scheduler": {
                 # REQUIRED: The scheduler instance
                 "scheduler": lr_scheduler,
diff --git a/nemo/lightning/pytorch/optim/megatron.py b/nemo/lightning/pytorch/optim/megatron.py
index 7faa53f32b65..1eb5290652a4 100644
--- a/nemo/lightning/pytorch/optim/megatron.py
+++ b/nemo/lightning/pytorch/optim/megatron.py
@@ -1,3 +1,4 @@
+import inspect
 from typing import Callable, List, Optional
 
 import pytorch_lightning as pl
@@ -92,8 +93,12 @@ def sharded_state_dict(
                 is_loading=False,
                 sharding_type='fully_sharded_model_space',
             ):
+                mcore_optimizer_sig = inspect.signature(self.mcore_optimizer.sharded_state_dict).parameters
+                distrib_optim_kwargs = {}
+                if "sharding_type" in mcore_optimizer_sig:
+                    distrib_optim_kwargs["sharding_type"] = sharding_type
                 state_dict = self.mcore_optimizer.sharded_state_dict(
-                    model_sharded_state_dict, is_loading=is_loading, sharding_type=sharding_type
+                    model_sharded_state_dict, is_loading=is_loading, **distrib_optim_kwargs
                 )
                 return state_dict
 
diff --git a/nemo/lightning/pytorch/plugins/data_sampler.py b/nemo/lightning/pytorch/plugins/data_sampler.py
index 378375e3bc0c..7796e4bb92de 100644
--- a/nemo/lightning/pytorch/plugins/data_sampler.py
+++ b/nemo/lightning/pytorch/plugins/data_sampler.py
@@ -1,3 +1,4 @@
+import logging
 from typing import Any, Dict, List, Literal, Optional
 
 import pytorch_lightning as pl
@@ -43,12 +44,13 @@ def setup(self, global_rank: int) -> None:
     def transform_dataloader(self, dataloader: DataLoader, consumed_samples: int = 0) -> DataLoader:
         from nemo.lightning.data import add_megatron_sampler
 
+        mode = getattr(dataloader, 'mode', 'train')
         return add_megatron_sampler(
             dataloader,
             micro_batch_size=self.micro_batch_size,
             global_batch_size=self.global_batch_size,
             rampup_batch_size=self.rampup_batch_size,
-            consumed_samples=self.init_consumed_samples,
+            consumed_samples=self.init_consumed_samples if mode == 'train' else 0,
             dataloader_type=self.dataloader_type,
         )
 
@@ -60,12 +62,8 @@ def compute_consumed_samples(self, steps_since_resume=0) -> int:
             return 0
 
         app_state = AppState()
-
         if self.rampup_batch_size is not None:
-            from apex.transformer.pipeline_parallel.utils import _GLOBAL_NUM_MICROBATCHES_CALCULATOR
-
-            current_global_batch_size = getattr(_GLOBAL_NUM_MICROBATCHES_CALCULATOR, "current_global_batch_size", 1)
-            consumed_samples = self.prev_consumed_samples + self.if_first_step * current_global_batch_size
+            consumed_samples = self.prev_consumed_samples + self.if_first_step * self.current_global_batch_size
         else:
             consumed_samples = (
                 self.init_consumed_samples
@@ -85,15 +83,17 @@ def on_megatron_step_start(self, trainer: pl.Trainer) -> None:
             trainer.should_stop = True
 
     def on_megatron_step_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
-        import apex.transformer.pipeline_parallel.utils
+        try:
+            from megatron.core.num_microbatches_calculator import update_num_microbatches
 
-        if self.rampup_batch_size is None:
-            return
+        except (ImportError, ModuleNotFoundError):
+            logging.warning("Megatron num_microbatches_calculator not found, using Apex version.")
+            from apex.transformer.pipeline_parallel.utils import update_num_microbatches
 
         self.prev_global_batch_size = self.current_global_batch_size
 
         # TODO: Add consumed samples
-        consumed_samples = self.compute_consumed_samples(trainer.global_step + 1 - self.init_global_step)
+        consumed_samples = self.compute_consumed_samples(trainer.global_step + 1 - self.init_consumed_samples)
 
         pl_module.log(
             'consumed_samples',
@@ -105,18 +105,13 @@ def on_megatron_step_end(self, trainer: pl.Trainer, pl_module: pl.LightningModul
 
         self.prev_consumed_samples = consumed_samples
 
-        num_microbatch_calculator = (
-            apex.transformer.pipeline_parallel.utils._GLOBAL_NUM_MICROBATCHES_CALCULATOR  # noqa: SLF001
-        )
-
-        num_microbatch_calculator.update(
+        update_num_microbatches(
             consumed_samples=consumed_samples,
             consistency_check=False,
         )
-        current_global_batch_size = num_microbatch_calculator.current_global_batch_size
         pl_module.log(
             "global_batch_size",
-            current_global_batch_size,
+            self.current_global_batch_size,
             prog_bar=True,
             rank_zero_only=True,
             batch_size=1,
@@ -133,17 +128,27 @@ def megatron_data_kwargs(self) -> Dict[str, Any]:
 
     @property
     def num_microbatches(self) -> int:
-        from apex.transformer.pipeline_parallel.utils import get_num_microbatches
+        try:
+            from megatron.core.num_microbatches_calculator import get_num_microbatches
+
+        except (ImportError, ModuleNotFoundError):
+            logging.warning("Megatron num_microbatches_calculator not found, using Apex version.")
+            from apex.transformer.pipeline_parallel.utils import get_num_microbatches
 
         return get_num_microbatches()
 
     @property
     def current_global_batch_size(self) -> int:
-        import apex.transformer.pipeline_parallel.utils
+        try:
+            from megatron.core.num_microbatches_calculator import get_current_global_batch_size
 
-        num_microbatch_calculator = (
-            apex.transformer.pipeline_parallel.utils._GLOBAL_NUM_MICROBATCHES_CALCULATOR  # noqa: SLF001
-        )
-        current_global_batch_size = num_microbatch_calculator.current_global_batch_size
+        except (ImportError, ModuleNotFoundError):
+            logging.warning("Megatron num_microbatches_calculator not found, using Apex version.")
+            from apex.transformer.pipeline_parallel.utils import get_current_global_batch_size
+
+        if get_current_global_batch_size():
+            current_global_batch_size = get_current_global_batch_size()
+        else:
+            current_global_batch_size = 1
 
         return current_global_batch_size
diff --git a/nemo/lightning/pytorch/plugins/mixed_precision.py b/nemo/lightning/pytorch/plugins/mixed_precision.py
index 751141d8111b..5e43e09c0420 100644
--- a/nemo/lightning/pytorch/plugins/mixed_precision.py
+++ b/nemo/lightning/pytorch/plugins/mixed_precision.py
@@ -61,7 +61,6 @@ def convert_module(self, module: Module) -> Module:
         This is optional and depends on the precision limitations during optimization.
 
         """
-        from megatron.core.distributed import DistributedDataParallel
         from megatron.core.transformer.module import Float16Module
         from megatron.core.utils import get_model_config
 
@@ -69,7 +68,10 @@ def convert_module(self, module: Module) -> Module:
             config = get_model_config(module.module)
             config.fp16 = self.precision == "16-mixed"
             config.bf16 = self.precision == "bf16-mixed"
-            if not isinstance(module.module, Float16Module):
+            if isinstance(module.module, Float16Module):
+                new_float16_module = Float16Module(config, module.module.module)
+                module.module = new_float16_module
+            else:
                 module.module = Float16Module(config, module.module)
 
         return module
diff --git a/nemo/lightning/pytorch/strategies.py b/nemo/lightning/pytorch/strategies.py
index 6a84319b4fa2..f44906ef7b7b 100644
--- a/nemo/lightning/pytorch/strategies.py
+++ b/nemo/lightning/pytorch/strategies.py
@@ -5,6 +5,7 @@
 import shutil
 from collections import OrderedDict
 from contextlib import ExitStack
+from dataclasses import dataclass
 from pathlib import Path
 from typing import TYPE_CHECKING, Any, ContextManager, Dict, List, Literal, Mapping, Optional, TypeVar, Union, cast
 
@@ -45,6 +46,18 @@
 DDPLiteral = Literal["megatron", "pytorch"]
 
 
+@dataclass
+class ParallelismConfig:
+    tensor_model_parallel_size: int
+    pipeline_model_parallel_size: int
+    virtual_pipeline_model_parallel_size: int
+    context_parallel_size: int
+    sequence_parallel: bool
+    expert_model_parallel_size: int
+    moe_extended_tp: bool
+    pipeline_dtype: torch.dtype
+
+
 class MegatronStrategy(DDPStrategy, io.IOMixin):
     """Megatron plugin for Pytorch Lightning.
 
@@ -71,12 +84,35 @@ class MegatronStrategy(DDPStrategy, io.IOMixin):
         cluster_environment: Cluster environment for distributed training. Defaults to None.
         checkpoint_io: Checkpoint I/O handler. Defaults to None.
         find_unused_parameters (bool): Find unused parameters in DDP. Defaults to False.
-        enable_nemo_ckpt_io (bool): Enable NeMo checkpoint I/O. Defaults to True.
         ckpt_type (TrainerCkptProtocol): Checkpoint type. Defaults to TrainerCheckpoint.
-        ckpt_include_optimizer (bool): Include optimizer state in checkpoint. Defaults to False.
+        ckpt_include_optimizer (bool): Include optimizer state in checkpoint. Defaults to True.
         ddp (Union[DDPLiteral, DistributedDataParallelConfig]): DDP configuration. Defaults to "megatron".
         lazy_init (bool): Use lazy initialization for model parallel parameters. Defaults to False.
         pipeline_dtype (Optional[torch.dtype]): Data type for pipeline parallelism. Defaults to None.
+        save_ckpt_format (str): Distributed checkpoint format to use for checkpoint saving. Should be one of
+            'torch_dist' or 'zarr'. Defaults to 'torch_dist'.
+        ckpt_async_save (bool): Whether to save checkpoints asynchronously to reduce checkpointing overhead.
+            Defaults to False.
+        ckpt_torch_dist_multiproc (int): Number of extra processes per rank used during ckpt save
+            with PyTorch distributed format. Defaults to None.
+        ckpt_assume_constant_structure (bool): Allows caching some computation across checkpoint saves.
+            Set to True only if the state dict structure doesn't change within a single job.
+        ckpt_parallel_save (bool): If true, each worker will write its own part of the dist checkpoint.
+            Defaults to True.
+        ckpt_parallel_save_within_dp (bool): If true, save will be parallelized only within a DP group
+            (whole world otherwise), which might slightly reduce the save overhead. Defaults to False.
+        ckpt_parallel_load (bool): If true, each worker will load part of the dist checkpoint
+            and exchange with NCCL. Might use some extra GPU memory. Defaults to False.
+        ckpt_parallel_save_optim (bool): Parallel save/load of a DistributedOptimizer. 'True'
+            allows performant save and reshardable checkpoints. Set to 'False' only in order to minimize
+            the number of checkpoint files.
+        ckpt_load_directly_on_device (bool): if True, loads the weights directly on GPU.
+            Has effect only for `zarr` based checkpoints (PyT Distributed always loads on device).
+            Defaults to True.
+        setup_optimizers (bool): Whether to call the trainer's setup_optimizers function to perform any
+            necessary conversions of optimizer parameters and move optimizer parameters to the correct device.
+            Defaults to True.
+        init_model_parallel (bool): Whether to initialize the model parallel groups. Defaults to True.
         **kwargs: Additional keyword arguments.
 
     Note:
@@ -100,16 +136,21 @@ def __init__(
         cluster_environment=None,  # TODO: Add type-hint
         checkpoint_io=None,  # TODO: Add type-hint
         find_unused_parameters: bool = False,
-        ckpt_include_optimizer: bool = False,
+        ckpt_include_optimizer: bool = True,
         ddp: Union[DDPLiteral, DistributedDataParallelConfig] = "megatron",
         lazy_init: bool = False,
         pipeline_dtype: Optional[torch.dtype] = None,
-        save_ckpt_format='torch_dist',
-        ckpt_torch_dist_multiproc=None,  ## TODO(ashors): put elsewhere?
-        ckpt_assume_constant_structure=False,
-        ckpt_parallel_save=True,
-        ckpt_parallel_load=False,
-        ckpt_parallel_save_optim=True,
+        save_ckpt_format: str = 'torch_dist',
+        ckpt_async_save: bool = False,
+        ckpt_torch_dist_multiproc: int = None,  ## TODO(ashors): put elsewhere?
+        ckpt_assume_constant_structure: bool = False,
+        ckpt_parallel_save: bool = True,
+        ckpt_parallel_save_within_dp: bool = False,
+        ckpt_parallel_load: bool = False,
+        ckpt_parallel_save_optim: bool = True,
+        ckpt_load_directly_on_device: bool = True,
+        setup_optimizers: bool = True,
+        init_model_parallel: bool = True,
         **kwargs,
     ) -> None:
         super().__init__(
@@ -132,19 +173,24 @@ def __init__(
         self.lazy_init = lazy_init
         self.ckpt_include_optimizer = ckpt_include_optimizer
         self.pipeline_dtype = pipeline_dtype
+        self._setup_optimizers = setup_optimizers
+        self._init_model_parallel = init_model_parallel
         self.log_train_loss = bool(int(os.getenv("NEMO_LOG_TRAIN_LOSS", 1)))
         self.log_memory_usage = bool(int(os.getenv("NEMO_LOG_MEMORY_USAGE", 0)))
 
         self.save_ckpt_format = save_ckpt_format
+        self.async_save = ckpt_async_save
         self.torch_dist_multiproc = ckpt_torch_dist_multiproc
         self.assume_constant_structure = ckpt_assume_constant_structure
         self.parallel_save = ckpt_parallel_save
+        self.parallel_save_within_dp = ckpt_parallel_save_within_dp
         self.parallel_load = ckpt_parallel_load
         self.parallel_save_optim = ckpt_parallel_save_optim
+        self.load_directly_on_device = ckpt_load_directly_on_device
 
         self._ddp = ddp
         if ddp == "megatron":
-            self.ddp_config = DistributedDataParallelConfig()
+            self.ddp_config = DistributedDataParallelConfig(check_for_nan_in_grad=True)
         elif isinstance(ddp, DistributedDataParallelConfig):
             self.ddp_config = ddp
         elif ddp == "pytorch":
@@ -180,7 +226,7 @@ def connect(self, model: pl.LightningModule) -> None:
                     ddp_config.use_distributed_optimizer = mcore_opt_config.use_distributed_optimizer
 
     @override
-    def setup(self, trainer: pl.Trainer, setup_optimizers: bool = True) -> None:
+    def setup(self, trainer: pl.Trainer) -> None:
         assert self.accelerator is not None
         self.accelerator.setup(trainer)
         self.trainer = trainer
@@ -199,12 +245,14 @@ def setup(self, trainer: pl.Trainer, setup_optimizers: bool = True) -> None:
         if not self.data_sampler and hasattr(datamodule, "data_sampler"):
             self.data_sampler = datamodule.data_sampler
             self.data_sampler.setup(self.cluster_environment.global_rank())
+            if hasattr(datamodule, "reconfigure_limit_batches"):
+                datamodule.reconfigure_limit_batches()
 
         if self.data_sampler:
             self.data_sampler.connect(trainer)
 
         self._fix_progress_bar(trainer)
-        self.setup_megatron_parallel(trainer, setup_optimizers=setup_optimizers)
+        self.setup_megatron_parallel(trainer)
         self.setup_precision_plugin()
 
         if getattr(self.lightning_module, "model_transform", None):
@@ -244,6 +292,16 @@ def setup(self, trainer: pl.Trainer, setup_optimizers: bool = True) -> None:
             assert self.model is not None
             _sync_module_states(self.model)
 
+        ## add AsyncFinalizerCallback if using async
+        if self.async_save:
+            have_async_callback = False
+            for callback in self.trainer.callbacks:
+                if isinstance(callback, AsyncFinalizerCallback):
+                    have_async_callback = True
+                    break
+            if not have_async_callback:
+                self.trainer.callbacks.append(AsyncFinalizerCallback())
+
     @override
     def setup_distributed(self) -> None:
         self._setup_parallel_ranks()
@@ -271,7 +329,7 @@ def process_dataloader(self, dataloader: DataLoader) -> DataLoader:
 
         return dataloader
 
-    def setup_megatron_parallel(self, trainer: pl.Trainer, setup_optimizers: bool = True) -> None:
+    def setup_megatron_parallel(self, trainer: pl.Trainer) -> None:
         assert self.model is not None, "Model is not set"
 
         convert_module_fn = None
@@ -286,6 +344,10 @@ def setup_megatron_parallel(self, trainer: pl.Trainer, setup_optimizers: bool =
             ddp_config=self.ddp_config,
             convert_module_fn=convert_module_fn,
         )
+
+        if self._init_model_parallel:
+            self.init_model_parallel()
+
         self.megatron_parallel.trainer = trainer
 
         # check signature-def of self.model.configure_optimizers to check if there's an optional arg: megatron_parallel
@@ -295,18 +357,9 @@ def setup_megatron_parallel(self, trainer: pl.Trainer, setup_optimizers: bool =
                 self.model.configure_optimizers, megatron_parallel=self.megatron_parallel
             )
 
-        if setup_optimizers:
+        if self._setup_optimizers:
             self.setup_optimizers(trainer)
 
-        # TODO: Throw an execption if we have a mcore optimizer and no ddp_config
-
-        if hasattr(self.precision_plugin, "convert_optimizer"):
-            _optimizers = [*self.optimizers]
-            _optimizers[0] = self.precision_plugin.convert_optimizer(self.optimizers[0])
-            self.optimizers = _optimizers
-
-        _optimizers_to_device(self.optimizers, self.root_device)
-
         self.model = self.megatron_parallel
         self.model.callbacks.add(getattr(trainer, "callbacks"))
 
@@ -317,6 +370,9 @@ def setup_megatron_parallel(self, trainer: pl.Trainer, setup_optimizers: bool =
         if datamodule:
             self.model.callbacks.add(datamodule)
 
+    def init_model_parallel(self):
+        self.megatron_parallel.init_model_parallel()
+
     @override
     def configure_ddp(self) -> None:
         logging.debug(f"{self.__class__.__name__}: configuring MegatronParallel")
@@ -349,6 +405,16 @@ def _setup_model(self, model: nn.Module) -> nn.Module:
 
         return model
 
+    @override
+    def setup_optimizers(self, trainer: "pl.Trainer") -> None:
+        super().setup_optimizers(trainer)
+        if hasattr(self.precision_plugin, "convert_optimizer"):
+            _optimizers = [*self.optimizers]
+            _optimizers[0] = self.precision_plugin.convert_optimizer(self.optimizers[0])
+            self.optimizers = _optimizers
+
+        _optimizers_to_device(self.optimizers, self.root_device)
+
     def _setup_parallel_ranks(self) -> None:
         self.set_world_ranks()
         env = cast(ClusterEnvironment, self.cluster_environment)
@@ -425,7 +491,9 @@ def validation_step(self, dataloader_iter, *args: Any, **kwargs: Any) -> STEP_OU
         kwargs = self._update_step_kwargs(dataloader_iter, kwargs, "validation")
 
         with self.precision_plugin.val_step_context():  # TODO: Do we need this?
-            return self.model(dataloader_iter, forward_only=True, *args, **kwargs)
+            out = self.model(dataloader_iter, forward_only=True, *args, **kwargs)
+            self.lightning_module.log('val_loss', out, rank_zero_only=True, batch_size=1)
+            return out
 
     @override
     def test_step(self, dataloader_iter, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
@@ -558,25 +626,18 @@ def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict: bool = Tr
     @override
     def checkpoint_io(self) -> CheckpointIO:
         if self._checkpoint_io is None:
-            checkpoint_callback = self.trainer.checkpoint_callback
-            async_save = getattr(checkpoint_callback, "async_save", False)
             self._checkpoint_io = MegatronCheckpointIO(
                 save_ckpt_format=self.save_ckpt_format,
-                async_save=async_save,
+                async_save=self.async_save,
                 torch_dist_multiproc=self.torch_dist_multiproc,
                 assume_constant_structure=self.assume_constant_structure,
                 parallel_save=self.parallel_save,
+                parallel_save_within_dp=self.parallel_save_within_dp,
                 parallel_load=self.parallel_load,
+                load_directly_on_device=self.load_directly_on_device,
             )
-            if async_save:
+            if self.async_save:
                 self._checkpoint_io = AsyncFinalizableCheckpointIO(self._checkpoint_io)
-                have_async_callback = False
-                for callback in self.trainer.callbacks:
-                    if isinstance(callback, AsyncFinalizerCallback):
-                        have_async_callback = True
-                        break
-                if not have_async_callback:
-                    self.trainer.callbacks.append(AsyncFinalizerCallback())
         elif isinstance(self._checkpoint_io, _WrappingCheckpointIO):
             self._checkpoint_io.checkpoint_io = MegatronCheckpointIO()
 
@@ -648,10 +709,8 @@ def restore_checkpoint_after_setup(self) -> bool:
         return True
 
     @property
-    def parallelism(self):
-        from megatron.core.model_parallel_config import ModelParallelConfig
-
-        return ModelParallelConfig(
+    def parallelism(self) -> ParallelismConfig:
+        return ParallelismConfig(
             tensor_model_parallel_size=self.tensor_model_parallel_size,
             pipeline_model_parallel_size=self.pipeline_model_parallel_size,
             virtual_pipeline_model_parallel_size=self.virtual_pipeline_model_parallel_size,
diff --git a/nemo/lightning/pytorch/trainer.py b/nemo/lightning/pytorch/trainer.py
index 8b453832d56e..da1a77c3c731 100644
--- a/nemo/lightning/pytorch/trainer.py
+++ b/nemo/lightning/pytorch/trainer.py
@@ -17,6 +17,10 @@ def io_init(self, **kwargs) -> fdl.Config[Self]:
         for val in cfg_kwargs.values():
             if not serialization.find_node_traverser(type(val)):
                 track_io(type(val))
+            elif isinstance(val, list):
+                for v in val:
+                    if not serialization.find_node_traverser(type(v)):
+                        track_io(type(v))
 
         return fdl.Config(type(self), **cfg_kwargs)
 
diff --git a/nemo/lightning/resume.py b/nemo/lightning/resume.py
index fc2e21eb37fd..ca87628d699e 100644
--- a/nemo/lightning/resume.py
+++ b/nemo/lightning/resume.py
@@ -19,10 +19,10 @@
 
 
 class Resume(IOMixin):
-    def nemo_path(self, model) -> Optional[Path]:
-        raise NotImplementedError
+    def nemo_path(self, model=None) -> Optional[Path]:
+        """Returns the checkpoint to resume from."""
 
-    def setup(self, model, trainer: Union[pl.Trainer, fl.Fabric]):
+    def setup(self, trainer: Union[pl.Trainer, fl.Fabric], model=None):
         if isinstance(trainer, fl.Fabric):
             raise NotImplementedError("Fabric is not supported yet.")
 
@@ -52,10 +52,11 @@ def __init__(
             path (str): Can be used to specify a path to a specific checkpoint file to load from.
                 This will override any checkpoint found when resume_if_exists is True.
                 Defaults to None
-            dirpath (str): Path to save the checkpoints to. Defaults to /checkpoints
+            dirpath (str): Path to the checkpointing directory to restore from. Defaults to /checkpoints
             import_path (str): Path to specify if importing a checkpoint from HF or
                 another non-NeMo checkpoint format. If import_path is provided, other arguments
                 are unused.
+            adapter_path (str): Path to any adapter checkpoints.
             resume_if_exists (bool): Whether this experiment is resuming from a previous run. If
                 True, it sets trainer._checkpoint_connector._ckpt_path so that the trainer should
                 auto-resume. exp_manager will move files under log_dir to log_dir/run_{int}.
@@ -139,7 +140,11 @@ def nemo_path(self, model=None) -> Optional[Path]:
                     checkpoint = last_checkpoints[0]
                     checkpoint = uninject_model_parallel_rank(checkpoint)
                 else:
-                    raise ValueError(f"Multiple checkpoints {last_checkpoints} that matches *last.ckpt.")
+                    # Select the checkpoint with the latest modified time
+                    checkpoint = sorted(last_checkpoints, key=lambda pth: pth.lstat().st_mtime, reverse=True)[0]
+                    logging.warning(
+                        f"Multiple checkpoints {last_checkpoints} matches *last.ckpt. Selecting one with the latest modified time."
+                    )
             else:
                 checkpoint = last_checkpoints[0]
 
diff --git a/nemo/package_info.py b/nemo/package_info.py
index 59805e0e04d3..1cd6ef729936 100644
--- a/nemo/package_info.py
+++ b/nemo/package_info.py
@@ -16,7 +16,7 @@
 MAJOR = 2
 MINOR = 0
 PATCH = 0
-PRE_RELEASE = 'rc1'
+PRE_RELEASE = 'rc2'
 
 # Use the following formatting: (major, minor, patch, pre-release)
 VERSION = (MAJOR, MINOR, PATCH, PRE_RELEASE)
diff --git a/nemo/utils/callbacks/dist_ckpt_io.py b/nemo/utils/callbacks/dist_ckpt_io.py
index 144c07addaa8..437c8b0c5887 100644
--- a/nemo/utils/callbacks/dist_ckpt_io.py
+++ b/nemo/utils/callbacks/dist_ckpt_io.py
@@ -17,7 +17,7 @@
 from abc import ABC, abstractmethod
 from contextlib import contextmanager
 from time import time
-from typing import Any, Dict, Optional
+from typing import Any, Dict, Optional, Union
 
 import pytorch_lightning as pl
 from lightning_fabric.plugins import CheckpointIO
@@ -44,6 +44,7 @@
         FullyParallelSaveStrategyWrapper,
     )
     from megatron.core.dist_checkpointing.strategies.torch import TorchDistSaveShardedStrategy
+    from megatron.core.dist_checkpointing.validation import StrictHandling
     from megatron.core.parallel_state import get_data_parallel_group
 
     HAVE_MEGATRON_CORE = True
@@ -188,6 +189,9 @@ class DistributedCheckpointIO(AsyncCompatibleCheckpointIO):
         load_directly_on_device (bool, optional): if True, loads the weights directly
             on GPU. Has effect only for `zarr` based checkpoints (PyT Distributed
             always loads on device). Defaults to True.
+        load_strictness (StrictHandling, optional): defines loading strictness.
+            If not None, overwrites the `strict` flag passed to `load_checkpoint`.
+            Defaults to None.
         async_save (bool): whether to save asynchronously. Should be set to True if
             this class will be wrapped with AsyncFinalizableCheckpointIO.
         torch_dist_multiproc (int, optional): number of extra processes per rank
@@ -202,10 +206,12 @@ def __init__(
         self,
         save_ckpt_format: str,
         load_directly_on_device: bool = True,
+        load_strictness: Optional['StrictHandling'] = None,
         async_save: bool = False,
         torch_dist_multiproc: Optional[int] = None,
         assume_constant_structure: bool = False,
         parallel_save: bool = False,
+        parallel_save_within_dp: bool = False,
         parallel_load: bool = False,
     ):
         super().__init__()
@@ -214,10 +220,12 @@ def __init__(
 
         self.save_ckpt_format = save_ckpt_format
         self.load_directly_on_device = load_directly_on_device
+        self.load_strictness = load_strictness
         self.async_save = async_save
         self.torch_dist_multiproc = torch_dist_multiproc
         self.assume_constant_structure = assume_constant_structure
         self.parallel_save = parallel_save
+        self.parallel_save_within_dp = parallel_save_within_dp
         self.parallel_load = parallel_load
 
         self._save_sharded_strategy = None
@@ -236,9 +244,11 @@ def from_config(cls, model_cfg: dict, async_save: bool = False):
         return cls(
             save_ckpt_format=model_cfg.get('dist_ckpt_format', 'zarr'),
             load_directly_on_device=model_cfg.get('dist_ckpt_load_on_device', True),
+            load_strictness=model_cfg.get('dist_ckpt_load_strictness', None),
             async_save=async_save,
             torch_dist_multiproc=model_cfg.get('dist_ckpt_torch_dist_multiproc', None),
             parallel_save=model_cfg.get('dist_ckpt_parallel_save', False),
+            parallel_save_within_dp=model_cfg.get('dist_ckpt_parallel_save_within_dp', False),
             parallel_load=model_cfg.get('dist_ckpt_parallel_load', False),
         )
 
@@ -272,7 +282,7 @@ def load_checkpoint(
         path: _PATH,
         map_location: Optional[Any] = None,
         sharded_state_dict: Dict[str, Any] = None,
-        strict: Optional[bool] = True,
+        strict: Union[None, bool, 'StrictHandling'] = None,
         validate_access_integrity: Optional[bool] = True,
     ) -> Dict[str, Any]:
         """Loads a distributed checkpoint.
@@ -284,6 +294,10 @@ def load_checkpoint(
                 defines the loading procedure for the distributed checkpoint.
                 Defaults to None to comply with the CheckpointIO interface,
                 but it's a required argument.
+            strict (bool, StrictHandling, optional): adjust load strictness. bool value
+                is translated to StrictHandling instance. Gets overwritten by
+                `self.load_strictness`. Defaults to None. If `self.load_strictness`
+                is also None, strict becomes StrictHandling.ASSUME_OK_UNEXPECTED.
 
         Returns:
             Dist[str, Any]: loaded checkpoint.
@@ -308,14 +322,27 @@ def load_checkpoint(
         if sharded_strategy is not None:
             logging.info(f'Using {sharded_strategy} dist-ckpt load strategy.')
 
-        if not strict:
-            sharded_state_dict = self.adjust_non_strict_load(path, sharded_state_dict)
+        if isinstance(strict, bool):
+            # For backward-compatibility reasons and a bug in MCore (strict check not applied to factories)
+            # we must apply a simple strict check here.
+            if not strict:
+                sharded_state_dict = self.adjust_non_strict_load(path, sharded_state_dict)
+            strict = StrictHandling.ASSUME_OK_UNEXPECTED if strict else StrictHandling.LOG_ALL
+        if self.load_strictness is not None:
+            # Overwrites function argument
+            strict = self.load_strictness
+        if strict is None:
+            # Default behavior
+            strict = StrictHandling.ASSUME_OK_UNEXPECTED
+
+        logging.debug(f'Dist ckpt load strictness: {strict}')
 
         return dist_checkpointing.load(
             sharded_state_dict=sharded_state_dict,
             checkpoint_dir=path,
             sharded_strategy=sharded_strategy,
             validate_access_integrity=validate_access_integrity,
+            strict=strict,
         )
 
     def adjust_non_strict_load(self, path: _PATH, sharded_state_dict: Dict[str, Any]):
@@ -377,8 +404,11 @@ def _determine_dist_ckpt_save_strategy(self):
             save_strategy.use_cached_ckpt_structure = self.assume_constant_structure
 
         if self.parallel_save:
+            parallelization_group = (
+                get_data_parallel_group(with_context_parallel=True) if self.parallel_save_within_dp else None
+            )
             save_strategy = FullyParallelSaveStrategyWrapper(
-                save_strategy, get_data_parallel_group(with_context_parallel=True), self.assume_constant_structure
+                save_strategy, parallelization_group, self.assume_constant_structure
             )
 
         logging.info(f'Using {save_strategy} dist-ckpt save strategy.')
diff --git a/nemo/utils/exp_manager.py b/nemo/utils/exp_manager.py
index f4bfb8ec95c4..ca18b22c00bc 100644
--- a/nemo/utils/exp_manager.py
+++ b/nemo/utils/exp_manager.py
@@ -165,6 +165,7 @@ class FaultToleranceParams:
     initial_rank_heartbeat_timeout: Optional[float] = 60.0 * 60.0
     rank_heartbeat_timeout: Optional[float] = 45.0 * 60.0
     calculate_timeouts: bool = True
+    safety_factor: float = 5.0
     rank_termination_signal: signal.Signals = signal.SIGKILL
     log_level: str = 'INFO'
     max_rank_restarts: int = 0
@@ -229,6 +230,8 @@ class ExpManagerConfig:
     # Fault tolrance
     create_fault_tolerance_callback: Optional[bool] = False
     fault_tolerance: Optional[FaultToleranceParams] = field(default_factory=FaultToleranceParams)
+    # logs TFLOPs per sec per gpu
+    log_tflops_per_sec_per_gpu: Optional[bool] = True
 
 
 class TimingCallback(Callback):
@@ -558,7 +561,7 @@ def exp_manager(trainer: 'pytorch_lightning.Trainer', cfg: Optional[Union[DictCo
         if HAVE_STRAGGLER_DET:
             logging.info("Enabling straggler detection...")
             straggler_det_args_dict = dict(cfg.straggler_detection_params)
-            straggler_det_callback = StragglerDetectionCallback(**straggler_det_args_dict, logger=logging)
+            straggler_det_callback = StragglerDetectionCallback(**straggler_det_args_dict)
             trainer.callbacks.append(straggler_det_callback)
         else:
             raise ValueError(
@@ -573,6 +576,7 @@ def exp_manager(trainer: 'pytorch_lightning.Trainer', cfg: Optional[Union[DictCo
             # here we only need to know if the autoresume is enabled.
             ft_use_autoresume = ft_params.max_subsequent_job_failures > 0
             fault_tol_callback = FaultToleranceCallback(
+                exp_dir=Path(log_dir).parent,  # log_dir is "/results/"
                 autoresume=ft_use_autoresume,
                 calculate_timeouts=ft_params.calculate_timeouts,
                 simulated_fault_params=ft_params.simulated_fault,
@@ -583,6 +587,11 @@ def exp_manager(trainer: 'pytorch_lightning.Trainer', cfg: Optional[Union[DictCo
                 'FaultToleranceCallback was enabled with create_fault_tolerance_callback, but fault_tolerance package is not installed.'
             )
 
+    if cfg.log_tflops_per_sec_per_gpu:
+        logging.info(
+            "TFLOPs per sec per GPU will be calculated, conditioned on supported models. Defaults to -1 upon failure."
+        )
+
     if is_global_rank_zero():
         # Move files_to_copy to folder and add git information if present
         if cfg.files_to_copy:
diff --git a/nemo/utils/export_utils.py b/nemo/utils/export_utils.py
index c44530944051..534598097bf4 100644
--- a/nemo/utils/export_utils.py
+++ b/nemo/utils/export_utils.py
@@ -72,10 +72,12 @@ def __init__(self, weight, bias, skip_bias_add):
         self.weight = weight
         self.skip_bias_add = skip_bias_add
 
-    def forward(self, x):
+    def forward(self, x, weight=None):
+        if weight is None:
+            weight = self.weight
         if self.skip_bias_add:
-            return F.linear(x, self.weight), self.bias
-        return F.linear(x, self.weight, self.bias), None
+            return F.linear(x, weight), self.bias
+        return F.linear(x, weight, self.bias), None
 
 
 def get_export_format(filename: str):
@@ -239,7 +241,8 @@ def run_ort_and_compare(sess, ort_input, output_example, check_tolerance=0.01):
     from apex.contrib.layer_norm.layer_norm import FastLayerNorm
     from apex.normalization import MixedFusedRMSNorm
     from apex.normalization.fused_layer_norm import FusedLayerNorm, MixedFusedLayerNorm
-    from apex.transformer.functional.fused_softmax import FusedScaleMaskSoftmax
+    from megatron.core.fusions.fused_layer_norm import FusedLayerNorm as MCoreFusedLayerNorm
+    from megatron.core.fusions.fused_softmax import FusedScaleMaskSoftmax
     from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
 
     def replace_FusedLayerNorm(n: nn.Module) -> Optional[nn.LayerNorm]:
@@ -255,21 +258,17 @@ def replace_FusedLayerNorm(n: nn.Module) -> Optional[nn.LayerNorm]:
 
         if isinstance(n, FusedLayerNorm) or isinstance(n, MixedFusedLayerNorm):
             shape, eps, affine = n.normalized_shape, n.eps, n.elementwise_affine
-            n_state = n.state_dict()
+        elif isinstance(n, MCoreFusedLayerNorm):
+            shape, eps, affine = n.weight.shape, n.eps, True
         elif isinstance(n, FastLayerNorm):
             shape, eps, affine = n.weight.shape, n.epsilon, True
-            n_state = n.state_dict()
-        elif isinstance(n, MixedFusedRMSNorm):
-            shape, eps, affine = n.normalized_shape, n.eps, n.elementwise_affine
-            tmp_n_state = n.state_dict()
-            n_state = {'weight': tmp_n_state['weight'], 'bias': torch.zeros_like(tmp_n_state['weight'])}
         else:
             return None
 
         n_state = n.state_dict()
         mod = nn.LayerNorm(shape, eps=eps, elementwise_affine=affine, device=p.device, dtype=p.dtype)
 
-        mod.load_state_dict(n_state)
+        mod.load_state_dict(n_state, strict=True)
 
         return mod
 
@@ -306,7 +305,7 @@ def replace_ParallelLinear(n: nn.Module) -> Optional[nn.Linear]:
         mod = LinearWithBiasSkip(n.weight, n.bias, n.skip_bias_add).to(dev)
 
         n_state = n.state_dict()
-        mod.load_state_dict(n_state)
+        mod.load_state_dict(n_state, strict=False)
         return mod
 
     def replace_FusedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]:
@@ -318,7 +317,7 @@ def replace_FusedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]:
            Equivalent LayerNorm module
         """
         if not isinstance(n, FusedScaleMaskSoftmax):
-            logging.warning("This function can only change the FusedScaleMaskSoftmax module.")
+            logging.warning(f"This function can only change the FusedScaleMaskSoftmax module, got: {n.__class__}")
             return n
 
         # disable the fusion only
@@ -331,6 +330,7 @@ def replace_FusedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]:
     default_Apex_replacements = {
         "FusedLayerNorm": replace_FusedLayerNorm,
         "MixedFusedLayerNorm": replace_FusedLayerNorm,
+        "MCoreFusedLayerNorm": replace_FusedLayerNorm,
         "FastLayerNorm": replace_FusedLayerNorm,
         "RowParallelLinear": replace_ParallelLinear,
         "ColumnParallelLinear": replace_ParallelLinear,
diff --git a/requirements/requirements.txt b/requirements/requirements.txt
index e2a558929146..3169d31dbeed 100644
--- a/requirements/requirements.txt
+++ b/requirements/requirements.txt
@@ -1,5 +1,5 @@
 fiddle
-huggingface_hub>=0.20.3
+huggingface_hub>=0.24
 numba
 numpy>=1.22
 onnx>=1.7.0
diff --git a/requirements/requirements_lightning.txt b/requirements/requirements_lightning.txt
index c7e67d21a693..1b3397f69033 100644
--- a/requirements/requirements_lightning.txt
+++ b/requirements/requirements_lightning.txt
@@ -4,6 +4,6 @@ hydra-core>1.3,<=1.3.2
 omegaconf<=2.3
 pytorch-lightning>2.2.1
 torchmetrics>=0.11.0
-transformers>=4.36.0,<=4.40.2
+transformers
 wandb
 webdataset>=0.2.86
diff --git a/requirements/requirements_multimodal.txt b/requirements/requirements_multimodal.txt
index 1fdce2c160d9..b7e6119fd7b7 100644
--- a/requirements/requirements_multimodal.txt
+++ b/requirements/requirements_multimodal.txt
@@ -6,7 +6,7 @@ einops_exts
 imageio
 kornia
 nerfacc>=0.5.3
-open_clip_torch
+open_clip_torch==2.24.0
 PyMCubes
 taming-transformers
 torchdiffeq
diff --git a/requirements/requirements_nlp.txt b/requirements/requirements_nlp.txt
index a1dad5b64a8a..f98f7c318c56 100644
--- a/requirements/requirements_nlp.txt
+++ b/requirements/requirements_nlp.txt
@@ -20,4 +20,5 @@ rouge_score
 sacrebleu  # manually install sacrebleu[ja] for Japanese support; MeCab is unsupported in Python 3.11+
 sentence_transformers
 tensorstore<0.1.46
+tiktoken==0.7.0
 zarr
diff --git a/requirements/requirements_tts.txt b/requirements/requirements_tts.txt
index 9536faec8c78..0d499feb3b1f 100644
--- a/requirements/requirements_tts.txt
+++ b/requirements/requirements_tts.txt
@@ -1,5 +1,6 @@
 attrdict
 einops
+janome
 jieba
 kornia
 librosa
diff --git a/requirements/requirements_vllm.txt b/requirements/requirements_vllm.txt
index a603b3c4ec53..414e05078680 100644
--- a/requirements/requirements_vllm.txt
+++ b/requirements/requirements_vllm.txt
@@ -1 +1 @@
-vllm==0.5.0
+vllm==0.5.3.post1
diff --git a/scripts/checkpoint_converters/convert_bert_hf_to_nemo.py b/scripts/checkpoint_converters/convert_bert_hf_to_nemo.py
index a81fd33f47a2..14baca53f165 100644
--- a/scripts/checkpoint_converters/convert_bert_hf_to_nemo.py
+++ b/scripts/checkpoint_converters/convert_bert_hf_to_nemo.py
@@ -15,12 +15,11 @@
 """
 Example to run this conversion script:
 ```
-    python convert_bert_hf_to_nemo.py \
-     --input_name_or_path "thenlper/gte-large" \
+    python /opt/NeMo/scripts/checkpoint_converters/convert_bert_hf_to_nemo.py \
+     --input_name_or_path /path/to/hf/checkpoints/folder \
      --output_path /path/to/output/nemo/file.nemo \
      --mcore True \
-     --post_process False \
-     --precision 32
+     --precision bf16
 ```
 """
 
@@ -37,7 +36,10 @@
 
 
 def adjust_nemo_config(model_config, ref_config, mcore_bert=True):
-    model_config.tokenizer["type"] = "intfloat/e5-large-unsupervised"  # ref_config["_input_name_or_path"]
+    model_config.tokenizer["type"] = ref_config["_name_or_path"]
+    model_config.tokenizer["library"] = "huggingface"
+    model_config.tokenizer["use_fast"] = True
+    model_config["max_position_embeddings"] = ref_config['max_position_embeddings']
     model_config["num_layers"] = ref_config["num_hidden_layers"]
     model_config["hidden_size"] = ref_config["hidden_size"]
     model_config["ffn_hidden_size"] = ref_config["intermediate_size"]
@@ -67,7 +69,7 @@ def get_args():
         "--post_process", type=bool, default=False, required=False, help="Whether to have the postprocessing modules"
     )
     parser.add_argument(
-        "--precision", type=str, default="32", choices=["bf16", "32"], help="Precision for checkpoint weights saved"
+        "--precision", type=str, default="bf16", choices=["bf16", "32"], help="Precision for checkpoint weights saved"
     )
 
     args = parser.parse_args()
@@ -86,7 +88,12 @@ def convert(args):
     model = MegatronBertModel(nemo_config.model, trainer)
 
     if not args.post_process:
-        model.model.lm_head, model.model.encoder.final_layernorm, model.model.binary_head, model.model.output_layer = (
+        (
+            model.model.module.lm_head,
+            model.model.module.encoder.final_layernorm,
+            model.model.module.binary_head,
+            model.model.module.output_layer,
+        ) = (
             None,
             None,
             None,
@@ -263,6 +270,16 @@ def convert(args):
         else:
             nemo_state_dict['model.language_model.embedding.word_embeddings.weight'] = padded_embedding
 
+    modified_dict = {}
+    for key, value in nemo_state_dict.items():
+        if key.startswith('model.'):
+            new_key = 'model.module.' + key[len('model.') :]
+            modified_dict[new_key] = value
+        else:
+            modified_dict[key] = value
+
+    nemo_state_dict = modified_dict
+
     model.load_state_dict(nemo_state_dict, strict=True)
     dtype = torch_dtype_from_precision(args.precision)
     model = model.to(dtype=dtype)
@@ -271,5 +288,6 @@ def convert(args):
 
 
 if __name__ == '__main__':
+    os.environ['NVTE_FLASH_ATTN'] = '0'  # Bert doesn't support FLASH_ATTN
     args = get_args()
     convert(args)
diff --git a/scripts/checkpoint_converters/convert_clip_hf_to_nemo.py b/scripts/checkpoint_converters/convert_clip_hf_to_nemo.py
index 690fa74abccd..2b8156ad4b26 100644
--- a/scripts/checkpoint_converters/convert_clip_hf_to_nemo.py
+++ b/scripts/checkpoint_converters/convert_clip_hf_to_nemo.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
+# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
diff --git a/scripts/checkpoint_converters/convert_llama_hf_to_nemo.py b/scripts/checkpoint_converters/convert_llama_hf_to_nemo.py
index e1dc00c77439..0837e0e6ccf2 100644
--- a/scripts/checkpoint_converters/convert_llama_hf_to_nemo.py
+++ b/scripts/checkpoint_converters/convert_llama_hf_to_nemo.py
@@ -44,7 +44,11 @@
 def get_args():
     parser = ArgumentParser()
     parser.add_argument(
-        "--input_name_or_path", type=str, default=None, required=True, help="Path to Huggingface LLaMA checkpoints",
+        "--input_name_or_path",
+        type=str,
+        default=None,
+        required=True,
+        help="Path to Huggingface LLaMA checkpoints",
     )
     parser.add_argument("--output_path", type=str, default=None, required=True, help="Path to output .nemo file.")
     parser.add_argument(
@@ -92,7 +96,10 @@ def load_config(args, llama_config):
         nemo_config.tokenizer = tokenizer_dict
 
     if llama_config['rope_scaling'] is not None:
-        if llama_config['rope_scaling']['type'] == 'linear':
+        rope_type = llama_config['rope_scaling'].get('rope_type')
+        if rope_type is None:
+            rope_type = llama_config['rope_scaling'].get('type')
+        if rope_type in ('linear', 'llama3'):
             nemo_config['seq_len_interpolation_factor'] = llama_config['rope_scaling']['factor']
         else:
             raise ValueError("Only linear rope scaling type is supported now")
@@ -139,7 +146,7 @@ def convert(args):
         scaler = None
         if precision in [16, '16', '16-mixed']:
             scaler = GradScaler(
-                init_scale=nemo_config.get('native_amp_init_scale', 2 ** 32),
+                init_scale=nemo_config.get('native_amp_init_scale', 2**32),
                 growth_interval=nemo_config.get('native_amp_growth_interval', 1000),
                 hysteresis=nemo_config.get('hysteresis', 2),
             )
diff --git a/scripts/checkpoint_converters/convert_llava_hf_to_nemo.py b/scripts/checkpoint_converters/convert_llava_hf_to_nemo.py
index d91899348e8c..85f65ca05ecf 100644
--- a/scripts/checkpoint_converters/convert_llava_hf_to_nemo.py
+++ b/scripts/checkpoint_converters/convert_llava_hf_to_nemo.py
@@ -292,7 +292,7 @@ def convert(args):
         batch_dict = hf_tokenizer(input_texts, max_length=512, padding=True, truncation=True, return_tensors='pt')
         batch_dict_cuda = {k: v.cuda() for k, v in batch_dict.items()}
         hf_model = hf_model.cuda().eval()
-        model = model.eval()
+        model = model.cuda().eval()
 
         hf_outputs = hf_model(**batch_dict_cuda, output_hidden_states=True)
         ids = batch_dict_cuda['input_ids']
@@ -307,7 +307,7 @@ def convert(args):
             attn_mask, _, pos_ids = attn_mask_and_pos_ids
 
             outputs = model(
-                tokens=tokens, text_position_ids=pos_ids.cuda(), attention_mask=attn_mask.cuda(), labels=None
+                tokens=tokens.cuda(), text_position_ids=pos_ids.cuda(), attention_mask=attn_mask.cuda(), labels=None
             )
 
         hf_next_token = hf_outputs.logits[0, -1].argmax()
diff --git a/scripts/checkpoint_converters/convert_llava_nemo_to_hf.py b/scripts/checkpoint_converters/convert_llava_nemo_to_hf.py
index 430a74567ec2..4681bac41a6f 100644
--- a/scripts/checkpoint_converters/convert_llava_nemo_to_hf.py
+++ b/scripts/checkpoint_converters/convert_llava_nemo_to_hf.py
@@ -150,10 +150,13 @@ def reverse_adjust_tensor_shapes(model, hf_model, nemo_state_dict):
     dict: The updated state dictionary with original tensor shapes and structures.
     """
     model_config = model.cfg
-    num_query_groups = model_config["num_query_groups"]
     head_num = model_config["num_attention_heads"]
     hidden_size = model_config["hidden_size"]
     head_size = model_config["kv_channels"]
+    if "num_query_groups" in model_config and model_config["num_query_groups"] is not None:
+        num_query_groups = model_config["num_query_groups"]
+    else:
+        num_query_groups = head_num
     if head_size is None:
         head_size = hidden_size // head_num
     heads_per_group = head_num // num_query_groups
@@ -300,7 +303,7 @@ def convert(args):
         batch_dict = hf_tokenizer(input_texts, max_length=512, padding=True, truncation=True, return_tensors='pt')
         batch_dict_cuda = {k: v.cuda() for k, v in batch_dict.items()}
         hf_model = hf_model.cuda().eval()
-        model = model.eval()
+        model = model.cuda().eval()
 
         hf_outputs = hf_model(**batch_dict_cuda, output_hidden_states=True)
         ids = batch_dict_cuda['input_ids']
@@ -315,7 +318,7 @@ def convert(args):
             attn_mask, _, pos_ids = attn_mask_and_pos_ids
 
             outputs = model(
-                tokens=tokens, text_position_ids=pos_ids.cuda(), attention_mask=attn_mask.cuda(), labels=None
+                tokens=tokens.cuda(), text_position_ids=pos_ids.cuda(), attention_mask=attn_mask.cuda(), labels=None
             )
 
         hf_next_token = hf_outputs.logits[0, -1].argmax()
diff --git a/scripts/checkpoint_converters/convert_mamba2_pyt_to_nemo.py b/scripts/checkpoint_converters/convert_mamba2_pyt_to_nemo.py
index 9dfd9565179d..1a0a13709421 100644
--- a/scripts/checkpoint_converters/convert_mamba2_pyt_to_nemo.py
+++ b/scripts/checkpoint_converters/convert_mamba2_pyt_to_nemo.py
@@ -95,8 +95,12 @@ def convert(args):
 
         for i in range(num_layers):
             for attr in layer_attributes:
-                new_key = f'model.decoder.layers.{i}.{attr}'
-                old_key = f'backbone.layers.{i}.{attr}'
+                if attr == 'norm.weight':
+                    new_key = f'model.decoder.layers.{i}.mixer.in_proj.layer_norm_weight'
+                    old_key = f'backbone.layers.{i}.norm.weight'
+                else:
+                    new_key = f'model.decoder.layers.{i}.{attr}'
+                    old_key = f'backbone.layers.{i}.{attr}'
                 new_state_dict[new_key] = checkpoint_weights[old_key]
 
         # Tokenizer settings
@@ -110,7 +114,15 @@ def convert(args):
         layer_numbers = set(int(re.search(r'decoder\.layers\.(\d+)\.', key).group(1)) for key in layer_keys)
         num_layers = max(layer_numbers) + 1
 
-        new_state_dict = {"model." + key: value for key, value in checkpoint_weights.items()}
+        for key, value in checkpoint_weights.items():
+            if '.norm.weight' in key and 'mixer' not in key:
+                key = key[:-11] + 'mixer.in_proj.layer_norm_weight'
+            new_state_dict["model." + key] = value
+
+        # Tokenizer settings
+        tokenizer_library = 'megatron'
+        tokenizer_type = 'GPTSentencePieceTokenizer'
+        tokenizer_model = args.tokenizer_model_dir
 
         # Tokenizer settings
         tokenizer_library = 'megatron'
@@ -164,7 +176,9 @@ def convert(args):
     trainer = MegatronLMPPTrainerBuilder(nemo_config).create_trainer()
     nemo_model_from_pyt = MegatronMambaModel(nemo_config.model, trainer)
 
-    nemo_model_from_pyt.load_state_dict(new_state_dict, strict=True)
+    # Setting strict=False for the _extra_state
+
+    nemo_model_from_pyt.load_state_dict(new_state_dict, strict=False)
     dtype = torch_dtype_from_precision(args.precision)
     nemo_model_from_pyt = nemo_model_from_pyt.to(dtype=dtype)
     nemo_model_from_pyt.save_to(args.output_path)
diff --git a/scripts/checkpoint_converters/convert_mistral_7b_hf_to_nemo.py b/scripts/checkpoint_converters/convert_mistral_7b_hf_to_nemo.py
index 3a72661499bf..5785db656217 100644
--- a/scripts/checkpoint_converters/convert_mistral_7b_hf_to_nemo.py
+++ b/scripts/checkpoint_converters/convert_mistral_7b_hf_to_nemo.py
@@ -25,6 +25,7 @@
 import os
 from argparse import ArgumentParser
 from collections import OrderedDict
+from pathlib import Path
 
 import torch
 import torch.nn
@@ -55,11 +56,13 @@ def get_args():
     )
     parser.add_argument("--output_path", type=str, default=None, required=True, help="Path to output .nemo file.")
     parser.add_argument("--precision", type=str, default="bf16", help="Model precision")
+    parser.add_argument('--low-ram', '--low-mem', action='store_true', dest='low_ram')
+    parser.add_argument('--tmp-dir', default='/tmp/mistral_ckpt_parts/')
     args = parser.parse_args()
     return args
 
 
-def load_model(cls, checkpoint, strict, **kwargs):
+def restore_model_from_checkpoint(cls, checkpoint, strict, **kwargs):
     try:
         if 'cfg' in kwargs:
             model = ptl_load_state(cls, checkpoint, strict=strict, **kwargs)
@@ -67,7 +70,8 @@ def load_model(cls, checkpoint, strict, **kwargs):
             model = cls(cfg=checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY], **kwargs)
             for name, module in model.named_parameters():
                 if name in checkpoint['state_dict']:
-                    module.data = checkpoint['state_dict'][name]
+                    # cast to target precision and
+                    module.data = checkpoint['state_dict'][name].to(dtype=module.data.dtype)
                     checkpoint['state_dict'].pop(name)
                 else:
                     print(f"Unexpected key: {name} not in checkpoint but in model.")
@@ -84,6 +88,9 @@ def load_model(cls, checkpoint, strict, **kwargs):
 
             # register the artifacts
             cfg = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY]
+            # assert os.path.exists(
+            #     cfg.tokenizer.model
+            # ), f"Expected cfg.tokenizer.model {cfg.tokenizer.model} to be present"
             if cfg.tokenizer.model is not None:
                 model.register_artifact("tokenizer.tokenizer_model", cfg.tokenizer.model)
             if cfg.tokenizer.vocab_file is not None:
@@ -95,18 +102,22 @@ def load_model(cls, checkpoint, strict, **kwargs):
     return model
 
 
-def load_config(mistral_config, tokenizer_path):
+def load_config(mistral_config, tokenizer, config_path):
     nemo_config = OmegaConf.load(
         os.path.join(os.path.dirname(__file__), '../../examples/nlp/language_modeling/conf/megatron_llama_config.yaml')
     ).model
     # akoumparouli: verify this.
-    nemo_config.encoder_seq_length = mistral_config['sliding_window']
+    if mistral_config.get('sliding_window', None) is not None:
+        nemo_config.encoder_seq_length = mistral_config['sliding_window']
+    else:
+        nemo_config.encoder_seq_length = mistral_config['max_position_embeddings']
     nemo_config.num_layers = int(mistral_config['num_hidden_layers'])
     nemo_config.hidden_size = mistral_config['hidden_size']
     nemo_config.ffn_hidden_size = mistral_config['intermediate_size']
     nemo_config.num_attention_heads = mistral_config['num_attention_heads']
     nemo_config.max_position_embeddings = mistral_config['max_position_embeddings']
-    nemo_config.window_size = [mistral_config['sliding_window'], 0]
+    if mistral_config.get('sliding_window', None) is not None:
+        nemo_config.window_size = [mistral_config['sliding_window'], 0]
     nemo_config.init_method_std = mistral_config['initializer_range']
     # RMSNorm's epsilon.
     nemo_config.layernorm_epsilon = mistral_config['rms_norm_eps']
@@ -118,7 +129,34 @@ def load_config(mistral_config, tokenizer_path):
     # Mistral uses SiLU, but it is the same as swish with beta = 1.
     nemo_config.activation = 'fast-swiglu'
 
-    nemo_config.tokenizer.model = tokenizer_path
+    # Tokenizer config
+    if hasattr(tokenizer, 'vocab_file'):
+        nemo_config.tokenizer.model = tokenizer.vocab_file
+    else:
+        # Load tekken.json, extract the 'vocab' field & write it to file.
+        vocab_path = os.path.join(config_path, 'tekken.json')
+        assert os.path.exists(vocab_path), f"Expected {vocab_path} to exist"
+        with open(vocab_path, 'rt') as fp:
+            tok_vocab = json.load(fp)
+        vocab_output_path = '/tmp/tekken.json'
+        if os.path.exists(vocab_output_path):
+            os.remove(vocab_output_path)
+        with open(vocab_output_path, 'wt') as fp:
+            json.dump(tok_vocab['vocab'], fp)
+        assert os.path.exists(vocab_output_path), f"Expected {vocab_output_path} to exist"
+        assert os.path.getsize(vocab_output_path) > 0, f"Expected {vocab_output_path} to be non-empty"
+
+        tokenizer_dict = {
+            'library': 'tiktoken',
+            'type': 'tiktoken',
+            'vocab_file': vocab_output_path,
+            'model': None,
+            'merge_file': None,
+            'delimiter': None,
+            'sentencepiece_legacy': False,
+        }
+        nemo_config.tokenizer = tokenizer_dict
+
     # TODO(@akoumparouli): rope_scaling.
     nemo_config['rotary_base'] = mistral_config['rope_theta']
 
@@ -130,38 +168,63 @@ def load_config(mistral_config, tokenizer_path):
     return nemo_config
 
 
-def load_mistral_ckpt(in_dir):
+class LazyStateDict:
+    def __init__(self, ckpt_index, root):
+        self.map = ckpt_index
+        self.root = root
+
+    def __getitem__(self, key):
+        from safetensors import safe_open
+
+        assert key in self.map, f'Got unknown key: {key}'
+        ckpt_part_path = os.path.join(self.root, self.map[key])
+        assert os.path.exists(ckpt_part_path), f'Expected ckpt-part to exist {ckpt_part_path}'
+        with safe_open(ckpt_part_path, framework="pt", device="cpu") as fp:
+            return fp.get_tensor(key)
+
+
+def load_mistral_ckpt(in_dir, load_model=True):
     params_file = os.path.join(in_dir, 'config.json')
     assert os.path.exists(params_file)
     with open(params_file, 'r') as fp:
         model_args = json.load(fp)
 
-    model = AutoModelForCausalLM.from_pretrained(in_dir)
-    ckpt = model.state_dict()
+    ckpt = None
+    if load_model:
+        # If it's in safetensors format, then use lazyloading
+        ckpt_parts_map_path = os.path.join(in_dir, 'model.safetensors.index.json')
+        if os.path.exists(ckpt_parts_map_path):
+            ckpt_parts_map = {}
+            with open(ckpt_parts_map_path, 'rt') as fp:
+                ckpt_parts_map = json.load(fp)
+            print('ckpt_parts_map= ', ckpt_parts_map)
+            ckpt = LazyStateDict(ckpt_parts_map['weight_map'], in_dir)
+        else:
+            model = AutoModelForCausalLM.from_pretrained(in_dir)
+            ckpt = model.state_dict()
 
     tokenizer = AutoTokenizer.from_pretrained(in_dir)
     assert tokenizer.vocab_size == model_args['vocab_size']
     return model_args, ckpt, tokenizer
 
 
-def convert(args):
-    logging.info(f"loading checkpoint {args.input_name_or_path}")
-
-    model_args, ckpt, tokenizer = load_mistral_ckpt(args.input_name_or_path)
-    nemo_config = load_config(model_args, os.path.join(args.input_name_or_path, 'tokenizer.model'))
-    logging.info(f"loaded checkpoint {args.input_name_or_path}")
-
-    if args.precision in ["32", "16"]:
-        precision = int(float(args.precision))
-    elif args.precision in ["bf16", "bf16-mixed"]:
+def parse_precision(precision):
+    if precision in ["32", "16"]:
+        return int(float(precision))
+    elif precision in ["bf16", "bf16-mixed"]:
         if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
-            precision = args.precision
+            return precision
         else:
             logging.warning("BF16 is not supported on this device. Using FP16 instead.")
-            precision = args.precision[2:]  # prune bf in string
+            return precision[2:]  # prune bf in string
     else:
-        precision = args.precision
+        return precision
+
 
+def make_trainer(args, nemo_config):
+    model_args, ckpt, tokenizer = load_mistral_ckpt(args.input_name_or_path, load_model=False)
+    nemo_config = load_config(model_args, tokenizer, args.input_name_or_path)
+    precision = parse_precision(args.precision)
     plugins = []
     if precision in [16, '16', 'bf16', '16-mixed', 'bf16-mixed']:
         scaler = None
@@ -191,13 +254,24 @@ def convert(args):
         dtype = torch.float32  # fallback
 
     nemo_config.precision = precision
-    logging.info(f"nemo_config: {nemo_config}")
+    print(f"nemo_config: {nemo_config}")
 
     trainer = Trainer(plugins=plugins, accelerator='cpu', strategy=NLPDDPStrategy())
+    return trainer, dtype
+
+
+def convert(args):
+    logging.info(f"loading checkpoint {args.input_name_or_path}")
+
+    model_args, ckpt, tokenizer = load_mistral_ckpt(args.input_name_or_path)
+    nemo_config = load_config(model_args, tokenizer, args.input_name_or_path)
+    logging.info(f"loaded checkpoint {args.input_name_or_path}")
 
     hidden_size = nemo_config.hidden_size
     head_num = nemo_config.num_attention_heads
-    head_size = hidden_size // head_num
+    head_size = model_args.get('head_dim', hidden_size // head_num)
+    # Set this explictly because 2407 does not use hidden_size // num_attention_heads
+    nemo_config.kv_channels = head_size
     num_layers = nemo_config.num_layers
 
     mcore_gpt = nemo_config.mcore_gpt
@@ -226,6 +300,10 @@ def convert(args):
     if mcore_gpt:
         assert nemo_config.activation.startswith('fast-'), 'mcore only supports fast version of gated linear unit.'
 
+    yield checkpoint
+    checkpoint = OrderedDict()
+    checkpoint['state_dict'] = OrderedDict()
+
     for l in range(int(num_layers)):
         print(f"converting layer {l}")
         old_tensor_shape = ckpt[f'model.layers.{l}.self_attn.q_proj.weight'].size()
@@ -298,6 +376,9 @@ def convert(args):
         checkpoint['state_dict'][post_attn_ln_base_name] = param_to_weights(post_attn_ln_weight)
 
         print(f"done layer {l}")
+        yield checkpoint
+        checkpoint = OrderedDict()
+        checkpoint['state_dict'] = OrderedDict()
 
     final_ln_weight = ckpt[f'model.norm.weight']
     if mcore_gpt:
@@ -314,36 +395,72 @@ def convert(args):
     checkpoint['state_dict'][output_layer_base_name] = param_to_weights(output_layer_weight)
 
     checkpoint[MegatronGPTModel.CHECKPOINT_HYPER_PARAMS_KEY] = nemo_config
+    yield checkpoint
     del ckpt
 
+
+def merge(a: dict, b: dict, path=[]):
+    is_dict = lambda x: isinstance(x, OrderedDict) or isinstance(x, dict)
+    for key in b:
+        if key in a:
+            if is_dict(a[key]) and is_dict(b[key]):
+                merge(a[key], b[key], path + [str(key)])
+            elif a[key] != b[key]:
+                raise Exception('Value conflict: ' + '.'.join(path + [str(key)]))
+        else:
+            a[key] = b[key]
+    return a
+
+
+def save_to_nemo(args, checkpoint):
+
+    logging.info(f"loading checkpoint {args.input_name_or_path}")
+    model_args, ckpt, tokenizer = load_mistral_ckpt(args.input_name_or_path, load_model=False)
+    nemo_config = load_config(model_args, tokenizer, args.input_name_or_path)
+
+    nemo_config.precision = parse_precision(args.precision)
+    nemo_config.megatron_amp_O2 = True
+
+    hidden_size = nemo_config.hidden_size
+    head_num = nemo_config.num_attention_heads
+    head_size = model_args.get('head_dim', hidden_size // head_num)
+    # Set this explictly because 2407 does not use hidden_size // num_attention_heads
+    nemo_config.kv_channels = head_size
+
+    trainer, dtype = make_trainer(args, nemo_config)
+
+    checkpoint[MegatronGPTModel.CHECKPOINT_HYPER_PARAMS_KEY] = nemo_config
+    checkpoint[MegatronGPTModel.CHECKPOINT_HYPER_PARAMS_KEY].use_cpu_initialization = True
+    checkpoint[MegatronGPTModel.CHECKPOINT_HYPER_PARAMS_KEY].perform_initialization = False
+
     if nemo_config.get('megatron_amp_O2', False):
         keys = list(checkpoint['state_dict'].keys())
         for key in keys:
             checkpoint['state_dict'][key.replace('model.', 'model.module.', 1)] = checkpoint['state_dict'].pop(key)
 
-    model = load_model(MegatronGPTModel, checkpoint, strict=False, trainer=trainer)
+    model = restore_model_from_checkpoint(MegatronGPTModel, checkpoint, strict=False, trainer=trainer)
 
     model._save_restore_connector = NLPSaveRestoreConnector()
 
-    # cast to target precision and disable cpu init
-    model = model.to(dtype=dtype)
+    # disable cpu init
     model.cfg.use_cpu_initialization = False
-
+    model.cfg.perform_initialization = True
     if getattr(tokenizer, 'chat_template', None) is not None:
         import hashlib
 
-        assert (
-            hashlib.md5(tokenizer.chat_template.encode('utf-8')).hexdigest() == "0b629f783db54e02509999196956ff40"
-        ), "Got unkown chat template"
-        from omegaconf import OmegaConf, open_dict
-
-        with open_dict(model.cfg):
-            model.cfg.tokenizer.chat_template = OmegaConf.create(
-                {
-                    'prefix': "{_bos_}",
-                    'roles': {'User': "[INST] {_content_} [/INST]", 'Assistant': "{_content_}{_eos_}"},
-                }
-            )
+        template_hash = hashlib.md5(tokenizer.chat_template.encode('utf-8')).hexdigest()
+        if template_hash != "0b629f783db54e02509999196956ff40":
+            logging.warning("Got unkown chat template")
+        else:
+            from omegaconf import OmegaConf, open_dict
+
+            with open_dict(model.cfg):
+                model.cfg.tokenizer.chat_template = OmegaConf.create(
+                    {
+                        'prefix': "{_bos_}",
+                        'roles': {'User': "[INST] {_content_} [/INST]", 'Assistant': "{_content_}{_eos_}"},
+                    }
+                )
 
     model.save_to(args.output_path)
     logging.info(f'NeMo model saved to: {args.output_path}')
@@ -351,4 +468,20 @@ def convert(args):
 
 if __name__ == '__main__':
     args = get_args()
-    convert(args)
+    if args.low_ram:
+        os.makedirs(args.tmp_dir, exist_ok=True)
+
+    checkpoint = OrderedDict()
+    for i, ckpt_part in enumerate(convert(args)):
+        if args.low_ram:
+            torch.save(ckpt_part, f'{args.tmp_dir}/nemo_ckpt_part_{i}.pth')
+        else:
+            checkpoint = merge(checkpoint, ckpt_part)
+
+    if args.low_ram:
+        print("Loading partial checkpoints")
+        for path in map(str, Path(args.tmp_dir).rglob("*.pth")):
+            print(f"Loading checkpoint: {path}")
+            checkpoint = merge(checkpoint, torch.load(path, mmap=True))
+
+    save_to_nemo(args, checkpoint)
diff --git a/scripts/checkpoint_converters/convert_mixtral_hf_to_nemo.py b/scripts/checkpoint_converters/convert_mixtral_hf_to_nemo.py
index 1bf23224357f..36e4c0c2c3ea 100644
--- a/scripts/checkpoint_converters/convert_mixtral_hf_to_nemo.py
+++ b/scripts/checkpoint_converters/convert_mixtral_hf_to_nemo.py
@@ -17,7 +17,8 @@
   Example to run this conversion script:
     python3 convert_mixtral_hf_to_nemo.py \
      --input_name_or_path  \
-     --output_path  
+     --output_path  \
+     --precision=bf16
 """
 
 import json
@@ -132,6 +133,7 @@ def load_config(mixtral_config, tokenizer_path):
     assert nemo_config.num_moe_experts > 0, "num_experts must be greater than zero."
     nemo_config.moe_router_topk = int(mixtral_config['num_experts_per_tok'])
     assert nemo_config.moe_router_topk > 0, "moe_router_topk must be greater than zero."
+    nemo_config.moe_router_pre_softmax = True
     nemo_config.use_cpu_initialization = True
     # Mixtral uses SiLU, but it is the same as swish with beta = 1.
     nemo_config.activation = 'fast-swiglu'
diff --git a/scripts/checkpoint_converters/convert_nemotron_nemo_to_hf.py b/scripts/checkpoint_converters/convert_nemotron_nemo_to_hf.py
new file mode 100644
index 000000000000..072de9e5d5f4
--- /dev/null
+++ b/scripts/checkpoint_converters/convert_nemotron_nemo_to_hf.py
@@ -0,0 +1,345 @@
+# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import os
+import shutil
+from argparse import ArgumentParser
+from collections import OrderedDict
+
+import torch
+from pytorch_lightning import Trainer
+from transformers import LlamaTokenizer, PreTrainedTokenizerFast
+from transformers.convert_slow_tokenizer import LlamaConverter
+
+from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
+from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
+from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy
+from nemo.utils import logging
+
+"""
+Script to convert a nemotron checkpoint in nemo (mcore path) into a HuggingFace checkpoint.
+This script can be used to 1) generate only the HF weights, or 2) generate an entire HF model folder.
+
+1) Generate only HF weights from a nemo file:
+
+    python convert_nemotron_nemo_to_hf.py \
+    --input_name_or_path /path/to/file.nemo or /path/to/extracted_folder \
+    --output_path /path/to/pytorch_model.bin
+
+2) Generate the full HF model folder
+
+    python convert_nemotron_nemo_to_hf.py \
+    --input_name_or_path /path/to/file.nemo or /path/to/extracted_folder \
+    --hf_input_path /path/to/input_hf_folder \
+    --hf_output_path /path/to/output_hf_folder \
+
+    Use the --cpu-only flag if the model cannot fit in the GPU (e.g. Nemotron4 340b).
+    However this option makes the conversion script significantly slower.
+"""
+
+
+def get_args():
+    parser = ArgumentParser()
+    parser.add_argument(
+        "--input_name_or_path",
+        type=str,
+        default=None,
+        required=True,
+        help="Path to .nemo file or extracted folder",
+    )
+    parser.add_argument("--output_path", type=str, default=None, required=False, help="Path to HF .bin file")
+    parser.add_argument(
+        "--hf_input_path",
+        type=str,
+        default=None,
+        help="A HF model path, " "e.g. a folder containing https://huggingface.co/nvidia/Minitron-8B-Base",
+    )
+    parser.add_argument(
+        "--hf_output_path",
+        type=str,
+        default=None,
+        help="Output HF model path, " "with the same format as above but user's own weights",
+    )
+    parser.add_argument(
+        "--precision",
+        type=str,
+        default=None,
+        help="Precision of output weights."
+        "Defaults to precision of the input nemo weights (model.cfg.trainer.precision)",
+    )
+    parser.add_argument(
+        "--cpu-only",
+        action="store_true",
+        help="Load model in cpu only. Useful if the model cannot fit in GPU memory, "
+        "but this option makes the conversion script significantly slower.",
+    )
+    args = parser.parse_args()
+    return args
+
+
+def convert_hf_config(nemo_config, tokenizer, vocab_size, dtype, hf_output_path, hf_url="nvidia/Minitron-8B-Base"):
+    """
+    Convert NeMo config to HF config
+    """
+    NEMO_ACT2HF = {
+        "squared-relu": "relu2",
+        "fast-swiglu": "silu",
+    }
+    DTYPE2HF = {
+        torch.bfloat16: "bfloat16",
+        torch.float16: "float16",
+        torch.float32: "float32",
+    }
+    hf_config = {
+        "_name_or_path": hf_url,
+        "architectures": ["NemotronForCausalLM"],
+        "bos_token_id": tokenizer.bos_id,
+        "eos_token_id": tokenizer.eos_id,
+        "hidden_act": NEMO_ACT2HF[nemo_config.activation],
+        "hidden_size": nemo_config.hidden_size,
+        "initializer_range": nemo_config.init_method_std,
+        "intermediate_size": nemo_config.ffn_hidden_size,
+        "max_position_embeddings": nemo_config.max_position_embeddings,
+        "model_type": "nemotron",
+        "num_attention_heads": nemo_config.num_attention_heads,
+        "num_hidden_layers": nemo_config.num_layers,
+        "num_key_value_heads": nemo_config.get("num_query_groups", nemo_config.num_attention_heads),
+        "norm_eps": nemo_config.layernorm_epsilon,
+        "rope_theta": nemo_config.get("rotary_base", 10000),
+        "partial_rotary_factor": nemo_config.get("rotary_percentage", 1.0),
+        "tie_word_embeddings": False,
+        "torch_dtype": DTYPE2HF[dtype],
+        "transformers_version": "4.32.0.dev0",  # TODO
+        "use_cache": True,
+        "vocab_size": vocab_size,
+    }
+    if nemo_config.kv_channels is not None:
+        hf_config["kv_channels"] = nemo_config.kv_channels
+    json.dump(hf_config, open(f"{hf_output_path}/config.json", "w"), indent=2)
+
+
+def convert(input_nemo_file, output_hf_file, precision=None, cpu_only=False) -> None:
+    """
+    Convert NeMo weights to HF weights
+    """
+    dummy_trainer = Trainer(devices=1, accelerator="cpu", strategy=NLPDDPStrategy())
+    model_config = MegatronGPTModel.restore_from(input_nemo_file, trainer=dummy_trainer, return_config=True)
+    model_config.tensor_model_parallel_size = 1
+    model_config.pipeline_model_parallel_size = 1
+    model_config.sequence_parallel = False
+    model_config.transformer_engine = True
+    if cpu_only:
+        map_location = torch.device("cpu")
+        model_config.use_cpu_initialization = True
+        model_config.dist_ckpt_load_on_device = False
+    else:
+        map_location = None
+
+    if cpu_only:
+        logging.info("******** Loading model on CPU. This will take a significant amount of time.")
+
+    model = MegatronGPTModel.restore_from(
+        input_nemo_file, trainer=dummy_trainer, override_config_path=model_config, map_location=map_location
+    )
+
+    vocab_size = model.padded_vocab_size
+
+    if precision is None:
+        precision = model.cfg.precision
+    if precision in [32, "32"]:
+        dtype = torch.float32
+    elif precision in [16, "16", "16-mixed"]:
+        dtype = torch.float16
+    elif precision in ["bf16", "bf16-mixed"]:
+        dtype = torch.bfloat16
+    else:
+        logging.warning(f"Precision string {precision} is not recognized, falling back to fp32")
+        dtype = torch.float32  # fallback
+    logging.info(f"Using precision {dtype}")
+
+    def param_to_weights(param):
+        return param.to(dtype)
+
+    checkpoint = OrderedDict()
+
+    hidden_size = model.cfg.hidden_size
+    head_num = model.cfg.num_attention_heads
+    num_layers = model.cfg.num_layers
+    ffn_hidden_size = model.cfg.ffn_hidden_size
+    num_query_groups = model.cfg.get("num_query_groups", head_num)  # different num_query_groups for 70B
+    if num_query_groups is None:
+        num_query_groups = head_num
+    heads_per_group = head_num // num_query_groups
+    qkv_total_dim = head_num + 2 * num_query_groups
+
+    # Embedding
+    embed_weight = model.state_dict()["model.embedding.word_embeddings.weight"]
+    embed_weights_base_name = "model.embed_tokens.weight"
+    checkpoint[embed_weights_base_name] = param_to_weights(embed_weight)
+
+    for l in range(int(num_layers)):
+        print(f"converting layer {l}")
+
+        qkv_weights = model.state_dict()[f"model.decoder.layers.{l}.self_attention.linear_qkv.weight"]
+        qkv_weights = qkv_weights.reshape([qkv_total_dim, -1, hidden_size])
+
+        q_slice = torch.cat(
+            [
+                torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group)
+                for i in range(num_query_groups)
+            ]
+        )
+        k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2))
+        v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2))
+        ## Example of slices
+        ## (without GQA): num_query_groups = head_num = 32,
+        ## q_slice = [0, 3, 6, 9 , ... 90, 93]
+        ## k_slice = [1, 4, 7, 10, ... 91, 94]
+        ## v_slice = [2, 5, 8, 11, ... 92, 95]
+        ## (with GQA): num_query_groups = 8, head_num = 64
+        ## q_slice = [0, 1, .. 6, 7, 10, 11, .. 16, 17, 20, 21, .. 67, 70, ... 76, 77]
+        ## k_slice = [8, 18, 28, ... 68, 78]
+        ## v_slice = [9, 19, 29, ... 69, 79]
+
+        q_weights_base_name = f"model.layers.{l}.self_attn.q_proj.weight"
+        k_weights_base_name = f"model.layers.{l}.self_attn.k_proj.weight"
+        v_weights_base_name = f"model.layers.{l}.self_attn.v_proj.weight"
+
+        checkpoint[q_weights_base_name] = param_to_weights(qkv_weights[q_slice].reshape(-1, hidden_size))
+        checkpoint[k_weights_base_name] = param_to_weights(qkv_weights[k_slice].reshape(-1, hidden_size))
+        checkpoint[v_weights_base_name] = param_to_weights(qkv_weights[v_slice].reshape(-1, hidden_size))
+
+        # attention dense
+        o_weight = model.state_dict()[f"model.decoder.layers.{l}.self_attention.linear_proj.weight"]
+        o_weight_base_name = f"model.layers.{l}.self_attn.o_proj.weight"
+        checkpoint[o_weight_base_name] = param_to_weights(o_weight)
+
+        # mlp
+        mlp_weights = model.state_dict()[f"model.decoder.layers.{l}.mlp.linear_fc1.weight"]
+        mlp_up_proj_weight = model.state_dict()[f"model.decoder.layers.{l}.mlp.linear_fc2.weight"]
+
+        if mlp_weights.shape[0] != mlp_up_proj_weight.shape[1]:
+            # Has projection (used for swi-glu)
+            logging.warning(
+                "Gated projection layers detected in NeMo checkpoint. Currently Nemotron HF does not support gated MLP."
+            )
+            assert mlp_weights.shape[0] == 2 * mlp_up_proj_weight.shape[1]
+
+            mlp_down_proj_weight = mlp_weights[:ffn_hidden_size, :]
+            mlp_gate_proj_weight = mlp_weights[ffn_hidden_size:, :]
+
+            mlp_down_proj_base_name = f"model.layers.{l}.mlp.gate_proj.weight"
+            mlp_gate_proj_base_name = f"model.layers.{l}.mlp.up_proj.weight"
+
+            checkpoint[mlp_down_proj_base_name] = param_to_weights(mlp_down_proj_weight)
+            checkpoint[mlp_gate_proj_base_name] = param_to_weights(mlp_gate_proj_weight)
+        else:
+            mlp_down_proj_weight = mlp_weights
+            mlp_down_proj_base_name = f"model.layers.{l}.mlp.up_proj.weight"
+            checkpoint[mlp_down_proj_base_name] = param_to_weights(mlp_down_proj_weight)
+
+        mlp_up_proj_base_name = f"model.layers.{l}.mlp.down_proj.weight"
+        checkpoint[mlp_up_proj_base_name] = param_to_weights(mlp_up_proj_weight)
+
+        # layernorm
+        input_ln_weight = model.state_dict()[f"model.decoder.layers.{l}.self_attention.linear_qkv.layer_norm_weight"]
+        input_ln_base_name = f"model.layers.{l}.input_layernorm.weight"
+        checkpoint[input_ln_base_name] = param_to_weights(input_ln_weight)
+        if (
+            model.state_dict().get(f"model.decoder.layers.{l}.self_attention.linear_qkv.layer_norm_bias", None)
+            is not None
+        ):
+            input_ln_bias = model.state_dict()[f"model.decoder.layers.{l}.self_attention.linear_qkv.layer_norm_bias"]
+            input_ln_bias_name = f"model.layers.{l}.input_layernorm.bias"
+            checkpoint[input_ln_bias_name] = param_to_weights(input_ln_bias)
+
+        post_attn_ln_weight = model.state_dict()[f"model.decoder.layers.{l}.mlp.linear_fc1.layer_norm_weight"]
+        post_attn_ln_base_name = f"model.layers.{l}.post_attention_layernorm.weight"
+        checkpoint[post_attn_ln_base_name] = param_to_weights(post_attn_ln_weight)
+        if model.state_dict().get(f"model.decoder.layers.{l}.mlp.linear_fc1.layer_norm_bias", None) is not None:
+            post_attn_ln_bias = model.state_dict()[f"model.decoder.layers.{l}.mlp.linear_fc1.layer_norm_bias"]
+            post_attn_ln_bias_name = f"model.layers.{l}.post_attention_layernorm.bias"
+            checkpoint[post_attn_ln_bias_name] = param_to_weights(post_attn_ln_bias)
+
+        print(f"done layer {l}")
+
+    final_ln_weight = model.state_dict()["model.decoder.final_layernorm.weight"]
+    final_ln_base_name = "model.norm.weight"
+    checkpoint[final_ln_base_name] = param_to_weights(final_ln_weight)
+    if model.state_dict().get("model.decoder.final_layernorm.bias", None) is not None:
+        final_ln_bias = model.state_dict()["model.decoder.final_layernorm.bias"]
+        final_ln_bias_name = "model.norm.bias"
+        checkpoint[final_ln_bias_name] = param_to_weights(final_ln_bias)
+
+    output_layer_weight = model.state_dict()["model.output_layer.weight"]
+    output_layer_base_name = "lm_head.weight"
+    checkpoint[output_layer_base_name] = param_to_weights(output_layer_weight)
+
+    os.makedirs(os.path.dirname(output_hf_file), exist_ok=True)
+    torch.save(checkpoint, output_hf_file)
+    logging.info(f"Weights saved to {output_hf_file}")
+
+    return model_config, model.tokenizer, dtype, vocab_size
+
+
+def extract_nemotron_tokenizer(nemo_file, model_config, output_hf_path, nemo_tokenizer):
+    tokenizer_cfg = model_config.tokenizer
+    if tokenizer_cfg.library == "sentencepiece":
+        # For sentencepiece tokenizer, we are wrapping with HF's LlamaTokenizer
+        # and convert it to a PreTrainedTokenizerFast
+        tokenizer_fn = tokenizer_cfg.model[5:]
+        output_tokenizer = f"{output_hf_path}/tokenizer.model"
+        if nemo_file.endswith(".nemo"):
+            import tarfile
+
+            archive = tarfile.open(nemo_file, "r")
+            tokenizer_filename = "./" + tokenizer_fn  # exclude 'nemo:' prefix
+            archive.extract(tokenizer_filename, output_hf_path)
+            archive.close()
+            os.rename(f"{output_hf_path}/{tokenizer_fn}", output_tokenizer)
+        elif os.path.isdir(nemo_file):
+            shutil.copy(f"{nemo_file}/{tokenizer_fn}", output_tokenizer)
+        # We use LlamaTokenizer for sentencepiece based tokenizer
+        tokenizer = LlamaTokenizer.from_pretrained(output_hf_path, legacy=False)
+        # Convert the LlamaTokenizer to a PreTrainedTokenizerFast instance
+        tokenizer = PreTrainedTokenizerFast(
+            tokenizer_object=LlamaConverter(tokenizer).converted(), model_input_names=["input_ids", "token_type_ids"]
+        )
+        tokenizer.save_pretrained(output_hf_path)
+        logging.info(f"Setencepiece tokenizer has been saved to {output_tokenizer}")
+    elif isinstance(nemo_tokenizer, AutoTokenizer):
+        nemo_tokenizer.tokenizer.save_pretrained(output_hf_path)
+        logging.info(f"HF AutoTokenizer has been saved to {output_hf_path}")
+    else:
+        raise ValueError(f"Unsupported tokenizer type: library: {tokenizer_cfg.library}, type: {tokenizer_cfg.type}")
+
+
+if __name__ == "__main__":
+    args = get_args()
+    if not args.hf_output_path:
+        assert args.output_path is not None, "Need to provide either output_path or hf_output_path"
+    else:
+        args.output_path = f"{args.hf_output_path}/pytorch_model.bin"
+        logging.info(f"weight will be saved to {args.output_path}")
+
+    nemo_config, nemo_tokenizer, dtype, vocab_size = convert(
+        args.input_name_or_path, args.output_path, precision=args.precision, cpu_only=args.cpu_only
+    )
+    if args.hf_input_path and args.hf_output_path:
+        convert_hf_config(nemo_config, nemo_tokenizer, vocab_size, dtype, args.hf_output_path, args.hf_input_path)
+        extract_nemotron_tokenizer(args.input_name_or_path, nemo_config, args.hf_output_path, nemo_tokenizer)
+    else:
+        logging.info("`hf_input_path` and/or `hf_output_path` not provided, not generating full HF model.")
+        logging.info(f".bin file is saved to {args.output_path}")
diff --git a/scripts/checkpoint_converters/convert_siglip_hf_to_nemo.py b/scripts/checkpoint_converters/convert_siglip_hf_to_nemo.py
index 97a9d557f78b..053b3a053884 100644
--- a/scripts/checkpoint_converters/convert_siglip_hf_to_nemo.py
+++ b/scripts/checkpoint_converters/convert_siglip_hf_to_nemo.py
@@ -13,11 +13,10 @@
 # limitations under the License.
 
 """
-Requires HF transformers updated to support Gemma Models
-   python3 /opt/NeMo/scripts/nlp_language_modeling/convert_gemma_hf_to_nemo.py \
-   --input_name_or_path /path/to/gemma/checkpoints/hf/7b \
-   --output_path /path/to/gemma-7b.nemo \
-   --tokenizer_path /path/to/tokenizer.model
+Requires HF transformers updated to support Siglip Models
+    python /opt/NeMo/scripts/checkpoint_converters/convert_siglip_hf_to_nemo.py \
+      --input_name_or_path=google/siglip-so400m-patch14-384 \
+      --output_path=test.nemo
 """
 
 import os
@@ -352,7 +351,7 @@ def get_args():
 def convert(args):
     logging.info(f"Loading checkpoint from HF: `{args.input_name_or_path}`")
     hf_model = AutoModel.from_pretrained(args.input_name_or_path)
-    # hf_processor = AutoProcessor.from_pretrained(args.input_name_or_path)
+    hf_processor = AutoProcessor.from_pretrained(args.input_name_or_path)
     logging.info("HF Model loading done.")
 
     nemo_config = OmegaConf.load(args.hparams_file)
@@ -369,6 +368,35 @@ def convert(args):
     nemo_state_dict = adjust_tensor_shapes(model, new_state_dict)
     model.load_state_dict(nemo_state_dict, strict=False)
 
+    logging.info(f'=' * 100)
+    # Verifications
+    import requests
+    from PIL import Image
+
+    url = "http://images.cocodataset.org/val2017/000000039769.jpg"
+    image = Image.open(requests.get(url, stream=True).raw)
+
+    texts = ["a photo of 2 cats", "a photo of 2 dogs"]
+    inputs = hf_processor(text=texts, images=image, padding="max_length", return_tensors="pt")
+
+    tokens = inputs["input_ids"].cuda()
+    text_model = model.model.text_encoder.cuda()
+    hf_text_model = hf_model.text_model.cuda()
+    text_model_output = text_model(tokens)
+    hf_text_model_output = hf_text_model(tokens).pooler_output
+    assert torch.allclose(text_model_output, hf_text_model_output, atol=0.01)
+    logging.info(f'! Text model results matched.')
+
+    pixels = inputs["pixel_values"].cuda()
+    vision_model = model.model.vision_encoder.cuda()
+    hf_vision_model = hf_model.vision_model.cuda()
+    vision_model_output = vision_model(pixels)
+    hf_vision_model_output = hf_vision_model(pixels).pooler_output
+    assert torch.allclose(vision_model_output, hf_vision_model_output, atol=0.01)
+    logging.info(f'! Vision model results matched.')
+
+    logging.info(f'=' * 100)
+
     dtype = torch_dtype_from_precision(args.precision)
     model = model.to(dtype=dtype)
     model.save_to(args.output_path)
diff --git a/scripts/checkpoint_converters/convert_stablediffusion_hf_to_nemo.py b/scripts/checkpoint_converters/convert_stablediffusion_hf_to_nemo.py
new file mode 100644
index 000000000000..ff10dab4bc90
--- /dev/null
+++ b/scripts/checkpoint_converters/convert_stablediffusion_hf_to_nemo.py
@@ -0,0 +1,452 @@
+# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+r"""
+Conversion script to convert HuggingFace StableDiffusion checkpoints into nemo checkpoint.
+  Example to run this conversion script:
+    python convert_hf_starcoder2_to_nemo.py \
+     --input_name_or_path  \
+     --output_path  --model 
+"""
+
+import os
+from argparse import ArgumentParser
+
+import numpy as np
+import safetensors
+import torch
+import torch.nn
+
+from nemo.utils import logging
+
+
+def filter_keys(rule, dict):
+    keys = list(dict.keys())
+    nd = {k: dict[k] for k in keys if rule(k)}
+    return nd
+
+
+def map_keys(rule, dict):
+    new = {rule(k): v for k, v in dict.items()}
+    return new
+
+
+def split_name(name, dots=0):
+    l = name.split(".")
+    return ".".join(l[: dots + 1]), ".".join(l[dots + 1 :])
+
+
+def is_prefix(shortstr, longstr):
+    # is the first string a prefix of the second one
+    return longstr == shortstr or longstr.startswith(shortstr + ".")
+
+
+def numdots(str):
+    return str.count(".")
+
+
+class SegTree:
+    def __init__(self):
+        self.nodes = dict()
+        self.val = None
+        self.final_val = 0
+        self.convert_name = None
+
+    def __len__(self):
+        return len(self.nodes)
+
+    def is_leaf(self):
+        return len(self.nodes) == 0
+
+    def add(self, name, val=0):
+        prefix, subname = split_name(name)
+        if subname == '':
+            self.nodes[name] = SegTree()
+            self.nodes[name].val = val
+            return
+        if self.nodes.get(prefix) is None:
+            self.nodes[prefix] = SegTree()
+        self.nodes[prefix].add(subname, val)
+
+    def change(self, name, val):
+        self.add(name, val)
+
+    def __getitem__(self, name: str):
+        if hasattr(self, name):
+            return getattr(self, name)
+        val = self.nodes.get(name)
+        if val is None:
+            # straight lookup failed, do a prefix lookup
+            keys = list(self.nodes.keys())
+            p_flag = [is_prefix(k, name) for k in keys]
+            if not any(p_flag):
+                return None
+            # either more than 1 match (error) or exactly 1 (success)
+            if np.sum(p_flag) > 1:
+                logging.warning(f"warning: multiple matches of key {name} with {keys}")
+            else:
+                i = np.where(p_flag)[0][0]
+                n = numdots(keys[i])
+                prefix, substr = split_name(name, n)
+                return self.nodes[prefix][substr]
+        return val
+
+
+def model_to_tree(model):
+    keys = list(model.keys())
+    tree = SegTree()
+    for k in keys:
+        tree.add(k, "leaf")
+    return tree
+
+
+def get_args():
+    parser = ArgumentParser()
+    parser.add_argument(
+        "--input_name_or_path",
+        type=str,
+        default=None,
+        required=True,
+        help="Path to Huggingface UNet checkpoints",
+    )
+    parser.add_argument("--output_path", type=str, default=None, required=True, help="Path to output .nemo file.")
+    parser.add_argument("--precision", type=str, default="32", help="Model precision")
+    parser.add_argument("--model", type=str, default="unet", required=True, choices=['unet', 'vae'])
+    parser.add_argument("--debug", action='store_true', help="Useful for debugging purposes.")
+
+    args = parser.parse_args()
+    return args
+
+
+def load_hf_ckpt(in_dir, args):
+    ckpt = {}
+    assert os.path.isdir(in_dir), "Currently supports only directories with a safetensor file in it."
+    with safetensors.safe_open(in_dir + "/diffusion_pytorch_model.safetensors", framework="pt") as f:
+        for k in f.keys():
+            ckpt[k] = f.get_tensor(k)
+    return args, ckpt
+
+
+def dup_convert_name_recursive(tree: SegTree, convert_name=None):
+    '''inside this tree, convert all nodes recursively
+    optionally, convert the name of the root as given by name (if not None)
+    '''
+    if tree is None:
+        return
+    if convert_name is not None:
+        tree.convert_name = convert_name
+    # recursively copy the name into convert_name
+    for k, v in tree.nodes.items():
+        dup_convert_name_recursive(v, k)
+
+
+def sanity_check(hf_tree, hf_unet, nemo_unet):
+    # check if i'm introducing new keys
+    for hfk, nk in hf_to_nemo_mapping(hf_tree).items():
+        if nk not in nemo_unet.keys():
+            logging.info(nk)
+        if hfk not in hf_unet.keys():
+            logging.info(hfk)
+
+
+def convert_input_keys(hf_tree: SegTree):
+    '''map the input blocks of huggingface model'''
+    # map `conv_in` to first input block
+    dup_convert_name_recursive(hf_tree['conv_in'], 'input_blocks.0.0')
+
+    # start counting blocks from now on
+    nemo_inp_blk = 1
+    down_blocks = hf_tree['down_blocks']
+    down_blocks_keys = sorted(list(down_blocks.nodes.keys()), key=int)
+    for downblockid in down_blocks_keys:
+        block = down_blocks[str(downblockid)]
+        # compute number of resnets, attentions, downsamplers in this block
+        resnets = block.nodes.get('resnets', SegTree())
+        attentions = block.nodes.get('attentions', SegTree())
+        downsamplers = block.nodes.get('downsamplers', SegTree())
+
+        if len(attentions) == 0:  # no attentions, this is a DownBlock2d
+            for resid in sorted(list(resnets.nodes.keys()), key=int):
+                resid = str(resid)
+                resnets[resid].convert_name = f"input_blocks.{nemo_inp_blk}.0"
+                map_resnet_block(resnets[resid])
+                nemo_inp_blk += 1
+        elif len(attentions) == len(resnets):
+            # there are attention blocks here -- each resnet+attention becomes a block
+            for resid in sorted(list(resnets.nodes.keys()), key=int):
+                resid = str(resid)
+                resnets[resid].convert_name = f"input_blocks.{nemo_inp_blk}.0"
+                map_resnet_block(resnets[resid])
+                attentions[resid].convert_name = f"input_blocks.{nemo_inp_blk}.1"
+                map_attention_block(attentions[resid])
+                nemo_inp_blk += 1
+        else:
+            logging.warning("number of attention blocks is not the same as resnets - whats going on?")
+        # if there is a downsampler, then also append it
+        if len(downsamplers) > 0:
+            for k in downsamplers.nodes.keys():
+                downsamplers[k].convert_name = f"input_blocks.{nemo_inp_blk}.{k}"
+                dup_convert_name_recursive(downsamplers[k]['conv'], 'op')
+            nemo_inp_blk += 1
+
+
+def clean_convert_names(tree):
+    tree.convert_name = None
+    for k, v in tree.nodes.items():
+        clean_convert_names(v)
+
+
+def map_attention_block(att_tree: SegTree):
+    '''this HF tree can either be an AttentionBlock or a DualAttention block
+    currently assumed AttentionBlock
+    '''
+
+    # TODO(@rohitrango): Add check for dual attention block, but this works for both SD and SDXL
+    def check_att_type(tree):
+        return "att_block"
+
+    if check_att_type(att_tree) == 'att_block':
+        dup_convert_name_recursive(att_tree['norm'], 'norm')
+        dup_convert_name_recursive(att_tree['proj_in'], 'proj_in')
+        dup_convert_name_recursive(att_tree['proj_out'], 'proj_out')
+        tblockids = list(att_tree['transformer_blocks'].nodes.keys())
+        for t in tblockids:
+            tblock = att_tree[f'transformer_blocks.{t}']
+            tblock.convert_name = f"transformer_blocks.{t}"
+            dup_convert_name_recursive(tblock['attn1'], 'attn1')
+            dup_convert_name_recursive(tblock['attn2'], 'attn2')
+            dup_convert_name_recursive(tblock['norm1'], 'attn1.norm')
+            dup_convert_name_recursive(tblock['norm2'], 'attn2.norm')
+            dup_convert_name_recursive(tblock['norm3'], 'ff.net.0')
+            # map ff
+            tblock['ff'].convert_name = "ff"
+            tblock['ff.net'].convert_name = 'net'
+            dup_convert_name_recursive(tblock['ff.net.0'], '1')
+            dup_convert_name_recursive(tblock['ff.net.2'], '3')
+    else:
+        logging.warning("failed to identify type of attention block here.")
+
+
+def map_resnet_block(resnet_tree: SegTree):
+    '''this HF tree is supposed to have all the keys for a resnet'''
+    dup_convert_name_recursive(resnet_tree.nodes.get('time_emb_proj'), 'emb_layers.1')
+    dup_convert_name_recursive(resnet_tree['norm1'], 'in_layers.0')
+    dup_convert_name_recursive(resnet_tree['conv1'], 'in_layers.1')
+    dup_convert_name_recursive(resnet_tree['norm2'], 'out_layers.0')
+    dup_convert_name_recursive(resnet_tree['conv2'], 'out_layers.2')
+    dup_convert_name_recursive(resnet_tree.nodes.get('conv_shortcut'), 'skip_connection')
+
+
+def hf_to_nemo_mapping(tree: SegTree):
+    mapping = {}
+    for nodename, subtree in tree.nodes.items():
+        convert_name = subtree.convert_name
+        convert_name = (convert_name + ".") if convert_name is not None else ""
+        if subtree.is_leaf() and subtree.convert_name is not None:
+            mapping[nodename] = subtree.convert_name
+        else:
+            submapping = hf_to_nemo_mapping(subtree)
+            for k, v in submapping.items():
+                mapping[nodename + "." + k] = convert_name + v
+    return mapping
+
+
+def convert_cond_keys(tree: SegTree):
+    # map all conditioning keys
+    if tree.nodes.get("add_embedding"):
+        logging.info("Add embedding found...")
+        tree['add_embedding'].convert_name = 'label_emb.0'
+        dup_convert_name_recursive(tree['add_embedding.linear_1'], '0')
+        dup_convert_name_recursive(tree['add_embedding.linear_2'], '2')
+    if tree.nodes.get("time_embedding"):
+        logging.info("Time embedding found...")
+        tree['time_embedding'].convert_name = 'time_embed'
+        dup_convert_name_recursive(tree['time_embedding.linear_1'], '0')
+        dup_convert_name_recursive(tree['time_embedding.linear_2'], '2')
+
+
+def convert_middle_keys(tree: SegTree):
+    '''middle block is fixed (resnet -> attention -> resnet)'''
+    mid = tree['mid_block']
+    resnets = mid['resnets']
+    attns = mid['attentions']
+    mid.convert_name = 'middle_block'
+    resnets['0'].convert_name = '0'
+    resnets['1'].convert_name = '2'
+    attns['0'].convert_name = '1'
+    map_resnet_block(resnets['0'])
+    map_resnet_block(resnets['1'])
+    map_attention_block(attns['0'])
+
+
+def convert_output_keys(hf_tree: SegTree):
+    '''output keys is similar to input keys'''
+    nemo_inp_blk = 0
+    up_blocks = hf_tree['up_blocks']
+    up_blocks_keys = sorted(list(up_blocks.nodes.keys()), key=int)
+
+    for downblockid in up_blocks_keys:
+        block = up_blocks[str(downblockid)]
+        # compute number of resnets, attentions, downsamplers in this block
+        resnets = block.nodes.get('resnets', SegTree())
+        attentions = block.nodes.get('attentions', SegTree())
+        upsamplers = block.nodes.get('upsamplers', SegTree())
+
+        if len(attentions) == 0:  # no attentions, this is a UpBlock2D
+            for resid in sorted(list(resnets.nodes.keys()), key=int):
+                resid = str(resid)
+                resnets[resid].convert_name = f"output_blocks.{nemo_inp_blk}.0"
+                map_resnet_block(resnets[resid])
+                nemo_inp_blk += 1
+
+        elif len(attentions) == len(resnets):
+            # there are attention blocks here -- each resnet+attention becomes a block
+            for resid in sorted(list(resnets.nodes.keys()), key=int):
+                resid = str(resid)
+                resnets[resid].convert_name = f"output_blocks.{nemo_inp_blk}.0"
+                map_resnet_block(resnets[resid])
+                attentions[resid].convert_name = f"output_blocks.{nemo_inp_blk}.1"
+                map_attention_block(attentions[resid])
+                nemo_inp_blk += 1
+        else:
+            logging.warning("number of attention blocks is not the same as resnets - whats going on?")
+
+        # if there is a upsampler, then also append it
+        if len(upsamplers) > 0:
+            nemo_inp_blk -= 1
+            upsamplenum = (
+                1 if len(attentions) == 0 else 2
+            )  # if there are attention modules, upsample is module2, else it is module 1 (to stay consistent with SD)
+            upsamplers['0'].convert_name = f"output_blocks.{nemo_inp_blk}.{upsamplenum}"
+            dup_convert_name_recursive(upsamplers['0.conv'], 'conv')
+            nemo_inp_blk += 1
+
+
+def convert_finalout_keys(hf_tree: SegTree):
+    dup_convert_name_recursive(hf_tree['conv_norm_out'], "out.0")
+    dup_convert_name_recursive(hf_tree['conv_out'], "out.1")
+
+
+def convert_encoder(hf_tree: SegTree):
+    encoder = hf_tree['encoder']
+    encoder.convert_name = 'encoder'
+    dup_convert_name_recursive(encoder['conv_in'], 'conv_in')
+    dup_convert_name_recursive(encoder['conv_out'], 'conv_out')
+    dup_convert_name_recursive(encoder['conv_norm_out'], 'norm_out')
+
+    # each block contains resnets and downsamplers
+    # there are also optional attention blocks in the down module, but I havent encountered them yet
+    encoder['down_blocks'].convert_name = 'down'
+    for downid, downblock in encoder['down_blocks'].nodes.items():
+        downblock.convert_name = downid
+        downsamplers = downblock.nodes.get('downsamplers', SegTree())
+        dup_convert_name_recursive(downblock['resnets'], 'block')
+        # check for conv_shortcuts here
+        for resid, resnet in downblock['resnets'].nodes.items():
+            if resnet.nodes.get('conv_shortcut') is not None:
+                resnet.nodes['conv_shortcut'].convert_name = 'nin_shortcut'
+        if len(downsamplers) > 0:
+            dup_convert_name_recursive(downsamplers['0'], 'downsample')
+
+    # map the `mid_block` ( NeMo's mid layer is hardcoded in terms of number of modules)
+    encoder['mid_block'].convert_name = 'mid'
+    dup_convert_name_recursive(encoder[f'mid_block.resnets.0'], 'block_1')
+    dup_convert_name_recursive(encoder[f'mid_block.resnets.1'], 'block_2')
+
+    # attention part
+    att = encoder['mid_block.attentions.0']
+    att.convert_name = 'attn_1'
+    dup_convert_name_recursive(att['group_norm'], 'norm')
+    dup_convert_name_recursive(att['to_k'], 'k')
+    dup_convert_name_recursive(att['to_q'], 'q')
+    dup_convert_name_recursive(att['to_v'], 'v')
+    dup_convert_name_recursive(att['to_out.0'], 'proj_out')
+
+
+def convert_decoder(hf_tree: SegTree):
+    decoder = hf_tree['decoder']
+    decoder.convert_name = 'decoder'
+    dup_convert_name_recursive(decoder['conv_in'], 'conv_in')
+    dup_convert_name_recursive(decoder['conv_out'], 'conv_out')
+    dup_convert_name_recursive(decoder['conv_norm_out'], 'norm_out')
+    # each block contains resnets and downsamplers
+    # map the `mid_block` ( NeMo's mid layer is hardcoded in terms of number of modules)
+    decoder['mid_block'].convert_name = 'mid'
+    dup_convert_name_recursive(decoder[f'mid_block.resnets.0'], 'block_1')
+    dup_convert_name_recursive(decoder[f'mid_block.resnets.1'], 'block_2')
+    # attention blocks
+    att = decoder['mid_block.attentions.0']
+    att.convert_name = 'attn_1'
+    dup_convert_name_recursive(att['group_norm'], 'norm')
+    dup_convert_name_recursive(att['to_k'], 'k')
+    dup_convert_name_recursive(att['to_q'], 'q')
+    dup_convert_name_recursive(att['to_v'], 'v')
+    dup_convert_name_recursive(att['to_out.0'], 'proj_out')
+
+    # up blocks contain resnets and upsamplers
+    decoder['up_blocks'].convert_name = 'up'
+    num_up_blocks = len(decoder['up_blocks'])
+    for upid, upblock in decoder['up_blocks'].nodes.items():
+        upblock.convert_name = str(num_up_blocks - 1 - int(upid))
+        upsamplers = upblock.nodes.get('upsamplers', SegTree())
+        dup_convert_name_recursive(upblock['resnets'], 'block')
+        # check for conv_shortcuts here
+        for resid, resnet in upblock['resnets'].nodes.items():
+            if resnet.nodes.get('conv_shortcut') is not None:
+                resnet.nodes['conv_shortcut'].convert_name = 'nin_shortcut'
+        if len(upsamplers) > 0:
+            dup_convert_name_recursive(upsamplers['0'], 'upsample')
+
+
+def convert(args):
+    logging.info(f"loading checkpoint {args.input_name_or_path}")
+    _, hf_ckpt = load_hf_ckpt(args.input_name_or_path, args)
+    hf_tree = model_to_tree(hf_ckpt)
+
+    if args.model == 'unet':
+        logging.info("converting unet...")
+        convert_input_keys(hf_tree)
+        convert_cond_keys(hf_tree)
+        convert_middle_keys(hf_tree)
+        convert_output_keys(hf_tree)
+        convert_finalout_keys(hf_tree)
+        # get mapping
+
+    elif args.model == 'vae':
+        logging.info("converting vae...")
+        dup_convert_name_recursive(hf_tree['quant_conv'], 'quant_conv')
+        dup_convert_name_recursive(hf_tree['post_quant_conv'], 'post_quant_conv')
+        convert_encoder(hf_tree)
+        convert_decoder(hf_tree)
+
+    else:
+        logging.error("incorrect model specification.")
+        return
+
+    # check mapping
+    mapping = hf_to_nemo_mapping(hf_tree)
+    if len(mapping) != len(hf_ckpt.keys()):
+        logging.warning("not all keys are matched properly.")
+    nemo_ckpt = {}
+
+    for hf_key, nemo_key in mapping.items():
+        nemo_ckpt[nemo_key] = hf_ckpt[hf_key]
+    # save this
+    torch.save(nemo_ckpt, args.output_path)
+    logging.info(f"Saved nemo file to {args.output_path}")
+
+
+if __name__ == '__main__':
+    args = get_args()
+    convert(args)
diff --git a/scripts/deploy/multimodal/deploy_triton.py b/scripts/deploy/multimodal/deploy_triton.py
index 1e339b3405cf..d0bf8f10548a 100755
--- a/scripts/deploy/multimodal/deploy_triton.py
+++ b/scripts/deploy/multimodal/deploy_triton.py
@@ -48,8 +48,8 @@ def get_args(argv):
         "--model_type",
         type=str,
         required=True,
-        choices=["neva", "video-neva"],
-        help="Type of the model. neva and video-neva are only supported.",
+        choices=["neva", "video-neva", "lita", "vila", "vita"],
+        help="Type of the model that is supported.",
     )
     parser.add_argument(
         "-lmt",
@@ -82,8 +82,15 @@ def get_args(argv):
     )
     parser.add_argument("-mil", "--max_input_len", default=4096, type=int, help="Max input length of the model")
     parser.add_argument("-mol", "--max_output_len", default=256, type=int, help="Max output length of the model")
-    parser.add_argument("-mbs", "--max_batch_size", default=1, type=int, help="Max batch size of the model")
+    parser.add_argument("-mbs", "--max_batch_size", default=1, type=int, help="Max batch size of the llm model")
     parser.add_argument("-mml", "--max_multimodal_len", default=3072, type=int, help="Max length of multimodal input")
+    parser.add_argument(
+        "-vmb",
+        "--vision_max_batch_size",
+        default=1,
+        type=int,
+        help="Max batch size of the visual inputs, for lita/vita model with video inference, this should be set to 256",
+    )
     args = parser.parse_args(argv)
     return args
 
@@ -131,6 +138,7 @@ def get_trt_deployable(args):
                 tensor_parallel_size=args.num_gpus,
                 max_input_len=args.max_input_len,
                 max_output_len=args.max_output_len,
+                vision_max_batch_size=args.vision_max_batch_size,
                 max_batch_size=args.max_batch_size,
                 max_multimodal_len=args.max_multimodal_len,
                 dtype=args.dtype,
diff --git a/scripts/deploy/nlp/deploy_inframework_triton.py b/scripts/deploy/nlp/deploy_inframework_triton.py
new file mode 100755
index 000000000000..b698e4cbacfd
--- /dev/null
+++ b/scripts/deploy/nlp/deploy_inframework_triton.py
@@ -0,0 +1,103 @@
+# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import logging
+import sys
+
+from nemo.deploy import DeployPyTriton
+
+LOGGER = logging.getLogger("NeMo")
+
+megatron_llm_supported = True
+try:
+    from nemo.deploy.nlp import MegatronLLMDeployable
+except Exception as e:
+    LOGGER.warning(f"Cannot import MegatronLLMDeployable, it will not be available. {type(e).__name__}: {e}")
+    megatron_llm_supported = False
+
+
+def get_args(argv):
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+        description=f"Deploy nemo models to Triton",
+    )
+    parser.add_argument("-nc", "--nemo_checkpoint", type=str, help="Source .nemo file")
+    parser.add_argument("-tmn", "--triton_model_name", required=True, type=str, help="Name for the service")
+    parser.add_argument("-tmv", "--triton_model_version", default=1, type=int, help="Version for the service")
+    parser.add_argument(
+        "-trp", "--triton_port", default=8000, type=int, help="Port for the Triton server to listen for requests"
+    )
+    parser.add_argument(
+        "-tha", "--triton_http_address", default="0.0.0.0", type=str, help="HTTP address for the Triton server"
+    )
+    parser.add_argument("-ng", "--num_gpus", default=1, type=int, help="Number of GPUs for the deployment")
+    parser.add_argument("-mbs", "--max_batch_size", default=8, type=int, help="Max batch size of the model")
+    parser.add_argument("-dm", "--debug_mode", default=False, action='store_true', help="Enable debug mode")
+    args = parser.parse_args(argv)
+    return args
+
+
+def get_nemo_deployable(args):
+    if args.nemo_checkpoint is None:
+        raise ValueError("In-Framework deployment requires a .nemo checkpoint")
+
+    return MegatronLLMDeployable(args.nemo_checkpoint, args.num_gpus)
+
+
+def nemo_deploy(argv):
+    args = get_args(argv)
+
+    if args.debug_mode:
+        loglevel = logging.DEBUG
+    else:
+        loglevel = logging.INFO
+
+    LOGGER.setLevel(loglevel)
+    LOGGER.info("Logging level set to {}".format(loglevel))
+    LOGGER.info(args)
+
+    if not megatron_llm_supported:
+        raise ValueError("MegatronLLMDeployable is not supported in this environment.")
+    triton_deployable = get_nemo_deployable(args)
+
+    try:
+        nm = DeployPyTriton(
+            model=triton_deployable,
+            triton_model_name=args.triton_model_name,
+            triton_model_version=args.triton_model_version,
+            max_batch_size=args.max_batch_size,
+            port=args.triton_port,
+            address=args.triton_http_address,
+        )
+
+        LOGGER.info("Triton deploy function will be called.")
+        nm.deploy()
+    except Exception as error:
+        LOGGER.error("Error message has occurred during deploy function. Error message: " + str(error))
+        return
+
+    try:
+        LOGGER.info("Model serving on Triton is will be started.")
+        nm.serve()
+    except Exception as error:
+        LOGGER.error("Error message has occurred during deploy function. Error message: " + str(error))
+        return
+
+    LOGGER.info("Model serving will be stopped.")
+    nm.stop()
+
+
+if __name__ == '__main__':
+    nemo_deploy(sys.argv[1:])
diff --git a/scripts/deploy/nlp/deploy_triton.py b/scripts/deploy/nlp/deploy_triton.py
index a306231bcd61..01be9ff63a0d 100755
--- a/scripts/deploy/nlp/deploy_triton.py
+++ b/scripts/deploy/nlp/deploy_triton.py
@@ -13,6 +13,7 @@
 # limitations under the License.
 
 import argparse
+import json
 import logging
 import os
 import sys
@@ -73,10 +74,13 @@ def get_args(argv):
     parser.add_argument(
         "-tha", "--triton_http_address", default="0.0.0.0", type=str, help="HTTP address for the Triton server"
     )
+    parser.add_argument(
+        "-trt", "--triton_request_timeout", default=60, type=int, help="Timeout in seconds for Triton server"
+    )
     parser.add_argument(
         "-tmr", "--triton_model_repository", default=None, type=str, help="Folder for the trt-llm conversion"
     )
-    parser.add_argument("-ng", "--num_gpus", default=1, type=int, help="Number of GPUs for the deployment")
+    parser.add_argument("-ng", "--num_gpus", default=None, type=int, help="Number of GPUs for the deployment")
     parser.add_argument("-tps", "--tensor_parallelism_size", default=1, type=int, help="Tensor parallelism size")
     parser.add_argument("-pps", "--pipeline_parallelism_size", default=1, type=int, help="Pipeline parallelism size")
     parser.add_argument(
@@ -91,7 +95,13 @@ def get_args(argv):
     parser.add_argument("-mol", "--max_output_len", default=256, type=int, help="Max output length of the model")
     parser.add_argument("-mbs", "--max_batch_size", default=8, type=int, help="Max batch size of the model")
     parser.add_argument("-mnt", "--max_num_tokens", default=None, type=int, help="Max number of tokens")
+    parser.add_argument("-msl", "--max_seq_len", default=None, type=int, help="Maximum number of sequence length")
+    parser.add_argument("-mp", "--multiple_profiles", default=False, action='store_true', help="Multiple profiles")
     parser.add_argument("-ont", "--opt_num_tokens", default=None, type=int, help="Optimum number of tokens")
+    parser.add_argument(
+        "-gap", "--gpt_attention_plugin", default="auto", type=str, help="dtype of gpt attention plugin"
+    )
+    parser.add_argument("-gp", "--gemm_plugin", default="auto", type=str, help="dtype of gpt plugin")
     parser.add_argument(
         "-mpet", "--max_prompt_embedding_table_size", default=None, type=int, help="Max prompt embedding table size"
     )
@@ -183,11 +193,33 @@ def get_args(argv):
         "-sha", "--service_http_address", default="0.0.0.0", type=str, help="HTTP address for the REST Service"
     )
     parser.add_argument("-sp", "--service_port", default=8080, type=int, help="Port for the REST Service")
+    parser.add_argument(
+        "-ofr",
+        "--openai_format_response",
+        default=False,
+        type=bool,
+        help="Return the response from PyTriton server in OpenAI compatible format",
+    )
     parser.add_argument("-dm", "--debug_mode", default=False, action='store_true', help="Enable debug mode")
     args = parser.parse_args(argv)
     return args
 
 
+def store_args_to_json(args):
+    """
+    Stores user defined arg values relevant for REST API in config.json
+    Gets called only when args.start_rest_service is True.
+    """
+    args_dict = {
+        "triton_service_ip": args.triton_http_address,
+        "triton_service_port": args.triton_port,
+        "triton_request_timeout": args.triton_request_timeout,
+        "openai_format_response": args.openai_format_response,
+    }
+    with open("nemo/deploy/service/config.json", "w") as f:
+        json.dump(args_dict, f)
+
+
 def get_trtllm_deployable(args):
     if args.triton_model_repository is None:
         trt_llm_path = "/tmp/trt_llm_model_dir/"
@@ -237,11 +269,6 @@ def get_trtllm_deployable(args):
                     "There are {0} tables and {1} task ids.".format(len(ptuning_tables_files), len(args.task_ids))
                 )
 
-    if args.start_rest_service:
-        if args.service_port == args.triton_port:
-            logging.error("REST service port and Triton server port cannot use the same port.")
-            return
-
     trt_llm_exporter = TensorRTLLM(
         model_dir=trt_llm_path,
         lora_ckpt_list=args.lora_ckpt,
@@ -263,6 +290,7 @@ def get_trtllm_deployable(args):
                 max_batch_size=args.max_batch_size,
                 max_num_tokens=args.max_num_tokens,
                 opt_num_tokens=args.opt_num_tokens,
+                max_seq_len=args.max_seq_len,
                 use_parallel_embedding=args.use_parallel_embedding,
                 max_prompt_embedding_table_size=args.max_prompt_embedding_table_size,
                 paged_kv_cache=(not args.no_paged_kv_cache),
@@ -272,6 +300,9 @@ def get_trtllm_deployable(args):
                 use_lora_plugin=args.use_lora_plugin,
                 lora_target_modules=args.lora_target_modules,
                 max_lora_rank=args.max_lora_rank,
+                multiple_profiles=args.multiple_profiles,
+                gpt_attention_plugin=args.gpt_attention_plugin,
+                gemm_plugin=args.gemm_plugin,
             )
         except Exception as error:
             raise RuntimeError("An error has occurred during the model export. Error message: " + str(error))
@@ -318,6 +349,13 @@ def nemo_deploy(argv):
     LOGGER.info("Logging level set to {}".format(loglevel))
     LOGGER.info(args)
 
+    if args.start_rest_service:
+        if args.service_port == args.triton_port:
+            logging.error("REST service port and Triton server port cannot use the same port.")
+            return
+        # Store triton ip, port and other args relevant for REST API in config.json to be accessible by rest_model_api.py
+        store_args_to_json(args)
+
     backend = args.backend.lower()
     if backend == 'tensorrt-llm':
         if not trt_llm_supported:
diff --git a/scripts/deploy/nlp/query_inframework.py b/scripts/deploy/nlp/query_inframework.py
new file mode 100644
index 000000000000..e77ab72a1f04
--- /dev/null
+++ b/scripts/deploy/nlp/query_inframework.py
@@ -0,0 +1,83 @@
+# Copyright (c) 2023, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import sys
+
+from nemo.deploy.nlp.query_llm import NemoQueryLLMPyTorch
+
+
+def get_args(argv):
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+        description=f"Queries Triton server running an in-framework Nemo model",
+    )
+    parser.add_argument("-u", "--url", default="0.0.0.0", type=str, help="url for the triton server")
+    parser.add_argument("-mn", "--model_name", required=True, type=str, help="Name of the triton model")
+    prompt_group = parser.add_mutually_exclusive_group(required=True)
+    prompt_group.add_argument("-p", "--prompt", required=False, type=str, help="Prompt")
+    prompt_group.add_argument("-pf", "--prompt_file", required=False, type=str, help="File to read the prompt from")
+    parser.add_argument("-mol", "--max_output_len", default=128, type=int, help="Max output token length")
+    parser.add_argument("-tk", "--top_k", default=1, type=int, help="top_k")
+    parser.add_argument("-tpp", "--top_p", default=0.0, type=float, help="top_p")
+    parser.add_argument("-t", "--temperature", default=1.0, type=float, help="temperature")
+    parser.add_argument("-it", "--init_timeout", default=60.0, type=float, help="init timeout for the triton server")
+
+    args = parser.parse_args(argv)
+    return args
+
+
+def query_llm(
+    url,
+    model_name,
+    prompts,
+    max_output_len=128,
+    top_k=1,
+    top_p=0.0,
+    temperature=1.0,
+    init_timeout=60.0,
+):
+    nemo_query = NemoQueryLLMPyTorch(url, model_name)
+    return nemo_query.query_llm(
+        prompts=prompts,
+        max_length=max_output_len,
+        top_k=top_k,
+        top_p=top_p,
+        temperature=temperature,
+        init_timeout=init_timeout,
+    )
+
+
+def query(argv):
+    args = get_args(argv)
+
+    if args.prompt_file is not None:
+        with open(args.prompt_file, "r") as f:
+            args.prompt = f.read()
+
+    outputs = query_llm(
+        url=args.url,
+        model_name=args.model_name,
+        prompts=[args.prompt],
+        max_output_len=args.max_output_len,
+        top_k=args.top_k,
+        top_p=args.top_p,
+        temperature=args.temperature,
+        init_timeout=args.init_timeout,
+    )
+    print(outputs["sentences"][0][0])
+
+
+if __name__ == '__main__':
+    query(sys.argv[1:])
diff --git a/scripts/installers/install_k2.sh b/scripts/installers/install_k2.sh
index 18d948209ab8..6de80ecae3eb 100755
--- a/scripts/installers/install_k2.sh
+++ b/scripts/installers/install_k2.sh
@@ -15,7 +15,7 @@
 # limitations under the License.
 
 K2_REPO=https://github.com/k2-fsa/k2
-LATEST_RELEASE=525cfa5 # fix for PyTorch 2.2.0
+LATEST_RELEASE=5735fa7 # fix for PyTorch 2.4.0
 # uncomment the following line after the next k2 version is released (>1.24.4)
 #LATEST_RELEASE=$(git -c 'versionsort.suffix=-' \
 #    ls-remote --exit-code --refs --sort='version:refname' --tags ${K2_REPO} '*.*' \
diff --git a/scripts/installers/install_torchaudio_latest.sh b/scripts/installers/install_torchaudio_latest.sh
index 9e72be5e51d6..bdad771fe267 100755
--- a/scripts/installers/install_torchaudio_latest.sh
+++ b/scripts/installers/install_torchaudio_latest.sh
@@ -92,10 +92,12 @@ echo "Installing torchaudio from branch: ${INSTALL_BRANCH}"
 pip install parameterized
 
 # Build torchaudio and run MFCC test
+# NB: setting PYTORCH_VERSION is a workaround for the case where PYTORCH_VERSION is set, but contains incorrect value
+# e.g., in container nvcr.io/nvidia/pytorch:24.03-py3
 git clone --depth 1 --branch ${INSTALL_BRANCH} https://github.com/pytorch/audio.git && \
 cd audio && \
 git submodule update --init --recursive && \
-USE_FFMPEG=1 BUILD_SOX=1 BUILD_VERSION=${TORCHAUDIO_BUILD_VERSION} python setup.py install && \
+PYTORCH_VERSION=${TORCH_FULL_VERSION} USE_FFMPEG=1 BUILD_SOX=1 BUILD_VERSION=${TORCHAUDIO_BUILD_VERSION} python setup.py install && \
 cd .. && \
 pytest -rs audio/test/torchaudio_unittest/transforms/torchscript_consistency_cpu_test.py -k 'test_MFCC' || \
 { echo "ERROR: Failed to install torchaudio!"; exit 1; };
diff --git a/scripts/multimodal_dataset_conversion/convert_dvc_dataset_for_evaluation.py b/scripts/multimodal_dataset_conversion/convert_dvc_dataset_for_evaluation.py
new file mode 100644
index 000000000000..1427e0983b24
--- /dev/null
+++ b/scripts/multimodal_dataset_conversion/convert_dvc_dataset_for_evaluation.py
@@ -0,0 +1,160 @@
+# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+"""
+This script is used to convert the DVC dataset to the format required by the model evaluation for RTL task.
+The DVC dataset should have the below structure:
+{
+    "-4RXOT_UfpM_3": {          # video_name is the unique video file name, extention is .mp4
+        "duration": 118.01801801801803,
+        "timestamps": [
+            [5, 58], 
+            [66, 82],
+            [82, 96]
+        ],
+        "sentences": [
+            "Apply eyeshadow on the lower area then crease with brush",
+            "Apply eyeshadow on the outer corner of eyes with brush",
+            "Apply eyeshadow on the outer half of eyes with brush",
+        ]
+    },
+    ...
+}
+
+The converted format will be as follows:
+[
+    {
+        "video": "-4RXOT_UfpM_3.mp4",
+        "question_id": "-4RXOT_UfpM_3_0",
+        "question": "When does \"Apply eyeshadow on the lower area then crease with brush\" happen in the video? Provide a response using only start and end timestamps.",
+        "ref_answer": "<5> <58> Apply eyeshadow on the lower area then crease with brush",
+        "duration": 118.01801801801803
+    },
+    {
+        "video": "-4RXOT_UfpM_3.mp4",
+        "question_id": "-4RXOT_UfpM_3_1",
+        "question": "When is \"Apply eyeshadow on the outer corner of eyes with brush\" depicted in the video? Convey your answer using start and end timestamps exclusively.",
+        "ref_answer": "<66> <82> Apply eyeshadow on the outer corner of eyes with brush",
+        "duration": 118.01801801801803
+    },
+    {
+        "video": "-4RXOT_UfpM_3.mp4",
+        "question_id": "-4RXOT_UfpM_3_2",
+        "question": "When does \"Apply eyeshadow on the outer half of eyes with brush\" happen in the video? Provide a response using only start and end timestamps.",
+        "ref_answer": "<82> <96> Apply eyeshadow on the outer half of eyes with brush",
+        "duration": 118.01801801801803
+    },
+    .....
+]
+
+For each sentence in the sentences list, we will generate one question for it and the answer will be the sentence itself with the timestamps.
+USAGE:
+python convert_dvc_dataset_for_evaluation.py --input  --output_file  --ratio 
+
+"""
+
+import argparse
+import json
+import os
+import random
+
+
+class RTLConverter:
+    def __init__(self, input_file, output_file, sample_ratio, ext):
+        self.input_file = input_file
+        self.output_file = output_file
+        self.sample_ratio = sample_ratio
+        self.desc_prompts = [
+            "When does \"%s\" happen in the video?",
+            "At what point in the video does \"%s\" happen?",
+            "When is \"%s\" depicted in the video?",
+            "At what time in the video does \"%s\" take place?",
+        ]
+        self.time_prompts = [
+            "Answer the question only using start and end timestamps.",
+            "Provide a response using only start and end timestamps.",
+            "Convey your answer using start and end timestamps exclusively.",
+        ]
+        self.ext = ext
+
+    def convert(self):
+        converted_data = []
+
+        # Load JSON data
+        with open(self.input_file, 'r') as file:
+            data = json.load(file)
+
+        # Fix random seed for reproducibility
+        random.seed(42)
+
+        # Randomly sample entries based on the sample ratio
+        vid_list = list(data.keys())
+        sampled_vids = random.sample(vid_list, k=int(len(vid_list) * self.sample_ratio))
+
+        # Iterate through sampled entries
+        for vid in sampled_vids:
+            details = data[vid]
+            duration = details['duration']
+            timestamps = details['timestamps']
+            sentences = details['sentences']
+
+            # Iterate through sentences
+            for i, sentence in enumerate(sentences):
+                question_id = f"{vid}_{i}"
+                desc_prompt = random.choice(self.desc_prompts)
+                time_prompt = random.choice(self.time_prompts)
+                start_time, end_time = timestamps[i]
+                answer = f"<{start_time}> <{end_time}> {sentence}"
+
+                # Construct question
+                question = (desc_prompt % sentence) + ' ' + time_prompt
+
+                # Create entry in converted data
+                converted_data.append(
+                    {
+                        "video": vid + self.ext,
+                        "question_id": question_id,
+                        "question": question,
+                        "ref_answer": answer,
+                        "duration": duration,
+                    }
+                )
+
+        # Ensure the output directory exists
+        os.makedirs(os.path.dirname(self.output_file), exist_ok=True)
+
+        # Write converted data to output file
+        with open(self.output_file, 'w') as file:
+            json.dump(converted_data, file, indent=2)
+
+
+def main():
+    parser = argparse.ArgumentParser(description="Convert makeup QA JSON format")
+    parser.add_argument("--input", help="Input DVC JSON file", required=True)
+    parser.add_argument("--output_file", help="Output file", default="rtl_eval.json", required=True)
+    parser.add_argument("--ratio", help="Sampling ratio between 0 and 1", type=float, default=1.0, required=False)
+    parser.add_argument("--ext", help="Extension of the video files", default=".mp4", required=False)
+    args = parser.parse_args()
+
+    if args.ratio < 0 or args.ratio > 1:
+        raise ValueError("Sampling ratio must be between 0 and 1")
+
+    converter = RTLConverter(args.input, args.output_file, args.ratio, args.ext)
+    converter.convert()
+
+
+if __name__ == "__main__":
+    main()
diff --git a/scripts/multimodal_dataset_conversion/convert_dvc_dataset_for_training.py b/scripts/multimodal_dataset_conversion/convert_dvc_dataset_for_training.py
new file mode 100644
index 000000000000..4aa366bc4007
--- /dev/null
+++ b/scripts/multimodal_dataset_conversion/convert_dvc_dataset_for_training.py
@@ -0,0 +1,323 @@
+# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+
+"""
+This script is used to convert the DVC dataset to the format required by the model training script.
+The DVC dataset should have the below structure:
+{
+    "1043215450": {          # video_name is the unique video file name (the extension should be .mp4)
+        "duration": 125.0,
+        "timestamps": [
+            [0, 5], 
+            [3, 9]
+        ],
+        "sentences": [                  # For custom caption or event localization task
+            "Here is your caption 1",
+            "Here is your caption 2",
+        ],
+        "events": [                   # For custom event task
+            "Event 1",
+            "Event 2",
+        ]
+    },
+    ...
+}
+
+The converted dataset format is as follows:
+[
+    # 1st example: dense video captioning  (custom event or custom caption task)
+    {
+        "id": "xxxx",
+        "video: "xxxx.mp4",
+        "conversations":
+        [
+            {"from": "human", "value": "