diff --git a/.github/workflows/_test_template.yml b/.github/workflows/_test_template.yml index 3a1f69243c39..1e184a8d4160 100644 --- a/.github/workflows/_test_template.yml +++ b/.github/workflows/_test_template.yml @@ -60,7 +60,16 @@ jobs: ARG=("--runtime=nvidia --gpus all") fi - docker run --rm -d --name nemo_container_${{ github.run_id }} ${ARG[@]} --shm-size=64g --env TRANSFORMERS_OFFLINE=0 --env HYDRA_FULL_ERROR=1 --volume /mnt/datadrive/TestData:/home/TestData nemoci.azurecr.io/nemo_container:${{ github.run_id }} bash -c "sleep $(( ${{ inputs.TIMEOUT }} * 60 + 60 ))" + docker run \ + --rm \ + -d \ + --name nemo_container_${{ github.run_id }} ${ARG[@]} \ + --shm-size=64g \ + --env TRANSFORMERS_OFFLINE=0 \ + --env HYDRA_FULL_ERROR=1 \ + --env HF_HOME=/home/TestData/HF_HOME \ + --volume /mnt/datadrive/TestData:/home/TestData nemoci.azurecr.io/nemo_container:${{ github.run_id }} \ + bash -c "sleep $(( ${{ inputs.TIMEOUT }} * 60 + 60 ))" - id: main name: Run main script @@ -95,4 +104,4 @@ jobs: if: always() run: | docker container stop nemo_container_${{ github.run_id }} || true - docker container rm nemo_container_${{ github.run_id }} || true \ No newline at end of file + docker container rm nemo_container_${{ github.run_id }} || true diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index e41e96cfa794..a4b2baa59550 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -38,6 +38,7 @@ jobs: outputs: test_to_run: ${{ steps.test_to_run.outputs.main }} all: ${{ steps.all.outputs.main }} + event_name: ${{ steps.github-event.outputs.main }} steps: - name: Parse test_to_run id: test_to_run @@ -47,11 +48,16 @@ jobs: - name: Parse all id: all run: | - echo "main=${{ contains(fromJSON(steps.test_to_run.outputs.main), 'all') }}" | tee -a "$GITHUB_OUTPUT" + echo "main=${{ contains(fromJSON(steps.test_to_run.outputs.main), 'all') }}" | tee -a "$GITHUB_OUTPUT" + - name: Infer github event + id: github-event + run: | + echo "main=${{ github.event_name }}" | tee -a "$GITHUB_OUTPUT" cicd-test-container-build: - if: ${{ github.event.label.name == 'Run CICD' || github.event_name == 'workflow_dispatch' }} - uses: NVIDIA/NeMo-FW-CI-templates/.github/workflows/_build_container.yml@v0.1.0 + if: ${{ github.event.label.name == 'Run CICD' || needs.pre-flight.outputs.event_name == 'workflow_dispatch' }} + uses: NVIDIA/NeMo-FW-CI-templates/.github/workflows/_build_container.yml@v0.14.0 + needs: pre-flight with: image-name: nemo_container dockerfile: Dockerfile.ci @@ -2103,6 +2109,121 @@ jobs: # } # } + L2_Megatron_LM_To_NeMo_Conversion: + needs: [cicd-test-container-setup] + uses: ./.github/workflows/_test_template.yml + if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_Megatron_LM_To_NeMo_Conversion') || needs.cicd-test-container-setup.outputs.all == 'true' + with: + RUNNER: self-hosted-azure + SCRIPT: | + CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=1 Megatron-LM/pretrain_gpt.py \ + --mock-data \ + --distributed-timeout-minutes 60 \ + --use-mcore-models \ + --no-mmap-bin-files \ + --untie-embeddings-and-output-weights \ + --disable-bias-linear \ + --train-samples 80 \ + --init-method-std 0.014 \ + --position-embedding-type rope \ + --rotary-base 1000000 \ + --rotary-percent 1.0 \ + --squared-relu \ + --num-layers 4 \ + --hidden-size 384 \ + --num-attention-heads 8 \ + --group-query-attention \ + --num-query-groups 8 \ + --ffn-hidden-size 1536 \ + --kv-channels 128 \ + --normalization RMSNorm \ + --attention-dropout 0.0 \ + --hidden-dropout 0.0 \ + --exit-duration-in-mins 5750 \ + --tensor-model-parallel-size 1 \ + --pipeline-model-parallel-size 1 \ + --seq-length 8192 \ + --max-position-embeddings 8192 \ + --micro-batch-size 1 \ + --global-batch-size 8 \ + --lr 6e-4 \ + --min-lr 6e-6 \ + --weight-decay 0.1 \ + --clip-grad 1.0 \ + --lr-decay-style cosine \ + --log-interval 1 \ + --eval-iters 1 \ + --eval-interval 10 \ + --tokenizer-type GPT2BPETokenizer \ + --tokenizer-model /home/TestData/nlp/gpt2_tokenizer \ + --vocab-file /home/TestData/nlp/gpt2_tokenizer/vocab.json \ + --merge-file /home/TestData/nlp/gpt2_tokenizer/merges.txt \ + --save /tmp/mlm_conversion_ckpt \ + --save-interval 10 \ + --ckpt-format torch_dist \ + --ckpt-fully-parallel-save \ + --ckpt-fully-parallel-load \ + --async-save \ + --ckpt-assume-constant-structure \ + --timing-log-option minmax \ + --log-params-norm \ + --log-num-zeros-in-grad \ + --log-throughput \ + --bf16 \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --use-distributed-optimizer \ + --overlap-grad-reduce \ + --overlap-param-gather \ + --manual-gc \ + --num-workers 2 + + python examples/nlp/language_modeling/megatron_gpt_pretraining.py \ + model.data.data_impl=mock \ + model.data.data_prefix=[] \ + model.skip_train=True \ + model.transformer_engine=True \ + model.use_flash_attention=False \ + model.normalization=rmsnorm \ + model.num_layers=4 \ + model.hidden_size=384 \ + model.ffn_hidden_size=1536 \ + model.num_attention_heads=8 \ + model.num_query_groups=8 \ + model.bias=False \ + model.bias_activation_fusion=False \ + model.bias_dropout_add_fusion=True \ + model.masked_softmax_fusion=True \ + model.encoder_seq_length=8192 \ + model.max_position_embeddings=8192 \ + model.data.seq_length=8192 \ + model.activation=squared-relu \ + model.transformer_block_type=True \ + model.micro_batch_size=1 \ + model.global_batch_size=8 \ + ++model.rotary_base=1000000 \ + model.rotary_percentage=1.0 \ + model.apply_query_key_layer_scaling=False \ + ++model.group_query_attention=True \ + model.apply_rope_fusion=True \ + model.kv_channels=128 \ + ++model.bert_binary_head=True \ + ++model.position_embedding_type=rope \ + ++model.add_position_embedding=True \ + trainer.limit_val_batches=1 \ + exp_manager.exp_dir=/tmp/nemo_conversion_ckpt + + python -m torch.distributed.launch --nproc_per_node=1 examples/nlp/language_modeling/megatron_ckpt_to_nemo.py \ + --checkpoint_folder /tmp/mlm_conversion_ckpt \ + --checkpoint_name iter_0000010 \ + --nemo_file_path /tmp/mlm_to_nemo_test.nemo \ + --tensor_model_parallel_size 1 \ + --pipeline_model_parallel_size 1 \ + --gpus_per_node 1 \ + --model_type gpt \ + --hparams_file /tmp/nemo_conversion_ckpt/megatron_gpt/version_0/hparams.yaml \ + --convert_mlm + L2_Megatron_GPT_with_ResetLR_Pretraining_and_Resume_Training_TP2: needs: [cicd-test-container-setup] uses: ./.github/workflows/_test_template.yml @@ -2943,7 +3064,7 @@ jobs: with: RUNNER: self-hosted-azure SCRIPT: | - NVTE_FUSED_ATTN=0 NVTE_FLASH_ATTN=0 python examples/nlp/language_modeling/megatron_t5_pretraining.py \ + python examples/nlp/language_modeling/megatron_t5_pretraining.py \ trainer.devices=2 \ trainer.log_every_n_steps=1 \ trainer.max_epochs=null \ @@ -2975,7 +3096,7 @@ jobs: +model.data.data_impl_kwargs.workers=null \ +model.data.data_impl_kwargs.sort_dataset_paths=False - NVTE_FUSED_ATTN=0 NVTE_FLASH_ATTN=0 python examples/nlp/language_modeling/megatron_t5_pretraining.py \ + python examples/nlp/language_modeling/megatron_t5_pretraining.py \ trainer.devices=2 \ trainer.log_every_n_steps=1 \ trainer.max_epochs=null \ @@ -3398,8 +3519,8 @@ jobs: with: RUNNER: self-hosted-azure SCRIPT: | - NVTE_FLASH_ATTN=0 NVTE_FUSED_ATTN=0 python examples/nlp/language_modeling/megatron_t5_eval.py \ - --model_file /home/TestData/nlp/megatron_t5/220m/megatron_mcore_t5_220m.nemo \ + python examples/nlp/language_modeling/megatron_t5_eval.py \ + --model_file /home/TestData/nlp/megatron_t5/220m/megatron_mcore_t5_220m_padding_attnmasktype.nemo \ --prompt "How do I fix my GPU memory issue? I am seeing out of memory." \ --tensor_model_parallel_size 1 @@ -3410,7 +3531,7 @@ jobs: with: RUNNER: self-hosted-azure SCRIPT: | - NVTE_FUSED_ATTN=0 NVTE_FLASH_ATTN=0 python examples/nlp/language_modeling/tuning/megatron_t5_finetuning.py \ + python examples/nlp/language_modeling/tuning/megatron_t5_finetuning.py \ trainer.devices=2 \ trainer.log_every_n_steps=1 \ trainer.max_epochs=9999 \ @@ -3421,7 +3542,7 @@ jobs: exp_manager.exp_dir=/tmp/nlp_mcore_t5_lora_tuning_tp2 \ model.pipeline_model_parallel_size=1 \ model.tensor_model_parallel_size=2 \ - model.restore_from_path=/home/TestData/nlp/megatron_t5/220m/megatron_mcore_t5_220m.nemo \ + model.restore_from_path=/home/TestData/nlp/megatron_t5/220m/megatron_mcore_t5_220m_padding_attnmasktype.nemo \ model.peft.peft_scheme=lora \ model.answer_only_loss=True \ model.micro_batch_size=1 \ @@ -3433,8 +3554,8 @@ jobs: model.data.validation_ds.file_names=[/home/TestData/nlp/megatron_sft/quarel.jsonl] \ model.data.validation_ds.names=[quarel] - NVTE_FUSED_ATTN=0 NVTE_FLASH_ATTN=0 python examples/nlp/language_modeling/tuning/megatron_t5_generate.py \ - model.restore_from_path=/home/TestData/nlp/megatron_t5/220m/megatron_mcore_t5_220m.nemo \ + python examples/nlp/language_modeling/tuning/megatron_t5_generate.py \ + model.restore_from_path=/home/TestData/nlp/megatron_t5/220m/megatron_mcore_t5_220m_padding_attnmasktype.nemo \ model.peft.restore_from_path=/tmp/nlp_mcore_t5_lora_tuning_tp2/megatron_t5_peft_lora_tuning/checkpoints/megatron_t5_peft_lora_tuning.nemo \ model.peft.restore_from_ckpt_name=null \ model.peft.restore_from_hparams_path=null \ @@ -3451,7 +3572,20 @@ jobs: inference.repetition_penalty=1.0 \ inference.outfile_path=/tmp/nlp_mcore_t5_lora_tuning_tp2/out.jsonl - # L2: Megatron Mock Data Generation + + L2_HF_Transformer_SFT_TE_Acceleration: + needs: [ cicd-test-container-setup ] + uses: ./.github/workflows/_test_template.yml + if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_HF_Transformer_SFT_TE_Acceleration') || needs.cicd-test-container-setup.outputs.all == 'true' + with: + RUNNER: self-hosted-azure + SCRIPT: | + python examples/llm/sft/hf.py --model /home/TestData/nlp/hf_gemma/hf_gemma_2b --model-accelerator te + AFTER_SCRIPT: | + rm -rf nemo_experiments + + + # L2: Megatron Mock Data Generation L2_Megatron_Mock_Data_Generation_MockGPTDataset: needs: [cicd-test-container-setup] uses: ./.github/workflows/_test_template.yml @@ -3572,12 +3706,12 @@ jobs: # timeout-minutes: 10 # container: # image: nemoci.azurecr.io/nemo_container:${{ github.run_id }} - # options: + # options: # # --user 0:128 # --device=/dev/nvidia0 # --gpus all # --shm-size=8g - # --env TRANSFORMERS_OFFLINE=0 + # --env TRANSFORMERS_OFFLINE=0 # --env HYDRA_FULL_ERROR=1 # --volume /mnt/datadrive/TestData:/home/TestData # steps: @@ -3637,12 +3771,12 @@ jobs: # runs-on: self-hosted-azure # container: # image: nemoci.azurecr.io/nemo_container:${{ github.run_id }} - # options: + # options: # # --user 0:128 # --device=/dev/nvidia0 # --gpus all - # --shm-size=8g - # --env TRANSFORMERS_OFFLINE=0 + # --shm-size=8g + # --env TRANSFORMERS_OFFLINE=0 # --env HYDRA_FULL_ERROR=1 # --volume /mnt/datadrive/TestData:/home/TestData # steps: @@ -3852,14 +3986,14 @@ jobs: with: RUNNER: self-hosted-azure SCRIPT: | - NVTE_FUSED_ATTN=0 NVTE_FLASH_ATTN=0 python tests/collections/llm/megatron_t5_pretraining.py \ + python tests/collections/llm/megatron_t5_pretraining.py \ --devices=2 \ --max-steps=3 \ --experiment-dir=tests/collections/llm/t5_pretrain_results/${{ github.run_id }} \ --data-path=/home/TestData/nlp/megatron_t5/data/pile_val_small_bert_tokenizer_text_document \ --index-mapping-dir=tests/collections/llm/t5_index_mappings/${{ github.run_id }} - NVTE_FUSED_ATTN=0 NVTE_FLASH_ATTN=0 python tests/collections/llm/megatron_t5_pretraining.py \ + python tests/collections/llm/megatron_t5_pretraining.py \ --devices=2 \ --max-steps=6 \ --experiment-dir=tests/collections/llm/t5_pretrain_results/${{ github.run_id }} \ @@ -3876,11 +4010,11 @@ jobs: with: RUNNER: self-hosted-azure SCRIPT: | - NVTE_FUSED_ATTN=0 NVTE_FLASH_ATTN=0 python tests/collections/llm/megatron_t5_finetuning.py \ + python tests/collections/llm/megatron_t5_finetuning.py \ --devices=2 \ --max-steps=250 \ --experiment-dir=tests/collections/llm/t5_finetune_results/${{ github.run_id }} \ - --checkpoint-path=/home/TestData/nlp/megatron_t5/220m/nemo2.0_t5_220m_150steps + --checkpoint-path=/home/TestData/nlp/megatron_t5/220m/nemo2.0_t5_220m_padding_attnmasktype_150steps AFTER_SCRIPT: | rm -rf tests/collections/llm/t5_finetune_results/${{ github.run_id }} @@ -3891,12 +4025,12 @@ jobs: with: RUNNER: self-hosted-azure SCRIPT: | - NVTE_FUSED_ATTN=0 NVTE_FLASH_ATTN=0 python tests/collections/llm/megatron_t5_finetuning.py \ + python tests/collections/llm/megatron_t5_finetuning.py \ --devices=2 \ --max-steps=250 \ --peft=lora \ --experiment-dir=tests/collections/llm/t5_peft_results/${{ github.run_id }} \ - --checkpoint-path=/home/TestData/nlp/megatron_t5/220m/nemo2.0_t5_220m_150steps + --checkpoint-path=/home/TestData/nlp/megatron_t5/220m/nemo2.0_t5_220m_padding_attnmasktype_150steps AFTER_SCRIPT: | rm -rf tests/collections/llm/t5_peft_results/${{ github.run_id }} @@ -4199,6 +4333,34 @@ jobs: --pp_size 1 \ --mbs 1 --packed + L2_NeMo_2_GPT_DoRA_TP1PP1_MBS1_PACKED: + needs: [cicd-test-container-setup] + uses: ./.github/workflows/_test_template.yml + if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_NeMo_2_GPT_DoRA_TP1PP1_MBS1_PACKED') || needs.cicd-test-container-setup.outputs.all == 'true' + with: + RUNNER: self-hosted-azure + SCRIPT: | + + python tests/collections/llm/gpt_finetuning.py \ + --restore_path /home/TestData/nemo2_ckpt/llama_68M \ + --devices 2 \ + --max_steps 3 \ + --experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \ + --peft dora \ + --tp_size 1 \ + --pp_size 1 \ + --mbs 1 --packed + + python tests/collections/llm/gpt_finetuning.py \ + --restore_path /home/TestData/nemo2_ckpt/llama_68M \ + --devices 2 \ + --max_steps 6 \ + --experiment_dir /tmp/nemo2_gpt_finetune/${{ github.run_id }} \ + --peft dora \ + --tp_size 1 \ + --pp_size 1 \ + --mbs 1 --packed + L2_NeMo_2_Mixtral_LoRA_EP2PP1_MBS2: needs: [cicd-test-container-setup] uses: ./.github/workflows/_test_template.yml @@ -4272,7 +4434,18 @@ jobs: --mbs 1 \ --model mistral \ --dist-opt + + L2_NEMO_2_LoRA_MERGE: + needs: [cicd-test-container-setup] + uses: ./.github/workflows/_test_template.yml + if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_NEMO_2_LoRA_MERGE') || needs.cicd-test-container-setup.outputs.all == 'true' + with: + RUNNER: self-hosted-azure + SCRIPT: | + python tests/collections/llm/peft/lora_merge.py \ + --lora_checkpoint_path=/home/TestData/nemo2_ckpt/llama_lora_ci_checkpoint/ \ + --output_path=/tmp/nemo2_lora_merge/${{ github.run_id }} L2_NeMo_2_NeMo_Mcore_Mixtral_bitexact: needs: [cicd-test-container-setup] @@ -4299,7 +4472,7 @@ jobs: rm -rf /tmp/nemo2_ptq_engine Nemo_CICD_Test: - needs: + needs: - pre-flight - cicd-test-container-setup @@ -4314,7 +4487,7 @@ jobs: - L0_Unit_Tests_GPU_Hydra - L0_Unit_Tests_GPU_Lightning - L0_Unit_Tests_GPU_Others - + - L0_Unit_Tests_CPU_ASR - L0_Unit_Tests_CPU_Audio - L0_Unit_Tests_CPU_Common @@ -4374,6 +4547,7 @@ jobs: - L2_RAG_Pipeline_Generating - L2_Megatron_GPT_Pretraining_and_Resume_Training_TP2 - L2_Megatron_GPT_Skip_Train + - L2_Megatron_LM_To_NeMo_Conversion - L2_Megatron_GPT_with_Rope_Pretraining_and_Resume_Training_TP2 - L2_Megatron_GPT_with_ResetLR_Pretraining_and_Resume_Training_TP2 - L2_Megatron_GPT_with_Drop_Optimizer_States_TP2 @@ -4412,7 +4586,8 @@ jobs: - L2_NeMo_2_GPT_Pretraining_no_transformer_engine - L2_NeMo_2_GPT_DDP_Param_Parity_check - L2_NeMo_2_HF_MODEL_IMPORT - - L2_NeMo_2_llama3_pretraining_recipe + - L2_NeMo_2_llama3_pretraining_recipe + - L2_HF_Transformer_SFT_TE_Acceleration - L2_NeMo_2_SSM_Pretraining - L2_NeMo_2_SSM_Finetuning - L2_NeMo_2_T5_Pretraining @@ -4428,11 +4603,13 @@ jobs: - L2_NeMo_2_GPT_LoRA_TP1PP2_MBS2 - L2_NeMo_2_GPT_LoRA_TP2PP1_MBS2 - L2_NeMo_2_GPT_LoRA_TP1PP1_MBS1_PACKED + - L2_NeMo_2_GPT_DoRA_TP1PP1_MBS1_PACKED - L2_NeMo_2_Mixtral_LoRA_EP2PP1_MBS2 - L2_NeMo_2_Mixtral_LoRA_TP1PP1_MBS1 - L2_NeMo_2_Mixtral_LoRA_TP2PP1_MBS1 - L2_NeMo_2_Mistral_LoRA_TP1PP1_MBS1 - L2_NeMo_2_Mistral_LoRA_TP2PP1_MBS1 + - L2_NEMO_2_LoRA_MERGE - L2_NeMo_2_Mixtral_Pretraining - L2_PTQ_Llama2_FP8 - L2_Community_LLM_Checkpoints_tests_Llama3 @@ -4450,7 +4627,7 @@ jobs: - L2_NeMo_2_PTQ_Llama2_FP8 if: always() runs-on: ubuntu-latest - steps: + steps: - name: Evaluate conclusion if: ${{ always() }} id: pipeline-conclusion @@ -4464,14 +4641,14 @@ jobs: echo "SUCCESS=$SUCCESS" >> $GITHUB_OUTPUT # This should depend on all the tests so we block/unblock based on all tests passing - - name: Pipeline successful, set exit code to 0 + - name: Pipeline successful, set exit code to 0 if: ${{ always() && steps.pipeline-conclusion.outputs.SUCCESS == 'true' }} run: exit 0 - - name: Pipeline successful, add PR comment + - name: Pipeline successful, add PR comment if: ${{ always() && steps.pipeline-conclusion.outputs.SUCCESS == 'true' && github.event_name == 'pull_request' && env.SLACK_WEBHOOK != '' }} uses: peter-evans/create-or-update-comment@v4 - env: + env: SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }} REPOSITORY: ${{ github.repository }} RUN_ID: ${{ github.run_id }} @@ -4490,7 +4667,7 @@ jobs: - name: "Pipeline not successful and not cancelled: Send Slack alert & create step summary" if: ${{ always() && steps.pipeline-conclusion.outputs.FAILED == 'true' && env.SLACK_WEBHOOK != '' }} - env: + env: SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }} SLACK_WEBHOOK_ADMIN: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} @@ -4583,4 +4760,4 @@ jobs: - name: "Pipeline not successful, set exit code to 1" if: ${{ always() && steps.pipeline-conclusion.outputs.SUCCESS == 'false' }} - run: exit 1 + run: exit 1 \ No newline at end of file diff --git a/.github/workflows/monitor-vms.yml b/.github/workflows/monitor-vms.yml index 6795f87abf68..0bb54524847a 100644 --- a/.github/workflows/monitor-vms.yml +++ b/.github/workflows/monitor-vms.yml @@ -27,7 +27,7 @@ jobs: | jq -c '[ .runners[] | select(.status == "online") - | select(.name | contains("gpu")) + | select(.name | contains("cpu") | not) | { "vm": .name, "n_gpus": [ diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index c11629776c40..81db8e1160d9 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -20,10 +20,15 @@ on: description: Ref (SHA or branch name) to release required: true type: string + dry-run: + description: Do not publish a wheel and GitHub release. + required: true + default: true + type: boolean jobs: release: - uses: NVIDIA/NeMo-FW-CI-templates/.github/workflows/_release_library.yml@v0.10.0 + uses: NVIDIA/NeMo-FW-CI-templates/.github/workflows/_release_library.yml@v0.15.0 with: release-ref: ${{ inputs.release-ref }} image-name: nemo_container @@ -35,7 +40,10 @@ jobs: python-package: nemo container-workdir: /workspace library-name: Neural Modules + dry-run: ${{ inputs.dry-run }} secrets: TWINE_USERNAME: ${{ secrets.TWINE_USERNAME }} TWINE_PASSWORD: ${{ secrets.TWINE_PASSWORD }} SLACK_RELEASE_ENDPOINT: ${{ secrets.SLACK_RELEASE_ENDPOINT }} + PAT: ${{ secrets.PAT }} + SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }} diff --git a/.secrets.baseline b/.secrets.baseline index c26f70775c5a..09fc7a78a6ca 100644 --- a/.secrets.baseline +++ b/.secrets.baseline @@ -90,6 +90,10 @@ { "path": "detect_secrets.filters.allowlist.is_line_allowlisted" }, + { + "path": "detect_secrets.filters.common.is_baseline_file", + "filename": ".secrets.baseline" + }, { "path": "detect_secrets.filters.common.is_ignored_due_to_verification_policies", "min_level": 2 @@ -273,7 +277,7 @@ "filename": "scripts/checkpoint_converters/convert_mistral_7b_hf_to_nemo.py", "hashed_secret": "e0308bd21bffc156d79208f9ecf130370a015002", "is_verified": false, - "line_number": 460 + "line_number": 471 } ], "scripts/dataset_processing/nlp/intent_and_slot/assistant_utils.py": [ @@ -1929,7 +1933,7 @@ "filename": "tutorials/speaker_tasks/Speaker_Diarization_Inference.ipynb", "hashed_secret": "80903ddedcf4ec0a2ee5911cefa7e1ad52419dcc", "is_verified": false, - "line_number": 989 + "line_number": 990 } ], "tutorials/tools/DefinedCrowd_x_NeMo_ASR_Training_Tutorial.ipynb": [ @@ -2083,5 +2087,5 @@ } ] }, - "generated_at": "2024-10-25T13:43:17Z" + "generated_at": "2024-11-14T09:37:19Z" } diff --git a/Dockerfile.ci b/Dockerfile.ci index 5858f0aadf5b..e1b78547325a 100644 --- a/Dockerfile.ci +++ b/Dockerfile.ci @@ -54,7 +54,7 @@ RUN pip install nemo_run@git+https://github.com/NVIDIA/NeMo-Run.git@${NEMO_RUN_T # Install NeMo requirements ARG TE_TAG=7d576ed25266a17a7b651f2c12e8498f67e0baea ARG MODELOPT_VERSION=0.19.0 -ARG MCORE_TAG=aded519cfb1de2abf96f36ca059f992294b7876f +ARG MCORE_TAG=c1728c12f1f1cdbb786e52f1ffe512295d76bef3 ARG APEX_TAG=810ffae374a2b9cb4b5c5e28eaeca7d7998fca0c RUN \ diff --git a/docs/source/nlp/distillation.rst b/docs/source/nlp/distillation.rst deleted file mode 100644 index 22b2f3dd8a1c..000000000000 --- a/docs/source/nlp/distillation.rst +++ /dev/null @@ -1,58 +0,0 @@ -.. _megatron_distillation: - -Distillation -========================== - -Knowledge Distillation (KD) --------------------------------- - -KD involves using information from an existing trained model to train a second (usually smaller, faster) model, thereby "distilling" knowledge from one to the other. - -Distillation has two primary benefits: faster convergence and higher end accuracy than traditional training. - -In NeMo, distillation is enabled by the `NVIDIA TensorRT Model Optimizer (ModelOpt) `_ library -- a library to optimize deep-learning models for inference on GPUs. - -The logits-distillation process consists of the following steps: - -1. Loading both student and teacher model checkpoints (must support same parallelism strategy, if any) -2. Training until convergence, where forward passes are run on both models (and backward only on student), performing a specific loss function between the logits. -3. Saving the final student model. - - -Example -^^^^^^^ -The example below shows how to run the distillation script for LLama models. - -The script must be launched correctly with the number of processes equal to tensor parallelism. This is achieved with the ``torchrun`` command below: - -.. code-block:: bash - - STUDENT_CKPT="path/to/student.nemo" # can also be None (will use default architecture found in examples/nlp/language_modeling/conf/megatron_llama_distill.yaml) - TEACHER_CKPT="path/to/teacher.nemo" - TOKENIZER="path/to/tokenizer.model" - DATA_PATHS="[1.0,path/to/tokenized/data]" - FINAL_SAVE_FILE="final_checkpoint.nemo" - TP=4 - - NPROC=$TP - launch_config="torchrun --nproc_per_node=$NPROC" - - ${launch_config} examples/nlp/language_modeling/megatron_gpt_distillation.py \ - model.restore_from_path=$STUDENT_CKPT \ - model.kd_teacher_restore_from_path=$TEACHER_CKPT \ - model.tensor_model_parallel_size=$TP \ - model.tokenizer.model=$TOKENIZER \ - model.data.data_prefix=$DATA_PATHS \ - model.nemo_path=$FINAL_SAVE_FILE \ - trainer.precision=bf16 \ - trainer.devices=$NPROC - -For large models, the command can be used in multi-node setting. For example, this can be done with `NeMo Framework Launcher `_ using Slurm. - - -Limitations -^^^^^^^^^^^ -* Only Megatron Core-based GPT models are supported -* Only logit-pair distillation is supported for now -* Pipeline parallelism not yet supported -* FSDP strategy not yet supported diff --git a/docs/source/nlp/nemo_megatron/model_distillation/drop_layers.rst b/docs/source/nlp/nemo_megatron/model_distillation/drop_layers.rst deleted file mode 100644 index 3dc008945cc9..000000000000 --- a/docs/source/nlp/nemo_megatron/model_distillation/drop_layers.rst +++ /dev/null @@ -1,67 +0,0 @@ -.. _drop_layers: - -Drop Model Layers ------------------ - -To trim the model layers, use the following script: - -.. code-block:: bash - - python -m torch.distributed.launch --nproc_per_node= * \ - /NeMo/examples/nlp/language_modeling/megatron_gpt_drop_layers.py \ - --path_to_nemo /path/to/model.nemo \ - --path_to_save /path/to/save/trimmed_model.nemo \ - --tensor_model_parallel_size \ - --pipeline_model_parallel_size \ - --gpus_per_node \ - --drop_layers 1 2 3 4 - -**Note:** layer indices start from 1. - -To save trimmed model in ``zarr`` checkpoint format, add the following flag to the command above: - -.. code-block:: bash - - --zarr - -**Note:** the ``zarr`` checkpoint format is deprecated. - -Validate Trimmed Model ----------------------- - -To validate the trimmed model, use the following script: - -.. code-block:: bash - - python /NeMo/examples/nlp/language_modeling/megatron_gpt_pretraining.py \ - --config-path=/path/to/folder/with/model/config \ - --config-name=model_config.yaml \ - trainer.limit_val_batches= \ - model.restore_from_path=/path/to/trimmed_model.nemo \ - model.skip_train=True \ - model.data.data_impl=mock \ - model.data.data_prefix=[] - -To use a specific dataset instead of a mock dataset, modify the ``model.data`` parameters as follows: - -.. code-block:: bash - - model.data.data_impl=mmap \ - model.data.data_prefix=["path/to/datafile1", "path/to/datafile2"] - -Validate Original Model ------------------------ - -To validate the original model without specific layers, use the following script: - -.. code-block:: bash - - python /NeMo/examples/nlp/language_modeling/megatron_gpt_pretraining.py \ - --config-path=/path/to/folder/with/model/config \ - --config-name=model_config.yaml \ - trainer.limit_val_batches= \ - model.restore_from_path=/path/to/original_model.nemo \ - model.skip_train=True \ - model.data.data_impl=mock \ - model.data.data_prefix=[] \ - model.drop_layers=[1,2,3,4] diff --git a/docs/source/nlp/punctuation_and_capitalization.rst b/docs/source/nlp/punctuation_and_capitalization.rst index 4be0d2151d8e..d67332eb00c1 100755 --- a/docs/source/nlp/punctuation_and_capitalization.rst +++ b/docs/source/nlp/punctuation_and_capitalization.rst @@ -240,7 +240,7 @@ An example of a config file is - trainer config - - Parameters of - `pytorch_lightning.Trainer `_. + `lightning.pytorch.Trainer `_. * - **exp_manager** - exp manager config - diff --git a/docs/source/starthere/fundamentals.rst b/docs/source/starthere/fundamentals.rst index e3014e0f5a03..f486bf3d6e49 100644 --- a/docs/source/starthere/fundamentals.rst +++ b/docs/source/starthere/fundamentals.rst @@ -116,7 +116,7 @@ Below is an example training script for our ``ExampleEncDecModel`` model. We hig :linenos: :emphasize-lines: 10, 11, 12 - import pytorch_lightning as pl + import lightning.pytorch as pl from nemo.collections.path_to_model_class import ExampleEncDecModel from nemo.core.config import hydra_runner diff --git a/examples/asr/asr_adapters/eval_asr_adapter.py b/examples/asr/asr_adapters/eval_asr_adapter.py index bc5947f26aaf..b35cf33a6c0e 100644 --- a/examples/asr/asr_adapters/eval_asr_adapter.py +++ b/examples/asr/asr_adapters/eval_asr_adapter.py @@ -36,7 +36,7 @@ """ -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import OmegaConf, open_dict from nemo.collections.asr.models import ASRModel diff --git a/examples/asr/asr_adapters/train_asr_adapter.py b/examples/asr/asr_adapters/train_asr_adapter.py index 3f82ef8fe554..253672e3eb89 100644 --- a/examples/asr/asr_adapters/train_asr_adapter.py +++ b/examples/asr/asr_adapters/train_asr_adapter.py @@ -84,7 +84,7 @@ import os from dataclasses import is_dataclass -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import DictConfig, OmegaConf, open_dict from nemo.collections.asr.models import ASRModel diff --git a/examples/asr/asr_chunked_inference/aed/speech_to_text_aed_chunked_infer.py b/examples/asr/asr_chunked_inference/aed/speech_to_text_aed_chunked_infer.py index 8188bcced14d..1e63a9d820be 100644 --- a/examples/asr/asr_chunked_inference/aed/speech_to_text_aed_chunked_infer.py +++ b/examples/asr/asr_chunked_inference/aed/speech_to_text_aed_chunked_infer.py @@ -49,7 +49,7 @@ from dataclasses import dataclass from typing import Optional -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from omegaconf import OmegaConf diff --git a/examples/asr/asr_chunked_inference/ctc/speech_to_text_buffered_infer_ctc.py b/examples/asr/asr_chunked_inference/ctc/speech_to_text_buffered_infer_ctc.py index 87370d278f98..ccea94f41f83 100644 --- a/examples/asr/asr_chunked_inference/ctc/speech_to_text_buffered_infer_ctc.py +++ b/examples/asr/asr_chunked_inference/ctc/speech_to_text_buffered_infer_ctc.py @@ -42,7 +42,7 @@ from dataclasses import dataclass from typing import Optional -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from omegaconf import OmegaConf diff --git a/examples/asr/asr_chunked_inference/rnnt/speech_to_text_buffered_infer_rnnt.py b/examples/asr/asr_chunked_inference/rnnt/speech_to_text_buffered_infer_rnnt.py index e6e84cdfa6c4..c31fa2b9d812 100644 --- a/examples/asr/asr_chunked_inference/rnnt/speech_to_text_buffered_infer_rnnt.py +++ b/examples/asr/asr_chunked_inference/rnnt/speech_to_text_buffered_infer_rnnt.py @@ -64,7 +64,7 @@ from dataclasses import dataclass from typing import Optional -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from omegaconf import OmegaConf, open_dict diff --git a/examples/asr/asr_ctc/speech_to_text_ctc.py b/examples/asr/asr_ctc/speech_to_text_ctc.py index 87b1b11633f7..ccdf3a5e09ea 100644 --- a/examples/asr/asr_ctc/speech_to_text_ctc.py +++ b/examples/asr/asr_ctc/speech_to_text_ctc.py @@ -68,7 +68,7 @@ """ -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import OmegaConf from nemo.collections.asr.models import EncDecCTCModel diff --git a/examples/asr/asr_ctc/speech_to_text_ctc_bpe.py b/examples/asr/asr_ctc/speech_to_text_ctc_bpe.py index b4e3be5f650a..997cd6e52d5b 100644 --- a/examples/asr/asr_ctc/speech_to_text_ctc_bpe.py +++ b/examples/asr/asr_ctc/speech_to_text_ctc_bpe.py @@ -64,7 +64,7 @@ """ -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import OmegaConf from nemo.collections.asr.models.ctc_bpe_models import EncDecCTCModelBPE diff --git a/examples/asr/asr_hybrid_transducer_ctc/speech_to_text_hybrid_rnnt_ctc_bpe.py b/examples/asr/asr_hybrid_transducer_ctc/speech_to_text_hybrid_rnnt_ctc_bpe.py index 796005a8fcee..ffda4c554a49 100644 --- a/examples/asr/asr_hybrid_transducer_ctc/speech_to_text_hybrid_rnnt_ctc_bpe.py +++ b/examples/asr/asr_hybrid_transducer_ctc/speech_to_text_hybrid_rnnt_ctc_bpe.py @@ -58,7 +58,7 @@ """ -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import OmegaConf from nemo.collections.asr.models import EncDecHybridRNNTCTCBPEModel diff --git a/examples/asr/asr_hybrid_transducer_ctc/speech_to_text_hybrid_rnnt_ctc_char.py b/examples/asr/asr_hybrid_transducer_ctc/speech_to_text_hybrid_rnnt_ctc_char.py index 423e005d8f02..02f43f93e2c7 100644 --- a/examples/asr/asr_hybrid_transducer_ctc/speech_to_text_hybrid_rnnt_ctc_char.py +++ b/examples/asr/asr_hybrid_transducer_ctc/speech_to_text_hybrid_rnnt_ctc_char.py @@ -69,7 +69,7 @@ """ -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import OmegaConf from nemo.collections.asr.models import EncDecHybridRNNTCTCModel diff --git a/examples/asr/asr_transducer/speech_to_text_rnnt.py b/examples/asr/asr_transducer/speech_to_text_rnnt.py index 5b4f1e8a985d..2fab3ac137e6 100644 --- a/examples/asr/asr_transducer/speech_to_text_rnnt.py +++ b/examples/asr/asr_transducer/speech_to_text_rnnt.py @@ -67,7 +67,7 @@ """ -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import OmegaConf from nemo.collections.asr.models import EncDecRNNTModel diff --git a/examples/asr/asr_transducer/speech_to_text_rnnt_bpe.py b/examples/asr/asr_transducer/speech_to_text_rnnt_bpe.py index 1fffea55686f..d18313acc9a6 100644 --- a/examples/asr/asr_transducer/speech_to_text_rnnt_bpe.py +++ b/examples/asr/asr_transducer/speech_to_text_rnnt_bpe.py @@ -59,7 +59,7 @@ """ -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import OmegaConf from nemo.collections.asr.models import EncDecRNNTBPEModel diff --git a/examples/asr/asr_with_tts/speech_to_text_bpe_with_text.py b/examples/asr/asr_with_tts/speech_to_text_bpe_with_text.py index b435d418fda2..acd7a8632822 100644 --- a/examples/asr/asr_with_tts/speech_to_text_bpe_with_text.py +++ b/examples/asr/asr_with_tts/speech_to_text_bpe_with_text.py @@ -49,7 +49,7 @@ """ -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import OmegaConf from nemo.collections.asr.models.hybrid_asr_tts_models import ASRWithTTSModel diff --git a/examples/asr/asr_with_tts/speech_to_text_bpe_with_text_finetune.py b/examples/asr/asr_with_tts/speech_to_text_bpe_with_text_finetune.py index 99bc41ba966b..c1692cf6234f 100644 --- a/examples/asr/asr_with_tts/speech_to_text_bpe_with_text_finetune.py +++ b/examples/asr/asr_with_tts/speech_to_text_bpe_with_text_finetune.py @@ -45,7 +45,7 @@ """ -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import OmegaConf from nemo.collections.asr.models.hybrid_asr_tts_models import ASRWithTTSModel diff --git a/examples/asr/conf/asr_adapters/asr_adaptation.yaml b/examples/asr/conf/asr_adapters/asr_adaptation.yaml index b9a2a003217e..bae166d18782 100644 --- a/examples/asr/conf/asr_adapters/asr_adaptation.yaml +++ b/examples/asr/conf/asr_adapters/asr_adaptation.yaml @@ -182,7 +182,7 @@ trainer: val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations accelerator: auto strategy: - _target_: pytorch_lightning.strategies.DDPStrategy + _target_: lightning.pytorch.strategies.DDPStrategy gradient_as_bucket_view: true accumulate_grad_batches: 1 gradient_clip_val: null diff --git a/examples/asr/conf/asr_adapters/asr_adaptation_hp.yaml b/examples/asr/conf/asr_adapters/asr_adaptation_hp.yaml index 958e6d23375c..d03b2eacfec4 100644 --- a/examples/asr/conf/asr_adapters/asr_adaptation_hp.yaml +++ b/examples/asr/conf/asr_adapters/asr_adaptation_hp.yaml @@ -182,7 +182,7 @@ trainer: val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations accelerator: auto strategy: - _target_: pytorch_lightning.strategies.DDPStrategy + _target_: lightning.pytorch.strategies.DDPStrategy gradient_as_bucket_view: true accumulate_grad_batches: 1 gradient_clip_val: null diff --git a/examples/asr/conf/asr_finetune/speech_to_text_finetune.yaml b/examples/asr/conf/asr_finetune/speech_to_text_finetune.yaml index 3b5717efddf9..1ae64a341e16 100644 --- a/examples/asr/conf/asr_finetune/speech_to_text_finetune.yaml +++ b/examples/asr/conf/asr_finetune/speech_to_text_finetune.yaml @@ -81,7 +81,7 @@ trainer: val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations accelerator: auto strategy: - _target_: pytorch_lightning.strategies.DDPStrategy + _target_: lightning.pytorch.strategies.DDPStrategy gradient_as_bucket_view: true accumulate_grad_batches: 1 gradient_clip_val: 0.0 diff --git a/examples/asr/conf/asr_finetune/speech_to_text_hf_finetune.yaml b/examples/asr/conf/asr_finetune/speech_to_text_hf_finetune.yaml index f111573f21eb..c044d3c8d7a8 100644 --- a/examples/asr/conf/asr_finetune/speech_to_text_hf_finetune.yaml +++ b/examples/asr/conf/asr_finetune/speech_to_text_hf_finetune.yaml @@ -145,7 +145,7 @@ trainer: val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations accelerator: auto strategy: - _target_: pytorch_lightning.strategies.DDPStrategy + _target_: lightning.pytorch.strategies.DDPStrategy gradient_as_bucket_view: true accumulate_grad_batches: 1 gradient_clip_val: 0.0 diff --git a/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_ctc_bpe_streaming.yaml b/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_ctc_bpe_streaming.yaml index 4c80d2f2e9d4..564f4b176e64 100644 --- a/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_ctc_bpe_streaming.yaml +++ b/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_ctc_bpe_streaming.yaml @@ -172,7 +172,7 @@ trainer: val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations accelerator: auto strategy: - _target_: pytorch_lightning.strategies.DDPStrategy + _target_: lightning.pytorch.strategies.DDPStrategy gradient_as_bucket_view: true accumulate_grad_batches: 1 gradient_clip_val: 1.0 diff --git a/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_ctc_char_streaming.yaml b/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_ctc_char_streaming.yaml index 0796a60260a1..6962c03ebe60 100644 --- a/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_ctc_char_streaming.yaml +++ b/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_ctc_char_streaming.yaml @@ -177,7 +177,7 @@ trainer: val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations accelerator: auto strategy: - _target_: pytorch_lightning.strategies.DDPStrategy + _target_: lightning.pytorch.strategies.DDPStrategy gradient_as_bucket_view: true accumulate_grad_batches: 1 gradient_clip_val: 1.0 diff --git a/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_transducer_bpe_streaming.yaml b/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_transducer_bpe_streaming.yaml index 4edcc38396fa..1531bf380b6d 100644 --- a/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_transducer_bpe_streaming.yaml +++ b/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_transducer_bpe_streaming.yaml @@ -228,7 +228,7 @@ trainer: val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations accelerator: auto strategy: - _target_: pytorch_lightning.strategies.DDPStrategy + _target_: lightning.pytorch.strategies.DDPStrategy gradient_as_bucket_view: true accumulate_grad_batches: 1 gradient_clip_val: 1.0 diff --git a/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_transducer_char_streaming.yaml b/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_transducer_char_streaming.yaml index 97b64ef93402..4cb508b0aff3 100644 --- a/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_transducer_char_streaming.yaml +++ b/examples/asr/conf/fastconformer/cache_aware_streaming/fastconformer_transducer_char_streaming.yaml @@ -234,7 +234,7 @@ trainer: val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations accelerator: auto strategy: - _target_: pytorch_lightning.strategies.DDPStrategy + _target_: lightning.pytorch.strategies.DDPStrategy gradient_as_bucket_view: true accumulate_grad_batches: 1 gradient_clip_val: 1.0 diff --git a/examples/asr/conf/fastconformer/fast-conformer_ctc_bpe.yaml b/examples/asr/conf/fastconformer/fast-conformer_ctc_bpe.yaml index ea6094380856..fd5f34aa43cb 100644 --- a/examples/asr/conf/fastconformer/fast-conformer_ctc_bpe.yaml +++ b/examples/asr/conf/fastconformer/fast-conformer_ctc_bpe.yaml @@ -198,7 +198,7 @@ trainer: val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations accelerator: auto strategy: - _target_: pytorch_lightning.strategies.DDPStrategy + _target_: lightning.pytorch.strategies.DDPStrategy gradient_as_bucket_view: true accumulate_grad_batches: 1 gradient_clip_val: 0.0 diff --git a/examples/asr/conf/fastconformer/fast-conformer_transducer_bpe.yaml b/examples/asr/conf/fastconformer/fast-conformer_transducer_bpe.yaml index 9e2c1a876864..deb7b7ca613a 100644 --- a/examples/asr/conf/fastconformer/fast-conformer_transducer_bpe.yaml +++ b/examples/asr/conf/fastconformer/fast-conformer_transducer_bpe.yaml @@ -251,7 +251,7 @@ trainer: val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations accelerator: auto strategy: - _target_: pytorch_lightning.strategies.DDPStrategy + _target_: lightning.pytorch.strategies.DDPStrategy gradient_as_bucket_view: true accumulate_grad_batches: 1 gradient_clip_val: 0.0 diff --git a/examples/asr/conf/fastconformer/hybrid_cache_aware_streaming/fastconformer_hybrid_transducer_ctc_bpe_streaming.yaml b/examples/asr/conf/fastconformer/hybrid_cache_aware_streaming/fastconformer_hybrid_transducer_ctc_bpe_streaming.yaml index daef1ed67a9f..6d89a6a52dfb 100644 --- a/examples/asr/conf/fastconformer/hybrid_cache_aware_streaming/fastconformer_hybrid_transducer_ctc_bpe_streaming.yaml +++ b/examples/asr/conf/fastconformer/hybrid_cache_aware_streaming/fastconformer_hybrid_transducer_ctc_bpe_streaming.yaml @@ -245,7 +245,7 @@ trainer: val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations accelerator: auto strategy: - _target_: pytorch_lightning.strategies.DDPStrategy + _target_: lightning.pytorch.strategies.DDPStrategy gradient_as_bucket_view: true accumulate_grad_batches: 1 gradient_clip_val: 1.0 diff --git a/examples/asr/conf/fastconformer/hybrid_cache_aware_streaming/fastconformer_hybrid_transducer_ctc_char_streaming.yaml b/examples/asr/conf/fastconformer/hybrid_cache_aware_streaming/fastconformer_hybrid_transducer_ctc_char_streaming.yaml index 96aee4af1803..7e6b9c4aa7b4 100644 --- a/examples/asr/conf/fastconformer/hybrid_cache_aware_streaming/fastconformer_hybrid_transducer_ctc_char_streaming.yaml +++ b/examples/asr/conf/fastconformer/hybrid_cache_aware_streaming/fastconformer_hybrid_transducer_ctc_char_streaming.yaml @@ -250,7 +250,7 @@ trainer: val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations accelerator: auto strategy: - _target_: pytorch_lightning.strategies.DDPStrategy + _target_: lightning.pytorch.strategies.DDPStrategy gradient_as_bucket_view: true accumulate_grad_batches: 1 gradient_clip_val: 1.0 diff --git a/examples/asr/conf/fastconformer/hybrid_transducer_ctc/fastconformer_hybrid_transducer_ctc_bpe.yaml b/examples/asr/conf/fastconformer/hybrid_transducer_ctc/fastconformer_hybrid_transducer_ctc_bpe.yaml index 4ba55e368bb9..12a21c6fba6c 100644 --- a/examples/asr/conf/fastconformer/hybrid_transducer_ctc/fastconformer_hybrid_transducer_ctc_bpe.yaml +++ b/examples/asr/conf/fastconformer/hybrid_transducer_ctc/fastconformer_hybrid_transducer_ctc_bpe.yaml @@ -224,7 +224,7 @@ trainer: val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations accelerator: auto strategy: - _target_: pytorch_lightning.strategies.DDPStrategy + _target_: lightning.pytorch.strategies.DDPStrategy gradient_as_bucket_view: true accumulate_grad_batches: 1 gradient_clip_val: 1.0 diff --git a/examples/asr/conf/fastconformer/hybrid_transducer_ctc/fastconformer_hybrid_transducer_ctc_char.yaml b/examples/asr/conf/fastconformer/hybrid_transducer_ctc/fastconformer_hybrid_transducer_ctc_char.yaml index ed2ad8ca9c0d..65f657b5416e 100644 --- a/examples/asr/conf/fastconformer/hybrid_transducer_ctc/fastconformer_hybrid_transducer_ctc_char.yaml +++ b/examples/asr/conf/fastconformer/hybrid_transducer_ctc/fastconformer_hybrid_transducer_ctc_char.yaml @@ -229,7 +229,7 @@ trainer: val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations accelerator: auto strategy: - _target_: pytorch_lightning.strategies.DDPStrategy + _target_: lightning.pytorch.strategies.DDPStrategy gradient_as_bucket_view: true accumulate_grad_batches: 1 gradient_clip_val: 1.0 diff --git a/examples/asr/conf/fastconformer/long_fastconformer/fast-conformer-long_ctc_bpe.yaml b/examples/asr/conf/fastconformer/long_fastconformer/fast-conformer-long_ctc_bpe.yaml index 773a500ef2db..df511883ce80 100644 --- a/examples/asr/conf/fastconformer/long_fastconformer/fast-conformer-long_ctc_bpe.yaml +++ b/examples/asr/conf/fastconformer/long_fastconformer/fast-conformer-long_ctc_bpe.yaml @@ -169,7 +169,7 @@ trainer: val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations accelerator: auto strategy: - _target_: pytorch_lightning.strategies.DDPStrategy + _target_: lightning.pytorch.strategies.DDPStrategy gradient_as_bucket_view: true accumulate_grad_batches: 1 gradient_clip_val: 0.0 @@ -204,4 +204,4 @@ exp_manager: create_wandb_logger: false wandb_logger_kwargs: name: null - project: null \ No newline at end of file + project: null diff --git a/examples/asr/conf/fastconformer/long_fastconformer/fast-conformer-long_transducer_bpe.yaml b/examples/asr/conf/fastconformer/long_fastconformer/fast-conformer-long_transducer_bpe.yaml index fec2a2839efa..0218136cbdbd 100644 --- a/examples/asr/conf/fastconformer/long_fastconformer/fast-conformer-long_transducer_bpe.yaml +++ b/examples/asr/conf/fastconformer/long_fastconformer/fast-conformer-long_transducer_bpe.yaml @@ -223,7 +223,7 @@ trainer: val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations accelerator: auto strategy: - _target_: pytorch_lightning.strategies.DDPStrategy + _target_: lightning.pytorch.strategies.DDPStrategy gradient_as_bucket_view: true accumulate_grad_batches: 1 gradient_clip_val: 0.0 diff --git a/examples/asr/conf/speech_multitask/fast-conformer_aed.yaml b/examples/asr/conf/speech_multitask/fast-conformer_aed.yaml index 3d1a8c8bdf47..50446dfd9467 100644 --- a/examples/asr/conf/speech_multitask/fast-conformer_aed.yaml +++ b/examples/asr/conf/speech_multitask/fast-conformer_aed.yaml @@ -249,7 +249,7 @@ trainer: val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations accelerator: auto strategy: - _target_: pytorch_lightning.strategies.DDPStrategy + _target_: lightning.pytorch.strategies.DDPStrategy gradient_as_bucket_view: true accumulate_grad_batches: 1 gradient_clip_val: 0.0 diff --git a/examples/asr/experimental/k2/align_speech_parallel.py b/examples/asr/experimental/k2/align_speech_parallel.py index abfffa0cdfdb..cf07fb998e95 100644 --- a/examples/asr/experimental/k2/align_speech_parallel.py +++ b/examples/asr/experimental/k2/align_speech_parallel.py @@ -77,7 +77,7 @@ from dataclasses import dataclass, field, is_dataclass from typing import Optional -import pytorch_lightning as ptl +import lightning.pytorch as ptl import torch from omegaconf import MISSING, OmegaConf diff --git a/examples/asr/experimental/k2/speech_to_text_bpe.py b/examples/asr/experimental/k2/speech_to_text_bpe.py index ee3924c7b8ac..8a941200770f 100644 --- a/examples/asr/experimental/k2/speech_to_text_bpe.py +++ b/examples/asr/experimental/k2/speech_to_text_bpe.py @@ -74,7 +74,7 @@ model.graph_module_cfg.background_cfg.intersect_pruned=False \ model.graph_module_cfg.background_cfg.boost_coeff=0.0 """ -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import OmegaConf from nemo.collections.asr.models.configs.k2_sequence_models_config import EncDecK2SeqModelConfig diff --git a/examples/asr/experimental/k2/speech_to_text_rnnt_bpe.py b/examples/asr/experimental/k2/speech_to_text_rnnt_bpe.py index a0031fba082d..973be0cbd477 100644 --- a/examples/asr/experimental/k2/speech_to_text_rnnt_bpe.py +++ b/examples/asr/experimental/k2/speech_to_text_rnnt_bpe.py @@ -63,7 +63,7 @@ """ -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import OmegaConf from nemo.collections.asr.models import EncDecK2RnntSeqModelBPE diff --git a/examples/asr/experimental/structured/speech_to_text_hybrid.py b/examples/asr/experimental/structured/speech_to_text_hybrid.py index 26530631498f..e6126c47305f 100644 --- a/examples/asr/experimental/structured/speech_to_text_hybrid.py +++ b/examples/asr/experimental/structured/speech_to_text_hybrid.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl from nemo.collections.asr.models import EncDecCTCModel, configs from nemo.core.config import hydra_runner diff --git a/examples/asr/experimental/structured/speech_to_text_structured.py b/examples/asr/experimental/structured/speech_to_text_structured.py index 366c6d831a7d..55934c00322e 100644 --- a/examples/asr/experimental/structured/speech_to_text_structured.py +++ b/examples/asr/experimental/structured/speech_to_text_structured.py @@ -14,7 +14,7 @@ from dataclasses import asdict -import pytorch_lightning as pl +import lightning.pytorch as pl import nemo.collections.asr as nemo_asr from nemo.collections.asr.models import EncDecCTCModel, configs @@ -64,7 +64,13 @@ ), # ... repeat 14 more times nemo_asr.modules.conv_asr.JasperEncoderConfig( - filters=1024, repeat=1, kernel=[1], stride=[1], dilation=[1], dropout=cfg.model.dropout, residual=False, + filters=1024, + repeat=1, + kernel=[1], + stride=[1], + dilation=[1], + dropout=cfg.model.dropout, + residual=False, ), ] diff --git a/examples/asr/experimental/structured/speech_to_text_structured_v2.py b/examples/asr/experimental/structured/speech_to_text_structured_v2.py index e8a865a9877a..146da425fb9b 100644 --- a/examples/asr/experimental/structured/speech_to_text_structured_v2.py +++ b/examples/asr/experimental/structured/speech_to_text_structured_v2.py @@ -14,7 +14,7 @@ from dataclasses import asdict -import pytorch_lightning as pl +import lightning.pytorch as pl from nemo.collections.asr.models import EncDecCTCModel, configs from nemo.core.config import modelPT, optimizers, schedulers diff --git a/examples/asr/speech_classification/speech_to_frame_label.py b/examples/asr/speech_classification/speech_to_frame_label.py index 04fcbdd1b61c..39a8e4415de5 100644 --- a/examples/asr/speech_classification/speech_to_frame_label.py +++ b/examples/asr/speech_classification/speech_to_frame_label.py @@ -39,7 +39,7 @@ """ -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import OmegaConf from nemo.collections.asr.models.classification_models import EncDecFrameClassificationModel diff --git a/examples/asr/speech_classification/speech_to_label.py b/examples/asr/speech_classification/speech_to_label.py index b3deb5a4e7e5..810d2b5e7bdf 100644 --- a/examples/asr/speech_classification/speech_to_label.py +++ b/examples/asr/speech_classification/speech_to_label.py @@ -143,7 +143,7 @@ https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/speech_classification/results.html# """ -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from omegaconf import OmegaConf diff --git a/examples/asr/speech_multitask/speech_to_text_aed.py b/examples/asr/speech_multitask/speech_to_text_aed.py index 0c13e5289d86..943ecee59bfc 100644 --- a/examples/asr/speech_multitask/speech_to_text_aed.py +++ b/examples/asr/speech_multitask/speech_to_text_aed.py @@ -50,7 +50,7 @@ """ -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import OmegaConf from nemo.collections.asr.models import EncDecMultiTaskModel diff --git a/examples/asr/speech_pretraining/downstream/speech_classification_mfa_train.py b/examples/asr/speech_pretraining/downstream/speech_classification_mfa_train.py index 3a256c7ab2d3..8bd56aa63450 100644 --- a/examples/asr/speech_pretraining/downstream/speech_classification_mfa_train.py +++ b/examples/asr/speech_pretraining/downstream/speech_classification_mfa_train.py @@ -14,7 +14,7 @@ from collections import OrderedDict -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from omegaconf import OmegaConf diff --git a/examples/asr/speech_pretraining/masked_token_pred_pretrain.py b/examples/asr/speech_pretraining/masked_token_pred_pretrain.py index 1ea88d696643..e1c740e66412 100644 --- a/examples/asr/speech_pretraining/masked_token_pred_pretrain.py +++ b/examples/asr/speech_pretraining/masked_token_pred_pretrain.py @@ -13,7 +13,7 @@ # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import OmegaConf from nemo.collections.asr.models.ssl_models import EncDecDenoiseMaskedTokenPredModel diff --git a/examples/asr/speech_pretraining/speech_pre_training.py b/examples/asr/speech_pretraining/speech_pre_training.py index cec9444096c3..0c94099442a6 100644 --- a/examples/asr/speech_pretraining/speech_pre_training.py +++ b/examples/asr/speech_pretraining/speech_pre_training.py @@ -13,7 +13,7 @@ # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import OmegaConf from nemo.collections.asr.models.ssl_models import SpeechEncDecSelfSupervisedModel diff --git a/examples/asr/speech_to_text_finetune.py b/examples/asr/speech_to_text_finetune.py index 36a7bdc3bbdc..6b53446622ee 100644 --- a/examples/asr/speech_to_text_finetune.py +++ b/examples/asr/speech_to_text_finetune.py @@ -54,7 +54,7 @@ https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/configs.html#fine-tuning-configurations """ import time -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import OmegaConf from nemo.collections.asr.models import ASRModel diff --git a/examples/asr/speech_translation/speech_to_text_transformer.py b/examples/asr/speech_translation/speech_to_text_transformer.py index ac4dc4334164..bb7e0b3e4461 100644 --- a/examples/asr/speech_translation/speech_to_text_transformer.py +++ b/examples/asr/speech_translation/speech_to_text_transformer.py @@ -40,7 +40,7 @@ """ -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import OmegaConf from nemo.collections.asr.models import EncDecTransfModelBPE diff --git a/examples/asr/speech_translation/translate_speech.py b/examples/asr/speech_translation/translate_speech.py index 53599e1b3511..76c8c096527f 100644 --- a/examples/asr/speech_translation/translate_speech.py +++ b/examples/asr/speech_translation/translate_speech.py @@ -17,7 +17,7 @@ from dataclasses import dataclass, is_dataclass from typing import List, Optional, Union -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from omegaconf import OmegaConf diff --git a/examples/asr/transcribe_speech.py b/examples/asr/transcribe_speech.py index a543fcf5e252..5c4a636e8b1c 100644 --- a/examples/asr/transcribe_speech.py +++ b/examples/asr/transcribe_speech.py @@ -18,7 +18,7 @@ from dataclasses import dataclass, field, is_dataclass from typing import List, Optional, Union -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from omegaconf import OmegaConf, open_dict @@ -276,6 +276,9 @@ def main(cfg: TranscriptionConfig) -> Union[TranscriptionConfig, List[Hypothesis # we will adjust this flag if the model does not support it compute_langs = cfg.compute_langs + if cfg.timestamps: + cfg.return_hypotheses = True + # Check whether model and decoder type match if isinstance(asr_model, EncDecCTCModel): if cfg.decoder_type and cfg.decoder_type != 'ctc': diff --git a/examples/asr/transcribe_speech_parallel.py b/examples/asr/transcribe_speech_parallel.py index eb905d3e91b0..d60099acd379 100644 --- a/examples/asr/transcribe_speech_parallel.py +++ b/examples/asr/transcribe_speech_parallel.py @@ -75,7 +75,7 @@ from dataclasses import dataclass, is_dataclass from typing import Optional -import pytorch_lightning as ptl +import lightning.pytorch as ptl import torch from omegaconf import MISSING, OmegaConf @@ -163,6 +163,14 @@ def main(cfg: ParallelTranscriptionConfig): cfg.predict_ds.return_sample_id = True cfg.predict_ds = match_train_config(predict_ds=cfg.predict_ds, train_ds=model.cfg.train_ds) + if cfg.predict_ds.use_lhotse: + OmegaConf.set_struct(cfg.predict_ds, False) + cfg.trainer.use_distributed_sampler = False + cfg.predict_ds.force_finite = True + cfg.predict_ds.force_map_dataset = True + cfg.predict_ds.do_transcribe = True + OmegaConf.set_struct(cfg.predict_ds, True) + if isinstance(model, EncDecMultiTaskModel): cfg.trainer.use_distributed_sampler = False OmegaConf.set_struct(cfg.predict_ds, False) @@ -172,7 +180,7 @@ def main(cfg: ParallelTranscriptionConfig): trainer = ptl.Trainer(**cfg.trainer) - if isinstance(model, EncDecMultiTaskModel): + if cfg.predict_ds.use_lhotse: OmegaConf.set_struct(cfg.predict_ds, False) cfg.predict_ds.global_rank = trainer.global_rank cfg.predict_ds.world_size = trainer.world_size diff --git a/examples/audio/audio_to_audio_train.py b/examples/audio/audio_to_audio_train.py index cef46dcf20b6..4d71e75176c9 100644 --- a/examples/audio/audio_to_audio_train.py +++ b/examples/audio/audio_to_audio_train.py @@ -28,7 +28,7 @@ """ from enum import Enum -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from omegaconf import OmegaConf diff --git a/examples/audio/process_audio.py b/examples/audio/process_audio.py index ec88bda34954..8657d53ef957 100644 --- a/examples/audio/process_audio.py +++ b/examples/audio/process_audio.py @@ -20,7 +20,7 @@ from pathlib import Path from typing import List, Optional -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from omegaconf import OmegaConf diff --git a/examples/llm/peft/hf.py b/examples/llm/peft/hf.py index 97f21d6c253e..357dc5a7bd17 100644 --- a/examples/llm/peft/hf.py +++ b/examples/llm/peft/hf.py @@ -13,7 +13,7 @@ # limitations under the License. import fiddle as fdl -from pytorch_lightning.loggers import WandbLogger +from lightning.pytorch.loggers import WandbLogger from nemo import lightning as nl from nemo.collections import llm @@ -76,11 +76,11 @@ def formatting_prompts_func(examples): # See: https://github.com/Lightning-AI/pytorch-lightning/blob/8ad3e29816a63d8ce5c00ac104b14729a4176f4f/src/lightning/pytorch/plugins/precision/fsdp.py#L81 grad_clip = None use_dist_samp = False - tokenizer = llm.HfAutoModelForCausalLM.configure_tokenizer(args.model) + tokenizer = llm.HFAutoModelForCausalLM.configure_tokenizer(args.model) llm.api.finetune( - model=llm.HfAutoModelForCausalLM(args.model), - data=llm.HfDatasetDataModule( + model=llm.HFAutoModelForCausalLM(args.model), + data=llm.HFDatasetDataModule( mk_hf_dataset(tokenizer.tokenizer), pad_token_id=tokenizer.tokenizer.eos_token_id ), trainer=nl.Trainer( diff --git a/examples/llm/sft/hf.py b/examples/llm/sft/hf.py old mode 100644 new mode 100755 index 7d4cde7866a2..ce79e136a1c2 --- a/examples/llm/sft/hf.py +++ b/examples/llm/sft/hf.py @@ -13,12 +13,14 @@ # limitations under the License. import fiddle as fdl -import pytorch_lightning as pl -from pytorch_lightning.loggers import WandbLogger +import lightning.pytorch as pl +from lightning.pytorch.loggers import WandbLogger from torch.utils.data import DataLoader from nemo import lightning as nl from nemo.collections import llm +from nemo.lightning.pytorch.accelerate.transformer_engine import is_te_accelerated +from nemo.lightning.pytorch.callbacks import ModelCallback class SquadDataModuleWithPthDataloader(llm.SquadDataModule): @@ -53,7 +55,9 @@ def squad(tokenizer) -> pl.LightningDataModule: parser.add_argument('--strategy', type=str, default='auto', choices=['auto', 'ddp', 'fsdp']) parser.add_argument('--devices', default=1) parser.add_argument('--accelerator', default='gpu', choices=['gpu']) + parser.add_argument('--model-accelerator', default=None, choices=['te']) parser.add_argument('--max-steps', type=int, default=100) + parser.add_argument("--fp8-autocast", default=False, action='store_true') parser.add_argument('--wandb-project', type=str, default=None) parser.add_argument('--model-save-path', type=str, default=None) args = parser.parse_args() @@ -71,7 +75,16 @@ def squad(tokenizer) -> pl.LightningDataModule: grad_clip = None use_dist_samp = False - model = llm.HfAutoModelForCausalLM(args.model) + model_accelerator = None + if args.model_accelerator == "te": + from functools import partial + from nemo.lightning.pytorch.accelerate.transformer_engine import te_accelerate + + model_accelerator = partial(te_accelerate, fp8_autocast=args.fp8_autocast) + + from nemo.lightning.pytorch.accelerate.transformer_engine import te_accelerate + + model = llm.HFAutoModelForCausalLM(model_name=args.model, model_accelerator=model_accelerator) tokenizer = model.tokenizer llm.api.finetune( @@ -88,11 +101,18 @@ def squad(tokenizer) -> pl.LightningDataModule: accumulate_grad_batches=10, gradient_clip_val=grad_clip, use_distributed_sampler=use_dist_samp, + callbacks=[], logger=wandb, ), optim=fdl.build(llm.adam.pytorch_adam_with_flat_lr(lr=1e-5)), log=None, ) + if args.model_accelerator: + if args.model_accelerator == "te": + te_acc = is_te_accelerated(model.model) + assert te_acc, "Transformer Engine acceleration was unsuccessful" + print("TE Accelerated: ", te_acc) + if args.model_save_path is not None: model.save_pretrained(args.model_save_path) diff --git a/examples/multimodal/multimodal_llm/neva/convert_llava_to_neva.py b/examples/multimodal/multimodal_llm/neva/convert_llava_to_neva.py index d02b737c750a..874d62dc63c9 100644 --- a/examples/multimodal/multimodal_llm/neva/convert_llava_to_neva.py +++ b/examples/multimodal/multimodal_llm/neva/convert_llava_to_neva.py @@ -34,10 +34,10 @@ from collections import OrderedDict import torch +from lightning.pytorch.core.saving import _load_state as ptl_load_state +from lightning.pytorch.trainer.trainer import Trainer from llava import LlavaLlamaForCausalLM from omegaconf import OmegaConf -from pytorch_lightning.core.saving import _load_state as ptl_load_state -from pytorch_lightning.trainer.trainer import Trainer from safetensors import safe_open from transformers import LlamaTokenizer diff --git a/examples/multimodal/speech_llm/export/extract_salm_weights.py b/examples/multimodal/speech_llm/export/extract_salm_weights.py index 0698a411110e..24c7aec3bb4d 100644 --- a/examples/multimodal/speech_llm/export/extract_salm_weights.py +++ b/examples/multimodal/speech_llm/export/extract_salm_weights.py @@ -18,9 +18,9 @@ import tempfile import torch +from lightning.pytorch.trainer.trainer import Trainer from megatron.core import dist_checkpointing from omegaconf import OmegaConf -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.multimodal.speech_llm.modules.perception_modules import AudioPerceptionModule from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel diff --git a/examples/multimodal/text_to_image/controlnet/controlnet_train.py b/examples/multimodal/text_to_image/controlnet/controlnet_train.py index 2bb8b66cac1a..14e7e62a1cc7 100644 --- a/examples/multimodal/text_to_image/controlnet/controlnet_train.py +++ b/examples/multimodal/text_to_image/controlnet/controlnet_train.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning import Trainer +from lightning.pytorch import Trainer from nemo.collections.multimodal.models.text_to_image.controlnet.controlnet import MegatronControlNet from nemo.collections.multimodal.models.text_to_image.controlnet.util import ImageLogger diff --git a/examples/multimodal/text_to_image/convert_hf_ckpt_to_nemo.py b/examples/multimodal/text_to_image/convert_hf_ckpt_to_nemo.py index cebf159eb870..c50ad439eaec 100644 --- a/examples/multimodal/text_to_image/convert_hf_ckpt_to_nemo.py +++ b/examples/multimodal/text_to_image/convert_hf_ckpt_to_nemo.py @@ -27,10 +27,10 @@ from argparse import ArgumentParser import torch -from lightning_fabric.utilities.cloud_io import _load as pl_load +from lightning.fabric.utilities.cloud_io import _load as pl_load +from lightning.pytorch.plugins.environments import TorchElasticEnvironment +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import OmegaConf -from pytorch_lightning.plugins.environments import TorchElasticEnvironment -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.multimodal.models.text_to_image.controlnet.controlnet import MegatronControlNet from nemo.collections.multimodal.models.text_to_image.stable_diffusion.diffusion_engine import MegatronDiffusionEngine diff --git a/examples/multimodal/text_to_image/dreambooth/dreambooth_lora_infer.py b/examples/multimodal/text_to_image/dreambooth/dreambooth_lora_infer.py index 52f0aa2940d2..e1d050f83939 100644 --- a/examples/multimodal/text_to_image/dreambooth/dreambooth_lora_infer.py +++ b/examples/multimodal/text_to_image/dreambooth/dreambooth_lora_infer.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import torch +from lightning.pytorch import Trainer +from lightning.pytorch.plugins.environments import TorchElasticEnvironment from omegaconf import open_dict -from pytorch_lightning import Trainer -from pytorch_lightning.plugins.environments import TorchElasticEnvironment from nemo.collections.multimodal.models.text_to_image.stable_diffusion.ldm.ddpm import MegatronLatentDiffusion from nemo.collections.multimodal.parts.stable_diffusion.pipeline import pipeline @@ -48,7 +48,10 @@ def model_cfg_modifier(model_cfg): plugins = [] plugins.append(TorchElasticEnvironment()) - strategy = NLPDDPStrategy(no_ddp_communication_hook=True, find_unused_parameters=False,) + strategy = NLPDDPStrategy( + no_ddp_communication_hook=True, + find_unused_parameters=False, + ) trainer = Trainer(plugins=plugins, strategy=strategy, **cfg.trainer) model = MegatronLatentDiffusion(model_cfg, trainer=trainer) diff --git a/examples/multimodal/text_to_image/imagen/generate_fid_images.py b/examples/multimodal/text_to_image/imagen/generate_fid_images.py index ea743e3e1d06..7d2df372b545 100644 --- a/examples/multimodal/text_to_image/imagen/generate_fid_images.py +++ b/examples/multimodal/text_to_image/imagen/generate_fid_images.py @@ -15,7 +15,7 @@ import os import torch -from pytorch_lightning import Trainer +from lightning.pytorch import Trainer from nemo.collections.multimodal.models.text_to_image.imagen.imagen_pipeline import ImagenPipeline from nemo.core.config import hydra_runner @@ -79,7 +79,10 @@ def main(cfg): seeds = [local_task_id * chunk_size + batch_idx * batch_size + idx for idx in range(len(batch_captions))] with torch.no_grad(): images, all_res_images, *_ = pipeline( - prompts=batch_captions, seed=seeds, single_batch_mode=True, classifier_free_guidance=current_node_cfg, + prompts=batch_captions, + seed=seeds, + single_batch_mode=True, + classifier_free_guidance=current_node_cfg, ) if cfg.fid.save_all_res: diff --git a/examples/multimodal/text_to_image/imagen/imagen_generate_images.py b/examples/multimodal/text_to_image/imagen/imagen_generate_images.py index bc002052a989..06b324367a52 100644 --- a/examples/multimodal/text_to_image/imagen/imagen_generate_images.py +++ b/examples/multimodal/text_to_image/imagen/imagen_generate_images.py @@ -16,8 +16,8 @@ import pickle import torch +from lightning.pytorch import Trainer from omegaconf import OmegaConf -from pytorch_lightning import Trainer from nemo.collections.multimodal.models.text_to_image.imagen.imagen_pipeline import ( ImagenPipeline, @@ -65,7 +65,11 @@ def main(inference_config): seed = batch_idx + chuncksize with torch.no_grad(): - images, all_res_images, throughput = pipeline(prompts=batch_captions, seed=seeds, single_batch_mode=True,) + images, all_res_images, throughput = pipeline( + prompts=batch_captions, + seed=seeds, + single_batch_mode=True, + ) for outpath, one_res in zip(outpaths, all_res_images): for idx, (caption, image) in enumerate(zip(batch_captions, one_res[0])): diff --git a/examples/multimodal/text_to_image/imagen/imagen_infer.py b/examples/multimodal/text_to_image/imagen/imagen_infer.py index 0fb291729596..9ce680cf4b09 100644 --- a/examples/multimodal/text_to_image/imagen/imagen_infer.py +++ b/examples/multimodal/text_to_image/imagen/imagen_infer.py @@ -14,8 +14,8 @@ import os +from lightning.pytorch import Trainer from omegaconf import OmegaConf -from pytorch_lightning import Trainer from nemo.collections.multimodal.models.text_to_image.imagen.imagen_pipeline import ( ImagenPipeline, diff --git a/examples/multimodal/text_to_image/imagen/imagen_training.py b/examples/multimodal/text_to_image/imagen/imagen_training.py index 23c1c9c1a1d7..211299156b69 100644 --- a/examples/multimodal/text_to_image/imagen/imagen_training.py +++ b/examples/multimodal/text_to_image/imagen/imagen_training.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from omegaconf.omegaconf import OmegaConf, open_dict from torch._dynamo import disable diff --git a/examples/multimodal/text_to_image/stable_diffusion/sd_lora_infer.py b/examples/multimodal/text_to_image/stable_diffusion/sd_lora_infer.py index 0877d4eb4b2f..0d83a8daab9f 100644 --- a/examples/multimodal/text_to_image/stable_diffusion/sd_lora_infer.py +++ b/examples/multimodal/text_to_image/stable_diffusion/sd_lora_infer.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import torch +from lightning.pytorch import Trainer +from lightning.pytorch.plugins.environments import TorchElasticEnvironment from omegaconf import open_dict -from pytorch_lightning import Trainer -from pytorch_lightning.plugins.environments import TorchElasticEnvironment from nemo.collections.multimodal.models.text_to_image.stable_diffusion.ldm.ddpm import MegatronLatentDiffusion from nemo.collections.multimodal.parts.stable_diffusion.pipeline import pipeline @@ -45,7 +45,10 @@ def model_cfg_modifier(model_cfg): plugins = [] plugins.append(TorchElasticEnvironment()) - strategy = NLPDDPStrategy(no_ddp_communication_hook=True, find_unused_parameters=False,) + strategy = NLPDDPStrategy( + no_ddp_communication_hook=True, + find_unused_parameters=False, + ) trainer = Trainer(plugins=plugins, strategy=strategy, **cfg.trainer) model = MegatronLatentDiffusion(model_cfg, trainer=trainer) diff --git a/examples/multimodal/text_to_image/stable_diffusion/sd_xl_train.py b/examples/multimodal/text_to_image/stable_diffusion/sd_xl_train.py index 44412aee0d14..4ef22b69aa64 100644 --- a/examples/multimodal/text_to_image/stable_diffusion/sd_xl_train.py +++ b/examples/multimodal/text_to_image/stable_diffusion/sd_xl_train.py @@ -13,10 +13,11 @@ # limitations under the License. import sys + import torch import torch._dynamo.config as dynamo_config +from lightning.pytorch import Trainer from omegaconf.omegaconf import OmegaConf -from pytorch_lightning import Trainer from nemo.collections.multimodal.models.text_to_image.stable_diffusion.diffusion_engine import MegatronDiffusionEngine from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronTrainerBuilder diff --git a/examples/multimodal/vision_language_foundation/clip/convert_external_clip_to_nemo.py b/examples/multimodal/vision_language_foundation/clip/convert_external_clip_to_nemo.py index 178140aac828..abc987e07097 100644 --- a/examples/multimodal/vision_language_foundation/clip/convert_external_clip_to_nemo.py +++ b/examples/multimodal/vision_language_foundation/clip/convert_external_clip_to_nemo.py @@ -45,9 +45,9 @@ import einops import open_clip import torch +from lightning.pytorch.plugins.environments import TorchElasticEnvironment +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import OmegaConf -from pytorch_lightning.plugins.environments import TorchElasticEnvironment -from pytorch_lightning.trainer.trainer import Trainer from transformers import CLIPModel from nemo.collections.multimodal.models.vision_language_foundation.clip.megatron_clip_models import MegatronCLIPModel diff --git a/examples/multimodal/x_to_nerf/benchmark_callback.py b/examples/multimodal/x_to_nerf/benchmark_callback.py index fd7d5afdc5bc..2db78d1d385a 100644 --- a/examples/multimodal/x_to_nerf/benchmark_callback.py +++ b/examples/multimodal/x_to_nerf/benchmark_callback.py @@ -15,7 +15,7 @@ import time from typing import Optional -from pytorch_lightning import Callback, LightningModule, Trainer +from lightning.pytorch import Callback, LightningModule, Trainer from nemo.utils import logging diff --git a/examples/multimodal/x_to_nerf/data.py b/examples/multimodal/x_to_nerf/data.py index fe7c47abc64b..b8dfd3aa536b 100644 --- a/examples/multimodal/x_to_nerf/data.py +++ b/examples/multimodal/x_to_nerf/data.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf.omegaconf import DictConfig from torch.utils.data import DataLoader diff --git a/examples/multimodal/x_to_nerf/main.py b/examples/multimodal/x_to_nerf/main.py index 5d7f616a3165..f3c8a6949867 100644 --- a/examples/multimodal/x_to_nerf/main.py +++ b/examples/multimodal/x_to_nerf/main.py @@ -13,8 +13,8 @@ # limitations under the License. from hydra.utils import get_class, instantiate +from lightning.pytorch import Trainer, seed_everything from omegaconf.omegaconf import DictConfig, OmegaConf -from pytorch_lightning import Trainer, seed_everything from nemo.core.config import hydra_runner from nemo.utils import logging diff --git a/examples/multimodal_autoregressive/README.md b/examples/multimodal_autoregressive/README.md new file mode 100644 index 000000000000..5934074a7d17 --- /dev/null +++ b/examples/multimodal_autoregressive/README.md @@ -0,0 +1,3 @@ +### MULTIMODAL AUTOREGRESSIVE GENERTION + +For information on how to get started with autoregressive generation for multimodal datasets using discrete tokenizers follow this [guide](nemo/collections/multimodal_autoregressive/data/README.md) diff --git a/examples/multimodal_autoregressive/conf/megatron_mm_ar_inference_image_generation.yaml b/examples/multimodal_autoregressive/conf/megatron_mm_ar_inference_image_generation.yaml new file mode 100644 index 000000000000..806800c96155 --- /dev/null +++ b/examples/multimodal_autoregressive/conf/megatron_mm_ar_inference_image_generation.yaml @@ -0,0 +1,36 @@ +inference: + greedy: True # Whether or not to use sampling ; use greedy decoding otherwise + top_k: 0 # The number of highest probability vocabulary tokens to keep for top-k-filtering. + top_p: 0.9 # If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation. + temperature: 1.0 # sampling temperature + add_BOS: True # add the bos token at the begining of the prompt + tokens_to_generate: 30 # The minimum length of the sequence to be generated. + all_probs: False # whether return the log prob for all the tokens in vocab + repetition_penalty: 1.2 # The parameter for repetition penalty. 1.0 means no penalty. + min_tokens_to_generate: 0 # The minimum length of the sequence to be generated. + compute_logprob: False # a flag used to compute logprob of all the input text, a very special case of running inference, default False + end_strings: ["<|extra_204|>"] # generation will stop when one of these tokens is generated + +trainer: + devices: 1 + num_nodes: 1 + accelerator: gpu + logger: False # logger provided by exp_manager + precision: bf16 # 16, 32, or bf16 + use_distributed_sampler: False + +tensor_model_parallel_size: -1 +pipeline_model_parallel_size: -1 +pipeline_model_parallel_split_rank: -1 # used for encoder and decoder model (0 for others) +megatron_amp_O2: False # Enable O2-level automatic mixed precision to save memory +image_encoder: Cosmos-Tokenizer-DV8x16x16 +gpt_model_file: null # GPT nemo file path +checkpoint_dir: null # checkpoint file dir. This is used to load the PTL checkpoint generated during the GPT training +checkpoint_name: null # PTL checkpoint file name, only used for PTL checkpoint loading +hparams_file: null # model configuration file, only used for PTL checkpoint loading +captions: # prompts for GPT inference + - "a drawing of a green pokemon with red eyes" + - "a red pokemon with green eyes" + - "a cartoon fish with a big smile" +images_output_path: null # Path to the directory to store the output images + diff --git a/examples/multimodal_autoregressive/conf/megatron_mm_ar_inference_vision_understanding.yaml b/examples/multimodal_autoregressive/conf/megatron_mm_ar_inference_vision_understanding.yaml new file mode 100644 index 000000000000..c392f5dcc5c2 --- /dev/null +++ b/examples/multimodal_autoregressive/conf/megatron_mm_ar_inference_vision_understanding.yaml @@ -0,0 +1,32 @@ +inference: + greedy: True # Whether or not to use sampling ; use greedy decoding otherwise + top_k: 0 # The number of highest probability vocabulary tokens to keep for top-k-filtering. + top_p: 0.9 # If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher are kept for generation. + temperature: 1.0 # sampling temperature + add_BOS: False # add the bos token at the begining of the prompt + tokens_to_generate: 30 # The minimum length of the sequence to be generated. + all_probs: False # whether return the log prob for all the tokens in vocab + repetition_penalty: 1.2 # The parameter for repetition penalty. 1.0 means no penalty. + min_tokens_to_generate: 0 # The minimum length of the sequence to be generated. + compute_logprob: False # a flag used to compute logprob of all the input text, a very special case of running inference, default False + end_strings: ["<|extra_204|>"] # generation will stop when one of these tokens is generated + +trainer: + devices: 1 + num_nodes: 1 + accelerator: gpu + logger: False # logger provided by exp_manager + precision: bf16 # 16, 32, or bf16 + use_distributed_sampler: False + +tensor_model_parallel_size: -1 +pipeline_model_parallel_size: -1 +pipeline_model_parallel_split_rank: -1 # used for encoder and decoder model (0 for others) +megatron_amp_O2: False # Enable O2-level automatic mixed precision to save memory +gpt_model_file: null # GPT nemo file path +checkpoint_dir: null # checkpoint file dir. This is used to load the PTL checkpoint generated during the GPT training +checkpoint_name: null # PTL checkpoint file name, only used for PTL checkpoint loading +hparams_file: null # model configuration file, only used for PTL checkpoint loading +images_path: # prompts for GPT inference + - "/path/to/image1" + - "/path/to/image2" diff --git a/examples/multimodal_autoregressive/megatron_mm_autoregressive_eval_image_generation.py b/examples/multimodal_autoregressive/megatron_mm_autoregressive_eval_image_generation.py new file mode 100644 index 000000000000..ae8dddb29553 --- /dev/null +++ b/examples/multimodal_autoregressive/megatron_mm_autoregressive_eval_image_generation.py @@ -0,0 +1,196 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import math +import os +import re + +import torch +import torchvision +from examples.nlp.language_modeling.megatron_gpt_eval import ( + load_model_from_config, + remove_padded_prompts, + round_to_mult, +) +from pytorch_lightning.trainer.trainer import Trainer + +# pylint: disable=line-too-long +from nemo.collections.common.video_tokenizers.cosmos_tokenizer import CausalVideoTokenizer +from nemo.collections.nlp.modules.common.transformer.text_generation import LengthParam, SamplingParam +from nemo.collections.nlp.parts.nlp_overrides import CustomProgressBar, NLPDDPStrategy +from nemo.core.config import hydra_runner + +""" +This is the script to run multimodal autoregresssive text generation. + +Make sure you install tiktoken==0.6.0 + +Usage: + Assume the model has TP=1, PP=1 in the following use cases. + a. run greedy inference from a nemo file: + python megatron_mm_autoregresssive_eval.py \ + gpt_model_file=PATH_TO_MODEL \ + inference.greedy=True \ + inference.add_BOS=True \ + trainer.devices=1 \ + trainer.num_nodes=1 \ + tensor_model_parallel_size=-1 \ + pipeline_model_parallel_size=-1 \ + captions=[caption1,caption2] + + b. run greedy inference from a PTL checkpoint file: + python megatron_mm_autoregresssive_eval.py \ + checkpoint_dir=PATH_TO_CHECKPOINT_FILE \ + checkpoint_name=CHECKPOINT_FILE_NAME \ + hparams_file=HPARAMS_FILE \ + inference.greedy=True \ + inference.add_BOS=True \ + trainer.devices=1 \ + trainer.num_nodes=1 \ + tensor_model_parallel_size=-1 \ + pipeline_model_parallel_size=-1 \ + captions=[caption1,caption2] + + c. run top_p inference from a nemo file: + python megatron_mm_autoregresssive_eval.py \ + gpt_model_file=PATH_TO_MODEL \ + inference.greedy=False \ + inference.top_k=0 \ + inference.top_p=0.9 \ + inference.repetition_penalty=1.2 \ + inference.add_BOS=True \ + trainer.devices=1 \ + trainer.num_nodes=1 \ + tensor_model_parallel_size=-1 \ + pipeline_model_parallel_size=-1 \ + captions=[caption1,caption2] + + d. If you don't need to generate tokens and need model to compute logprobs: + python megatron_mm_autoregresssive_eval.py \ + gpt_model_file=PATH_TO_MODEL \ + inference.compute_logprob=True \ + trainer.devices=1 \ + trainer.num_nodes=1 \ + tensor_model_parallel_size=-1 \ + pipeline_model_parallel_size=-1 \ + captions=[caption1,caption2] +""" + + +def to_img(tokens_string, image_tokenizer): + """Converts visual tokens to images + + Given input visual tokens, we extract the indices, pass it to the decoder to get the image + """ + visual_token_pattern = r"<\|visual token (\d+)\|>" + visual_tokens = [int(match) for match in re.findall(visual_token_pattern, tokens_string)] + # We assume image is square. So if 64 tokensa are present, we reshape it to 8x8 and then pass it to decoder + dim = int(math.sqrt(len(visual_tokens))) + visual_tokens_tensor = torch.tensor(visual_tokens[: dim * dim]) + # Decoder accepts input of the following format [bs, channel_dim, h, w] + visual_tokens_tensor_reshaped = visual_tokens_tensor.reshape((dim, dim)).unsqueeze(0).unsqueeze(0) + visual_tokens_final = visual_tokens_tensor_reshaped.to(image_tokenizer._device) + img = image_tokenizer.decode(visual_tokens_final) + + # Convert from bf16 to 16 and to format [channel_dim, h, w] + image = torchvision.transforms.functional.to_pil_image(img.float().squeeze()) + return image + + +def load_prompts(cfg): + """Function to return the prompts passed into the model""" + prompts = [] + for caption in cfg.captions: + prompt = f'You are a helpful assistant. Draw a picture for the caption given by the user. USER: {caption}. ASSISTANT: ' + prompts.append(prompt) + return prompts + + +if not torch.cuda.is_available(): + raise EnvironmentError("GPU is needed for the inference") + + +@hydra_runner(config_path="conf", config_name="megatron_mm_ar_inference_image_generation") +def main(cfg) -> None: + """Main function""" + + callbacks = [] + # enable_progress_bar is True by default. If cfg.trainer.enable_progress_bar=False, CustomProgressBar is not appended to callbacks + if 'enable_progress_bar' not in cfg.trainer or cfg.trainer.enable_progress_bar: + callbacks.append(CustomProgressBar()) + # trainer required for restoring model parallel models + trainer = Trainer( + strategy=NLPDDPStrategy(timeout=datetime.timedelta(seconds=18000)), + **cfg.trainer, + callbacks=callbacks, + ) + + image_tokenizer = CausalVideoTokenizer.from_pretrained( + tokenizer_type=cfg.image_encoder, load_encoder=False, load_decoder=True, load_full_model=False + ) + + model = load_model_from_config(trainer, cfg) + model.freeze() + + # Have to turn off activations_checkpoint_method for inference + try: + model.model.language_model.encoder.activations_checkpoint_method = None + except AttributeError: + pass + + length_params: LengthParam = { + "max_length": cfg.inference.tokens_to_generate, + "min_length": cfg.inference.min_tokens_to_generate, + } + + sampling_params: SamplingParam = { + "use_greedy": cfg.inference.greedy, + "temperature": cfg.inference.temperature, + "top_k": cfg.inference.top_k, + "top_p": cfg.inference.top_p, + "repetition_penalty": cfg.inference.repetition_penalty, + "add_BOS": cfg.inference.add_BOS, + "all_probs": cfg.inference.all_probs, + "compute_logprob": cfg.inference.compute_logprob, + "end_strings": cfg.inference.end_strings, + } + + prompts = [] + with torch.no_grad(): + prompts = load_prompts(cfg) + + fp8_enabled = hasattr(model.cfg, "fp8") and (model.cfg.fp8 == True) + if fp8_enabled and len(prompts) > 0: + padded_len = round_to_mult(len(prompts), 8) + nb_paddings = padded_len - len(prompts) + if nb_paddings > 0: + nb_paddings += [''] * nb_paddings + + # First method of running text generation, call model.generate method + response = model.generate(inputs=prompts, length_params=length_params, sampling_params=sampling_params) + + if fp8_enabled: + response = remove_padded_prompts(response, nb_paddings) + + output_tokens_strings = response['sentences'] + for idx, output_token_string in enumerate(output_tokens_strings): + image = to_img(output_token_string, image_tokenizer) + image.save(os.path.join(cfg.images_output_path, f'{idx}.jpg')) + + print(f'Images saved to {cfg.images_output_path}') + + +if __name__ == '__main__': + main() # noqa pylint: disable=no-value-for-parameter diff --git a/examples/multimodal_autoregressive/megatron_mm_autoregressive_eval_vision_understanding.py b/examples/multimodal_autoregressive/megatron_mm_autoregressive_eval_vision_understanding.py new file mode 100644 index 000000000000..4aea4d9898ae --- /dev/null +++ b/examples/multimodal_autoregressive/megatron_mm_autoregressive_eval_vision_understanding.py @@ -0,0 +1,220 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import datetime + +import torch +import torchvision +from examples.nlp.language_modeling.megatron_gpt_eval import ( + RequestDataSet, + load_model_from_config, + remove_padded_prompts, + round_to_mult, +) +from omegaconf import OmegaConf +from PIL import Image +from pytorch_lightning.trainer.trainer import Trainer +from torch.utils.data import DataLoader +from transformers import AutoModel, AutoTokenizer + +# pylint: disable=line-too-long +from nemo.collections.nlp.modules.common.transformer.text_generation import LengthParam, SamplingParam +from nemo.collections.nlp.parts.nlp_overrides import CustomProgressBar, NLPDDPStrategy +from nemo.core.config import hydra_runner + +""" +This is the script to run multimodal autoregresssive text generation. + +Make sure you install tiktoken==0.6.0 + +Usage: + Assume the model has TP=1, PP=1 in the following use cases. + a. run greedy inference from a nemo file: + python megatron_mm_autoregresssive_eval.py \ + gpt_model_file=PATH_TO_MODEL \ + inference.greedy=True \ + inference.add_BOS=True \ + trainer.devices=1 \ + trainer.num_nodes=1 \ + tensor_model_parallel_size=-1 \ + pipeline_model_parallel_size=-1 \ + images_path=[image_path1,image_path2] + + b. run greedy inference from a PTL checkpoint file: + python megatron_mm_autoregresssive_eval.py \ + checkpoint_dir=PATH_TO_CHECKPOINT_FILE \ + checkpoint_name=CHECKPOINT_FILE_NAME \ + hparams_file=HPARAMS_FILE \ + inference.greedy=True \ + inference.add_BOS=True \ + trainer.devices=1 \ + trainer.num_nodes=1 \ + tensor_model_parallel_size=-1 \ + pipeline_model_parallel_size=-1 \ + images_path=[image_path1,image_path2] + + c. run top_p inference from a nemo file: + python megatron_mm_autoregresssive_eval.py \ + gpt_model_file=PATH_TO_MODEL \ + inference.greedy=False \ + inference.top_k=0 \ + inference.top_p=0.9 \ + inference.repetition_penalty=1.2 \ + inference.add_BOS=True \ + trainer.devices=1 \ + trainer.num_nodes=1 \ + tensor_model_parallel_size=-1 \ + pipeline_model_parallel_size=-1 \ + images_path=[image_path1,image_path2] + + d. If you don't need to generate tokens and need model to compute logprobs: + python megatron_mm_autoregresssive_eval.py \ + gpt_model_file=PATH_TO_MODEL \ + inference.compute_logprob=True \ + trainer.devices=1 \ + trainer.num_nodes=1 \ + tensor_model_parallel_size=-1 \ + pipeline_model_parallel_size=-1 \ + images_path=[image_path1,image_path2] +""" + +EMU_HUB = "BAAI/Emu3-Gen" +VQ_HUB = "BAAI/Emu3-VisionTokenizer" + + +def to_imgstr(image_tokens, tokenizer): + """Convert integer image tokens to visual tokens string""" + image_tokens = image_tokens.cpu().numpy().tolist() + image_token_str = [ + ['<|visual token {token_id:0>6d}|>'.format(token_id=token_id) for token_id in token_row] + for token_row in image_tokens + ] + image_row_str = ["".join(token_row) for token_row in image_token_str] + imgstr = tokenizer.eol_token.join(image_row_str) + return imgstr + + +def load_prompts(cfg, image_tokenizer, tokenizer): + """Function to generate prompts + + The prompts generated here are fed to the model. + """ + prompts = [] + text = "Please describe the image" + for image_path in cfg.images_path: + image = Image.open(image_path) + image_tensor = torchvision.transforms.functional.pil_to_tensor(image).unsqueeze(0) + image_tokens = image_tokenizer.encode(image_tensor.to(image_tokenizer.device, image_tokenizer.dtype)) + bs, h, w = image_tokens.shape + imgstr = to_imgstr(image_tokens[0], tokenizer=tokenizer) + image_prompt = ( + tokenizer.boi_token + + f'{h}*{w}' + + tokenizer.img_token + + imgstr + + tokenizer.eol_token + + tokenizer.eof_token + + tokenizer.eoi_token + ) + prompt = f'{tokenizer.bos_token}You are a helpful assistant. USER: {image_prompt}{text}. ASSISTANT:' + prompts.append(prompt) + return prompts + + +if not torch.cuda.is_available(): + raise EnvironmentError("GPU is needed for the inference") + + +@hydra_runner(config_path="conf", config_name="megatron_mm_ar_inference_vision_understanding") +def main(cfg) -> None: + """Main function""" + + callbacks = [] + # enable_progress_bar is True by default. If cfg.trainer.enable_progress_bar=False, CustomProgressBar is not appended to callbacks + if 'enable_progress_bar' not in cfg.trainer or cfg.trainer.enable_progress_bar: + callbacks.append(CustomProgressBar()) + # trainer required for restoring model parallel models + trainer = Trainer( + strategy=NLPDDPStrategy(timeout=datetime.timedelta(seconds=18000)), + **cfg.trainer, + callbacks=callbacks, + ) + + tokenizer = AutoTokenizer.from_pretrained(EMU_HUB, trust_remote_code=True) + image_tokenizer = AutoModel.from_pretrained(VQ_HUB, device_map="cuda", trust_remote_code=True).eval() + + model = load_model_from_config(trainer, cfg) + model.freeze() + + # Have to turn off activations_checkpoint_method for inference + try: + model.model.language_model.encoder.activations_checkpoint_method = None + except AttributeError: + pass + + length_params: LengthParam = { + "max_length": cfg.inference.tokens_to_generate, + "min_length": cfg.inference.min_tokens_to_generate, + } + + sampling_params: SamplingParam = { + "use_greedy": cfg.inference.greedy, + "temperature": cfg.inference.temperature, + "top_k": cfg.inference.top_k, + "top_p": cfg.inference.top_p, + "repetition_penalty": cfg.inference.repetition_penalty, + "add_BOS": cfg.inference.add_BOS, + "all_probs": cfg.inference.all_probs, + "compute_logprob": cfg.inference.compute_logprob, + "end_strings": cfg.inference.end_strings, + } + + prompts = [] + with torch.no_grad(): + prompts = load_prompts(cfg, image_tokenizer, tokenizer) + + fp8_enabled = hasattr(model.cfg, "fp8") and (model.cfg.fp8 == True) + if fp8_enabled and len(prompts) > 0: + padded_len = round_to_mult(len(prompts), 8) + nb_paddings = padded_len - len(prompts) + if nb_paddings > 0: + nb_paddings += [''] * nb_paddings + + # First method of running text generation, call model.generate method + response = model.generate(inputs=prompts, length_params=length_params, sampling_params=sampling_params) + + if fp8_enabled: + response = remove_padded_prompts(response, nb_paddings) + print("***************************") + print(response) + print("***************************") + + # Second method of running text generation, call trainer.predict [recommended] + bs = 8 if fp8_enabled else 2 + ds = RequestDataSet(prompts) + request_dl = DataLoader(dataset=ds, batch_size=bs) + config = OmegaConf.to_container(cfg.inference) + model.set_inference_config(config) + response = trainer.predict(model, request_dl) + + if fp8_enabled: + response[-1] = remove_padded_prompts(response[-1], nb_paddings) + print("***************************") + print(response) + print("***************************") + + +if __name__ == '__main__': + main() # noqa pylint: disable=no-value-for-parameter diff --git a/examples/nlp/dialogue/dialogue.py b/examples/nlp/dialogue/dialogue.py index 578895a2ad43..3f4c5581eb5a 100644 --- a/examples/nlp/dialogue/dialogue.py +++ b/examples/nlp/dialogue/dialogue.py @@ -42,7 +42,7 @@ import os -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import DictConfig, OmegaConf from nemo.collections.nlp.models.dialogue.dialogue_gpt_classification_model import DialogueGPTClassificationModel diff --git a/examples/nlp/duplex_text_normalization/helpers.py b/examples/nlp/duplex_text_normalization/helpers.py index 6c1cfe37b90d..d9b8780fd787 100644 --- a/examples/nlp/duplex_text_normalization/helpers.py +++ b/examples/nlp/duplex_text_normalization/helpers.py @@ -14,7 +14,7 @@ import os -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import DictConfig from nemo.collections.nlp.data.text_normalization import constants @@ -29,7 +29,7 @@ def instantiate_model_and_trainer(cfg: DictConfig, model_name: str, do_training: bool): - """ Function for instantiating a model and a trainer + """Function for instantiating a model and a trainer Args: cfg: The config used to instantiate the model and the trainer. model_name: A str indicates whether the model to be instantiated is a tagger or a decoder (i.e., model_name should be either TAGGER_MODEL or DECODER_MODEL). diff --git a/examples/nlp/entity_linking/self_alignment_pretraining.py b/examples/nlp/entity_linking/self_alignment_pretraining.py index a1ac1ac327cb..58b20f384d04 100644 --- a/examples/nlp/entity_linking/self_alignment_pretraining.py +++ b/examples/nlp/entity_linking/self_alignment_pretraining.py @@ -16,8 +16,8 @@ # Please see tutorial at Nemo/tutorials/nlp/Entity_Linking_Medical.ipynb for # more information on entity linking and self alignment pretraining. +from lightning.pytorch import Trainer from omegaconf import DictConfig, OmegaConf -from pytorch_lightning import Trainer from nemo.collections.nlp.models import EntityLinkingModel from nemo.core.config import hydra_runner diff --git a/examples/nlp/glue_benchmark/glue_benchmark.py b/examples/nlp/glue_benchmark/glue_benchmark.py index 3cb5f8e4af3e..28efb9520fbd 100644 --- a/examples/nlp/glue_benchmark/glue_benchmark.py +++ b/examples/nlp/glue_benchmark/glue_benchmark.py @@ -35,7 +35,7 @@ import os -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import DictConfig, OmegaConf from nemo.collections.nlp.models import GLUEModel diff --git a/examples/nlp/information_retrieval/bert_dpr.py b/examples/nlp/information_retrieval/bert_dpr.py index 2d9cd962ff34..4fc791da04fd 100644 --- a/examples/nlp/information_retrieval/bert_dpr.py +++ b/examples/nlp/information_retrieval/bert_dpr.py @@ -13,7 +13,7 @@ # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import DictConfig, OmegaConf from nemo.collections.nlp.models import BertDPRModel diff --git a/examples/nlp/information_retrieval/bert_joint_ir.py b/examples/nlp/information_retrieval/bert_joint_ir.py index 1bb164e580d1..f95cdd04e036 100644 --- a/examples/nlp/information_retrieval/bert_joint_ir.py +++ b/examples/nlp/information_retrieval/bert_joint_ir.py @@ -13,7 +13,7 @@ # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import DictConfig, OmegaConf from nemo.collections.nlp.models import BertJointIRModel diff --git a/examples/nlp/information_retrieval/megatron_gpt_embedding_finetuning.py b/examples/nlp/information_retrieval/megatron_gpt_embedding_finetuning.py index e1fe28cc892f..9cb5cb5d3d19 100644 --- a/examples/nlp/information_retrieval/megatron_gpt_embedding_finetuning.py +++ b/examples/nlp/information_retrieval/megatron_gpt_embedding_finetuning.py @@ -15,8 +15,8 @@ from collections.abc import MutableMapping import torch.multiprocessing as mp +from lightning.pytorch.loggers import WandbLogger from omegaconf.omegaconf import OmegaConf -from pytorch_lightning.loggers import WandbLogger from nemo.collections.nlp.models.information_retrieval.megatron_gpt_embedding_model import MegatronGPTEmbeddingModel from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronLMPPTrainerBuilder diff --git a/examples/nlp/information_retrieval/megatron_gpt_reranker_finetuning.py b/examples/nlp/information_retrieval/megatron_gpt_reranker_finetuning.py index cf65840bb843..be89e5bf5c43 100644 --- a/examples/nlp/information_retrieval/megatron_gpt_reranker_finetuning.py +++ b/examples/nlp/information_retrieval/megatron_gpt_reranker_finetuning.py @@ -15,8 +15,8 @@ from collections.abc import MutableMapping import torch.multiprocessing as mp +from lightning.pytorch.loggers import WandbLogger from omegaconf.omegaconf import OmegaConf -from pytorch_lightning.loggers import WandbLogger from nemo.collections.nlp.models.information_retrieval.megatron_gpt_reranker_model import MegatronGPTRerankerModel from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronLMPPTrainerBuilder diff --git a/examples/nlp/intent_slot_classification/intent_slot_classification.py b/examples/nlp/intent_slot_classification/intent_slot_classification.py index a112ea7785f5..2025f48f330f 100644 --- a/examples/nlp/intent_slot_classification/intent_slot_classification.py +++ b/examples/nlp/intent_slot_classification/intent_slot_classification.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import DictConfig, OmegaConf from nemo.collections.nlp.models import IntentSlotClassificationModel diff --git a/examples/nlp/intent_slot_classification/multi_label_intent_slot_classification.py b/examples/nlp/intent_slot_classification/multi_label_intent_slot_classification.py index 2441885e2ed2..232aa7d4d230 100644 --- a/examples/nlp/intent_slot_classification/multi_label_intent_slot_classification.py +++ b/examples/nlp/intent_slot_classification/multi_label_intent_slot_classification.py @@ -27,7 +27,7 @@ """ -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import DictConfig, OmegaConf from nemo.collections.nlp.models import MultiLabelIntentSlotClassificationModel diff --git a/examples/nlp/language_modeling/bert_pretraining.py b/examples/nlp/language_modeling/bert_pretraining.py index 75d0a1072e69..7cff43f7fc73 100644 --- a/examples/nlp/language_modeling/bert_pretraining.py +++ b/examples/nlp/language_modeling/bert_pretraining.py @@ -13,9 +13,9 @@ # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl +from lightning.pytorch.strategies import DDPStrategy from omegaconf import DictConfig, OmegaConf -from pytorch_lightning.strategies import DDPStrategy from nemo.collections.nlp.models.language_modeling import BERTLMModel from nemo.core.config import hydra_runner diff --git a/examples/nlp/language_modeling/mamba_change_num_partition.py b/examples/nlp/language_modeling/mamba_change_num_partition.py index ced2b43cd312..349543de8e59 100644 --- a/examples/nlp/language_modeling/mamba_change_num_partition.py +++ b/examples/nlp/language_modeling/mamba_change_num_partition.py @@ -19,8 +19,8 @@ from argparse import ArgumentParser import torch +from lightning.pytorch import Trainer from omegaconf import open_dict -from pytorch_lightning import Trainer from nemo.collections.nlp.models.language_modeling.megatron_mamba_model import MegatronMambaModel from nemo.collections.nlp.parts.nlp_overrides import ( diff --git a/examples/nlp/language_modeling/megatron_bart_pretraining.py b/examples/nlp/language_modeling/megatron_bart_pretraining.py index e45b5e04ca45..a6dd6f183d72 100644 --- a/examples/nlp/language_modeling/megatron_bart_pretraining.py +++ b/examples/nlp/language_modeling/megatron_bart_pretraining.py @@ -13,11 +13,11 @@ # limitations under the License. +from lightning.pytorch import Trainer +from lightning.pytorch.callbacks import ModelSummary +from lightning.pytorch.plugins.environments import TorchElasticEnvironment +from lightning.pytorch.trainer.connectors.checkpoint_connector import _CheckpointConnector from omegaconf.omegaconf import OmegaConf, open_dict -from pytorch_lightning import Trainer -from pytorch_lightning.callbacks import ModelSummary -from pytorch_lightning.plugins.environments import TorchElasticEnvironment -from pytorch_lightning.trainer.connectors.checkpoint_connector import _CheckpointConnector from nemo.collections.nlp.models.language_modeling.megatron_bart_model import MegatronBARTModel from nemo.collections.nlp.parts.nlp_overrides import ( @@ -48,7 +48,7 @@ def main(cfg) -> None: scaler = None if cfg.trainer.precision in [16, '16', '16-mixed']: scaler = GradScaler( - init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32), + init_scale=cfg.model.get('native_amp_init_scale', 2**32), growth_interval=cfg.model.get('native_amp_growth_interval', 1000), hysteresis=cfg.model.get('hysteresis', 2), ) diff --git a/examples/nlp/language_modeling/megatron_change_num_partitions.py b/examples/nlp/language_modeling/megatron_change_num_partitions.py index c035346e3bf1..49d1ef0dcb57 100644 --- a/examples/nlp/language_modeling/megatron_change_num_partitions.py +++ b/examples/nlp/language_modeling/megatron_change_num_partitions.py @@ -21,8 +21,8 @@ import torch import torch.nn as nn +from lightning.pytorch import Trainer from omegaconf import OmegaConf, open_dict -from pytorch_lightning import Trainer from nemo.collections.nlp.parts.nlp_overrides import ( NEMO_MEGATRON_MODEL_PARALLEL_APPSTATE_OVERRIDE, @@ -922,7 +922,7 @@ def main(): scaler = None if precision in [16, '16', '16-mixed']: scaler = GradScaler( - init_scale=tmp_cfg.get('native_amp_init_scale', 2 ** 32), + init_scale=tmp_cfg.get('native_amp_init_scale', 2**32), growth_interval=tmp_cfg.get('native_amp_growth_interval', 1000), hysteresis=tmp_cfg.get('hysteresis', 2), ) @@ -943,7 +943,10 @@ def main(): if tp_size < 0 or pp_size < 0: logging.info(f"Loading model config from {args.model_file} to get TP and PP size") model_config_internal = cls.restore_from( - restore_path=args.model_file, trainer=trainer, map_location=torch.device("cpu"), return_config=True, + restore_path=args.model_file, + trainer=trainer, + map_location=torch.device("cpu"), + return_config=True, ) tp_size = model_config_internal.get('tensor_model_parallel_size', 1) @@ -1137,7 +1140,9 @@ def main(): else: model = cls.load_from_checkpoint( - checkpoint_path=checkpoint_path, trainer=trainer, map_location=torch.device("cpu"), + checkpoint_path=checkpoint_path, + trainer=trainer, + map_location=torch.device("cpu"), ) model.freeze() diff --git a/examples/nlp/language_modeling/megatron_ckpt_to_nemo.py b/examples/nlp/language_modeling/megatron_ckpt_to_nemo.py index c81119489582..4b9fab987dc7 100644 --- a/examples/nlp/language_modeling/megatron_ckpt_to_nemo.py +++ b/examples/nlp/language_modeling/megatron_ckpt_to_nemo.py @@ -32,10 +32,10 @@ import torch from genericpath import isdir +from lightning.pytorch.plugins.environments import TorchElasticEnvironment +from lightning.pytorch.trainer.trainer import Trainer from megatron.core import parallel_state from omegaconf import OmegaConf, open_dict -from pytorch_lightning.plugins.environments import TorchElasticEnvironment -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.models.language_modeling.megatron_bart_model import MegatronBARTModel from nemo.collections.nlp.models.language_modeling.megatron_bert_model import MegatronBertModel @@ -112,6 +112,11 @@ def get_args(): choices=['32-true', '16-mixed', 'bf16-mixed'], help="Precision value for the trainer that matches with precision of the ckpt", ) + parser.add_argument( + "--convert_mlm", + action="store_true", + help="Use this flag to convert megatron-lm checkpoints.", + ) args = parser.parse_args() return args @@ -195,7 +200,9 @@ def convert(local_rank, rank, world_size, args): ) if args.model_type == 'gpt': - model = MegatronGPTModel.load_from_checkpoint(checkpoint_path, hparams_file=args.hparams_file, trainer=trainer) + model = MegatronGPTModel.load_from_checkpoint( + checkpoint_path, hparams_file=args.hparams_file, trainer=trainer, load_mlm=args.convert_mlm + ) elif args.model_type == 'sft': model = MegatronGPTSFTModel.load_from_checkpoint( checkpoint_path, hparams_file=args.hparams_file, trainer=trainer diff --git a/examples/nlp/language_modeling/megatron_export.py b/examples/nlp/language_modeling/megatron_export.py index bf9157884bfc..b511a415d9b1 100644 --- a/examples/nlp/language_modeling/megatron_export.py +++ b/examples/nlp/language_modeling/megatron_export.py @@ -28,8 +28,8 @@ import os +from lightning.pytorch import Trainer from omegaconf import OmegaConf, open_dict -from pytorch_lightning import Trainer from nemo.collections.nlp.models.language_modeling.megatron_bart_model import MegatronBARTModel from nemo.collections.nlp.models.language_modeling.megatron_bert_model import MegatronBertModel diff --git a/examples/nlp/language_modeling/megatron_gpt_distillation.py b/examples/nlp/language_modeling/megatron_gpt_distillation.py index dc8614be23b2..c00470c5c81e 100644 --- a/examples/nlp/language_modeling/megatron_gpt_distillation.py +++ b/examples/nlp/language_modeling/megatron_gpt_distillation.py @@ -19,8 +19,8 @@ import modelopt.torch.distill as mtd import modelopt.torch.opt as mto import torch.multiprocessing as mp +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import DictConfig, OmegaConf, open_dict -from pytorch_lightning.trainer.trainer import Trainer try: from megatron.core import parallel_state, tensor_parallel diff --git a/examples/nlp/language_modeling/megatron_gpt_eval.py b/examples/nlp/language_modeling/megatron_gpt_eval.py index b9b0d2973094..4dbbee78e898 100644 --- a/examples/nlp/language_modeling/megatron_gpt_eval.py +++ b/examples/nlp/language_modeling/megatron_gpt_eval.py @@ -20,8 +20,8 @@ from functools import partial import torch +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import OmegaConf, open_dict -from pytorch_lightning.trainer.trainer import Trainer from torch.utils.data import DataLoader, Dataset from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel diff --git a/examples/nlp/language_modeling/megatron_gpt_mcore_batch_eval.py b/examples/nlp/language_modeling/megatron_gpt_mcore_batch_eval.py index 988a5f8588ff..ceb32d75f495 100644 --- a/examples/nlp/language_modeling/megatron_gpt_mcore_batch_eval.py +++ b/examples/nlp/language_modeling/megatron_gpt_mcore_batch_eval.py @@ -16,6 +16,7 @@ import os from argparse import Namespace +from lightning.pytorch.trainer.trainer import Trainer from megatron.core.inference.common_inference_params import CommonInferenceParams from megatron.core.inference.engines.mcore_engine import MCoreEngine from megatron.core.inference.inference_model_wrappers.gpt.gpt_inference_wrapper import GPTInferenceWrapper @@ -23,7 +24,6 @@ SimpleTextGenerationController, ) from omegaconf import OmegaConf, open_dict -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel from nemo.collections.nlp.modules.common.megatron.megatron_init import fake_initialize_model_parallel diff --git a/examples/nlp/language_modeling/megatron_gpt_prune.py b/examples/nlp/language_modeling/megatron_gpt_prune.py index de12b861a1c0..44992873f362 100644 --- a/examples/nlp/language_modeling/megatron_gpt_prune.py +++ b/examples/nlp/language_modeling/megatron_gpt_prune.py @@ -16,8 +16,8 @@ import torch import torch.multiprocessing as mp from datasets import load_dataset +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import OmegaConf -from pytorch_lightning.trainer.trainer import Trainer from tqdm import tqdm from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel diff --git a/examples/nlp/language_modeling/megatron_gpt_ptq.py b/examples/nlp/language_modeling/megatron_gpt_ptq.py index e41becc2d8e0..0ac0822c5fbe 100644 --- a/examples/nlp/language_modeling/megatron_gpt_ptq.py +++ b/examples/nlp/language_modeling/megatron_gpt_ptq.py @@ -15,8 +15,8 @@ import torch import torch.multiprocessing as mp from datasets import load_dataset +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import OmegaConf -from pytorch_lightning.trainer.trainer import Trainer from tqdm import tqdm from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel diff --git a/examples/nlp/language_modeling/megatron_gpt_test.py b/examples/nlp/language_modeling/megatron_gpt_test.py index 62a1d40dbaed..03bc6735e891 100644 --- a/examples/nlp/language_modeling/megatron_gpt_test.py +++ b/examples/nlp/language_modeling/megatron_gpt_test.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from lightning.pytorch import Trainer from omegaconf.omegaconf import OmegaConf -from pytorch_lightning import Trainer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel from nemo.collections.nlp.modules.common.megatron.megatron_utils import compute_model_parallel_rank @@ -38,7 +38,7 @@ def main(cfg) -> None: trainer = Trainer( plugins=[ NLPMixedPrecisionPlugin( - init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32), + init_scale=cfg.model.get('native_amp_init_scale', 2**32), growth_interval=cfg.model.get('native_amp_growth_interval', 1000), ), ], @@ -46,7 +46,13 @@ def main(cfg) -> None: **cfg.trainer, ) elif cfg.trainer.precision in ['bf16', 'bf16-mixed']: - trainer = Trainer(plugins=[NLPNativeBfloat16PrecisionPlugin(),], strategy=NLPDDPStrategy(), **cfg.trainer,) + trainer = Trainer( + plugins=[ + NLPNativeBfloat16PrecisionPlugin(), + ], + strategy=NLPDDPStrategy(), + **cfg.trainer, + ) else: trainer = Trainer(plugins=[NLPPrecisionPlugin()], strategy=NLPDDPStrategy(), **cfg.trainer) @@ -55,7 +61,9 @@ def main(cfg) -> None: app_state.model_parallel_rank = compute_model_parallel_rank(trainer.local_rank, app_state.model_parallel_size) model = MegatronGPTModel.restore_from( - cfg.restore_from_path, trainer=trainer, save_restore_connector=NLPSaveRestoreConnector(), + cfg.restore_from_path, + trainer=trainer, + save_restore_connector=NLPSaveRestoreConnector(), ) # Note: most nemo models must have the data paths configured before instantiating the model diff --git a/examples/nlp/language_modeling/megatron_gpt_validate.py b/examples/nlp/language_modeling/megatron_gpt_validate.py index b5a61e627a14..fa0abb89421c 100644 --- a/examples/nlp/language_modeling/megatron_gpt_validate.py +++ b/examples/nlp/language_modeling/megatron_gpt_validate.py @@ -15,8 +15,8 @@ import os import tempfile +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import OmegaConf, open_dict -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel from nemo.collections.nlp.modules.common.megatron.megatron_init import fake_initialize_model_parallel @@ -140,7 +140,9 @@ def main(cfg) -> None: with tempfile.NamedTemporaryFile(suffix='.yaml') as f: OmegaConf.save(config=pretrained_cfg, f=f.name) model = MegatronGPTModel.load_from_checkpoint( - checkpoint_path=checkpoint_path, trainer=trainer, hparams_file=f.name, + checkpoint_path=checkpoint_path, + trainer=trainer, + hparams_file=f.name, ) else: raise ValueError("need at least a nemo file or checkpoint dir") diff --git a/examples/nlp/language_modeling/megatron_lm_ckpt_to_nemo.py b/examples/nlp/language_modeling/megatron_lm_ckpt_to_nemo.py index 72252a03d5be..64ba2a51bb71 100644 --- a/examples/nlp/language_modeling/megatron_lm_ckpt_to_nemo.py +++ b/examples/nlp/language_modeling/megatron_lm_ckpt_to_nemo.py @@ -42,12 +42,12 @@ from typing import Any, Optional import torch -from lightning_fabric.utilities.cloud_io import _load as pl_load +from lightning.fabric.utilities.cloud_io import _load as pl_load +from lightning.pytorch.core.saving import _load_state as ptl_load_state +from lightning.pytorch.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml +from lightning.pytorch.trainer.trainer import Trainer +from lightning.pytorch.utilities.migration import pl_legacy_patch from megatron.core import parallel_state -from pytorch_lightning.core.saving import _load_state as ptl_load_state -from pytorch_lightning.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml -from pytorch_lightning.trainer.trainer import Trainer -from pytorch_lightning.utilities.migration import pl_legacy_patch from nemo.collections.nlp.models.language_modeling.megatron_bert_model import MegatronBertModel from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel diff --git a/examples/nlp/language_modeling/megatron_mamba_eval.py b/examples/nlp/language_modeling/megatron_mamba_eval.py index ed12e4b904ac..ba000e6bef63 100644 --- a/examples/nlp/language_modeling/megatron_mamba_eval.py +++ b/examples/nlp/language_modeling/megatron_mamba_eval.py @@ -20,8 +20,8 @@ from functools import partial import torch +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import OmegaConf, open_dict -from pytorch_lightning.trainer.trainer import Trainer from torch.utils.data import DataLoader, Dataset from nemo.collections.nlp.models.language_modeling.megatron_mamba_model import MegatronMambaModel diff --git a/examples/nlp/language_modeling/megatron_retro_cal_shape.py b/examples/nlp/language_modeling/megatron_retro_cal_shape.py index a57a927d2a36..f790d9471964 100644 --- a/examples/nlp/language_modeling/megatron_retro_cal_shape.py +++ b/examples/nlp/language_modeling/megatron_retro_cal_shape.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from lightning.pytorch import Trainer +from lightning.pytorch.plugins.environments import TorchElasticEnvironment +from lightning.pytorch.plugins.precision import MixedPrecisionPlugin from omegaconf.omegaconf import OmegaConf, open_dict -from pytorch_lightning import Trainer -from pytorch_lightning.plugins.environments import TorchElasticEnvironment -from pytorch_lightning.plugins.precision import MixedPrecisionPlugin from nemo.collections.nlp.models.language_modeling.megatron_retrieval_model import MegatronRetrievalModel from nemo.collections.nlp.modules.common.megatron.mup.shape import make_base_shapes @@ -46,7 +46,7 @@ def main(cfg) -> None: scaler = None if cfg.trainer.precision in [16, '16', '16-mixed']: scaler = GradScaler( - init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32), + init_scale=cfg.model.get('native_amp_init_scale', 2**32), growth_interval=cfg.model.get('native_amp_growth_interval', 1000), hysteresis=cfg.model.get('hysteresis', 2), ) diff --git a/examples/nlp/language_modeling/megatron_retro_eval.py b/examples/nlp/language_modeling/megatron_retro_eval.py index 89e3fe9c3ddb..ac946b2adf42 100644 --- a/examples/nlp/language_modeling/megatron_retro_eval.py +++ b/examples/nlp/language_modeling/megatron_retro_eval.py @@ -16,8 +16,8 @@ import os import torch +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import OmegaConf -from pytorch_lightning.trainer.trainer import Trainer from torch.utils.data import DataLoader, Dataset from nemo.collections.nlp.models.language_modeling.megatron_retro_model import MegatronRetroModel @@ -60,7 +60,9 @@ def __init__(self, sentences, neighbors): self.sentences = sentences self.neighbors = neighbors - def __len__(self,): + def __len__( + self, + ): return len(self.sentences) def __getitem__(self, idx): diff --git a/examples/nlp/language_modeling/megatron_retro_eval_legacy.py b/examples/nlp/language_modeling/megatron_retro_eval_legacy.py index 69222acedd34..c51a8f536cc1 100644 --- a/examples/nlp/language_modeling/megatron_retro_eval_legacy.py +++ b/examples/nlp/language_modeling/megatron_retro_eval_legacy.py @@ -15,8 +15,8 @@ import os from examples.nlp.language_modeling.megatron_gpt_eval import RequestDataSet +from lightning.pytorch import Trainer from omegaconf.omegaconf import OmegaConf, open_dict -from pytorch_lightning import Trainer from torch.utils.data import DataLoader from nemo.collections.nlp.models.language_modeling.megatron_retrieval_model import MegatronRetrievalModel @@ -69,7 +69,10 @@ def main(cfg) -> None: save_restore_connector.model_extracted_dir = model_path model_cfg = MegatronRetrievalModel.restore_from( - model_path, trainer=trainer, return_config=True, save_restore_connector=save_restore_connector, + model_path, + trainer=trainer, + return_config=True, + save_restore_connector=save_restore_connector, ) with open_dict(model_cfg): @@ -89,7 +92,10 @@ def main(cfg) -> None: cfg.pipeline_model_parallel_split_rank = model_cfg.get('pipeline_model_parallel_split_rank', 0) model = MegatronRetrievalModel.restore_from( - model_path, trainer=trainer, save_restore_connector=save_restore_connector, override_config_path=model_cfg, + model_path, + trainer=trainer, + save_restore_connector=save_restore_connector, + override_config_path=model_cfg, ) length_params: LengthParam = { diff --git a/examples/nlp/language_modeling/megatron_retro_fine_tune.py b/examples/nlp/language_modeling/megatron_retro_fine_tune.py index 3fcaec156d9c..153a4b581135 100644 --- a/examples/nlp/language_modeling/megatron_retro_fine_tune.py +++ b/examples/nlp/language_modeling/megatron_retro_fine_tune.py @@ -15,12 +15,12 @@ import datetime import os +from lightning.pytorch import Trainer +from lightning.pytorch.callbacks.timer import Timer +from lightning.pytorch.plugins.environments import TorchElasticEnvironment +from lightning.pytorch.plugins.precision import MixedPrecisionPlugin +from lightning.pytorch.trainer.connectors.checkpoint_connector import _CheckpointConnector from omegaconf.omegaconf import OmegaConf, open_dict -from pytorch_lightning import Trainer -from pytorch_lightning.callbacks.timer import Timer -from pytorch_lightning.plugins.environments import TorchElasticEnvironment -from pytorch_lightning.plugins.precision import MixedPrecisionPlugin -from pytorch_lightning.trainer.connectors.checkpoint_connector import _CheckpointConnector from nemo.collections.nlp.models.language_modeling.megatron_retro_fine_tune_model import MegatronRetroFinetuneModel from nemo.collections.nlp.parts.nlp_overrides import ( @@ -87,7 +87,7 @@ def main(cfg) -> None: scaler = None if cfg.trainer.precision in [16, '16', '16-mixed']: scaler = GradScaler( - init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32), + init_scale=cfg.model.get('native_amp_init_scale', 2**32), growth_interval=cfg.model.get('native_amp_growth_interval', 1000), hysteresis=cfg.model.get('hysteresis', 2), ) @@ -118,7 +118,9 @@ def main(cfg) -> None: # Override timer callback to a stateless one for idx, callback in enumerate(trainer.callbacks): if isinstance(callback, Timer): - trainer.callbacks[idx] = StatelessTimer(cfg.trainer.max_time,) + trainer.callbacks[idx] = StatelessTimer( + cfg.trainer.max_time, + ) # load existing or init new soft prompt GPT model if cfg.model.get("restore_path", None): diff --git a/examples/nlp/language_modeling/megatron_retro_mutransfer_pretrain.py b/examples/nlp/language_modeling/megatron_retro_mutransfer_pretrain.py index af6e22035def..775b75680ee9 100644 --- a/examples/nlp/language_modeling/megatron_retro_mutransfer_pretrain.py +++ b/examples/nlp/language_modeling/megatron_retro_mutransfer_pretrain.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from lightning.pytorch import Trainer +from lightning.pytorch.plugins.environments import TorchElasticEnvironment +from lightning.pytorch.plugins.precision import MixedPrecisionPlugin +from lightning.pytorch.trainer.connectors.checkpoint_connector import _CheckpointConnector from omegaconf.omegaconf import OmegaConf, open_dict -from pytorch_lightning import Trainer -from pytorch_lightning.plugins.environments import TorchElasticEnvironment -from pytorch_lightning.plugins.precision import MixedPrecisionPlugin -from pytorch_lightning.trainer.connectors.checkpoint_connector import _CheckpointConnector from nemo.collections.nlp.models.language_modeling.megatron_retrieval_model import MegatronRetrievalModel from nemo.collections.nlp.modules.common.megatron.mup.optim import MuAdam, MuAdamW @@ -52,7 +52,7 @@ def main(cfg) -> None: scaler = None if cfg.trainer.precision in [16, '16', '16-mixed']: scaler = GradScaler( - init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32), + init_scale=cfg.model.get('native_amp_init_scale', 2**32), growth_interval=cfg.model.get('native_amp_growth_interval', 1000), hysteresis=cfg.model.get('hysteresis', 2), ) diff --git a/examples/nlp/language_modeling/megatron_retro_pretraining_legacy.py b/examples/nlp/language_modeling/megatron_retro_pretraining_legacy.py index 4653222b3438..298deafabc1c 100644 --- a/examples/nlp/language_modeling/megatron_retro_pretraining_legacy.py +++ b/examples/nlp/language_modeling/megatron_retro_pretraining_legacy.py @@ -14,11 +14,11 @@ import os +from lightning.pytorch import Trainer +from lightning.pytorch.plugins.environments import TorchElasticEnvironment +from lightning.pytorch.plugins.precision import MixedPrecisionPlugin +from lightning.pytorch.trainer.connectors.checkpoint_connector import _CheckpointConnector from omegaconf.omegaconf import OmegaConf, open_dict -from pytorch_lightning import Trainer -from pytorch_lightning.plugins.environments import TorchElasticEnvironment -from pytorch_lightning.plugins.precision import MixedPrecisionPlugin -from pytorch_lightning.trainer.connectors.checkpoint_connector import _CheckpointConnector from nemo.collections.nlp.models.language_modeling.megatron_retrieval_model import MegatronRetrievalModel from nemo.collections.nlp.modules.common.megatron.megatron_init import initialize_model_parallel_for_nemo @@ -51,7 +51,7 @@ def main(cfg) -> None: scaler = None if cfg.trainer.precision in [16, '16', '16-mixed']: scaler = GradScaler( - init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32), + init_scale=cfg.model.get('native_amp_init_scale', 2**32), growth_interval=cfg.model.get('native_amp_growth_interval', 1000), hysteresis=cfg.model.get('hysteresis', 2), ) diff --git a/examples/nlp/language_modeling/megatron_retro_qatask_eval.py b/examples/nlp/language_modeling/megatron_retro_qatask_eval.py index b99bcafbab02..4e47157d5150 100644 --- a/examples/nlp/language_modeling/megatron_retro_qatask_eval.py +++ b/examples/nlp/language_modeling/megatron_retro_qatask_eval.py @@ -17,8 +17,8 @@ import os import torch +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import OmegaConf -from pytorch_lightning.trainer.trainer import Trainer from torch.utils.data import DataLoader, Dataset from nemo.collections.nlp.data.question_answering.input_example.qa_input_example import QAExample @@ -63,7 +63,9 @@ def __init__(self, sentences, neighbors): self.sentences = sentences self.neighbors = neighbors - def __len__(self,): + def __len__( + self, + ): return len(self.sentences) def __getitem__(self, idx): diff --git a/examples/nlp/language_modeling/megatron_t5_eval.py b/examples/nlp/language_modeling/megatron_t5_eval.py index 0b6ea54b6b99..57b48134101f 100644 --- a/examples/nlp/language_modeling/megatron_t5_eval.py +++ b/examples/nlp/language_modeling/megatron_t5_eval.py @@ -17,8 +17,8 @@ from argparse import ArgumentParser import torch +from lightning.pytorch.trainer.trainer import Trainer from omegaconf.omegaconf import OmegaConf, open_dict -from pytorch_lightning.trainer.trainer import Trainer from torch.utils.data import DataLoader from nemo.collections.nlp.data.language_modeling.megatron.request_dataset import T5RequestDataset @@ -40,13 +40,22 @@ def main(): "--tokens_to_generate", type=int, default="16", required=False, help="How many tokens to add to prompt" ) parser.add_argument( - "--tensor_model_parallel_size", type=int, default=-1, required=False, + "--tensor_model_parallel_size", + type=int, + default=-1, + required=False, ) parser.add_argument( - "--pipeline_model_parallel_size", type=int, default=-1, required=False, + "--pipeline_model_parallel_size", + type=int, + default=-1, + required=False, ) parser.add_argument( - "--pipeline_model_parallel_split_rank", type=int, default=-1, required=False, + "--pipeline_model_parallel_split_rank", + type=int, + default=-1, + required=False, ) parser.add_argument("--precision", default="16", type=str, help="PyTorch Lightning Trainer precision flag") parser.add_argument("--decoder_starts_with_pad", action="store_true", help="Decoder starts with pad token") diff --git a/examples/nlp/language_modeling/megatron_t5_lm_adaptation_finetune.py b/examples/nlp/language_modeling/megatron_t5_lm_adaptation_finetune.py index 9e392d913171..4137213023ee 100644 --- a/examples/nlp/language_modeling/megatron_t5_lm_adaptation_finetune.py +++ b/examples/nlp/language_modeling/megatron_t5_lm_adaptation_finetune.py @@ -13,11 +13,11 @@ # limitations under the License. +from lightning.pytorch import Trainer +from lightning.pytorch.callbacks import ModelSummary +from lightning.pytorch.plugins.environments import TorchElasticEnvironment +from lightning.pytorch.trainer.connectors.checkpoint_connector import _CheckpointConnector from omegaconf.omegaconf import OmegaConf, open_dict -from pytorch_lightning import Trainer -from pytorch_lightning.callbacks import ModelSummary -from pytorch_lightning.plugins.environments import TorchElasticEnvironment -from pytorch_lightning.trainer.connectors.checkpoint_connector import _CheckpointConnector from nemo.collections.nlp.models.language_modeling.megatron_t5_model import MegatronT5Model from nemo.collections.nlp.parts.nlp_overrides import ( @@ -49,7 +49,7 @@ def main(cfg) -> None: scaler = None if cfg.trainer.precision in [16, '16', '16-mixed']: scaler = GradScaler( - init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32), + init_scale=cfg.model.get('native_amp_init_scale', 2**32), growth_interval=cfg.model.get('native_amp_growth_interval', 1000), hysteresis=cfg.model.get('hysteresis', 2), ) diff --git a/examples/nlp/language_modeling/megatron_t5_seq2seq_eval.py b/examples/nlp/language_modeling/megatron_t5_seq2seq_eval.py index ba8ea6492da3..ae6e1744395d 100644 --- a/examples/nlp/language_modeling/megatron_t5_seq2seq_eval.py +++ b/examples/nlp/language_modeling/megatron_t5_seq2seq_eval.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from lightning.pytorch import Trainer +from lightning.pytorch.plugins.environments import TorchElasticEnvironment +from lightning.pytorch.plugins.precision import MixedPrecisionPlugin from megatron_t5_seq2seq_finetune import load_from_checkpoint_dir, load_from_nemo, validate_checkpoint_loading_args from omegaconf.omegaconf import OmegaConf, open_dict -from pytorch_lightning import Trainer -from pytorch_lightning.plugins.environments import TorchElasticEnvironment -from pytorch_lightning.plugins.precision import MixedPrecisionPlugin from nemo.collections.nlp.models.language_modeling.megatron_glue_model import MegatronT5GLUEModel from nemo.collections.nlp.models.language_modeling.megatron_t0_model import MegatronT0Model @@ -82,7 +82,7 @@ def main(cfg) -> None: scaler = None if cfg.trainer.precision in [16, '16', '16-mixed']: scaler = GradScaler( - init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32), + init_scale=cfg.model.get('native_amp_init_scale', 2**32), growth_interval=cfg.model.get('native_amp_growth_interval', 1000), hysteresis=cfg.model.get('hysteresis', 2), ) diff --git a/examples/nlp/language_modeling/megatron_t5_seq2seq_finetune.py b/examples/nlp/language_modeling/megatron_t5_seq2seq_finetune.py index 2409e99ad951..5f63289be27a 100644 --- a/examples/nlp/language_modeling/megatron_t5_seq2seq_finetune.py +++ b/examples/nlp/language_modeling/megatron_t5_seq2seq_finetune.py @@ -16,10 +16,10 @@ import tempfile import torch.multiprocessing as mp +from lightning.pytorch import Trainer +from lightning.pytorch.plugins.environments import TorchElasticEnvironment +from lightning.pytorch.trainer.connectors.checkpoint_connector import _CheckpointConnector from omegaconf.omegaconf import OmegaConf, open_dict -from pytorch_lightning import Trainer -from pytorch_lightning.plugins.environments import TorchElasticEnvironment -from pytorch_lightning.trainer.connectors.checkpoint_connector import _CheckpointConnector from nemo.collections.nlp.models.language_modeling.megatron_glue_model import MegatronT5GLUEModel from nemo.collections.nlp.models.language_modeling.megatron_t0_model import MegatronT0Model diff --git a/examples/nlp/language_modeling/transformer_lm.py b/examples/nlp/language_modeling/transformer_lm.py index caaa0e0d2935..3e97e28bb35e 100644 --- a/examples/nlp/language_modeling/transformer_lm.py +++ b/examples/nlp/language_modeling/transformer_lm.py @@ -13,7 +13,7 @@ # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import DictConfig, OmegaConf from nemo.collections.nlp.models.language_modeling import TransformerLMModel diff --git a/examples/nlp/language_modeling/upcycle_dense_to_moe.py b/examples/nlp/language_modeling/upcycle_dense_to_moe.py index a1f4b6000b6f..f4a5fc017d97 100644 --- a/examples/nlp/language_modeling/upcycle_dense_to_moe.py +++ b/examples/nlp/language_modeling/upcycle_dense_to_moe.py @@ -26,7 +26,7 @@ import torch import torch.nn -from pytorch_lightning.trainer.trainer import Trainer +from lightning.pytorch.trainer.trainer import Trainer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy, NLPSaveRestoreConnector diff --git a/examples/nlp/machine_translation/enc_dec_nmt-bottleneck.py b/examples/nlp/machine_translation/enc_dec_nmt-bottleneck.py index b1743e03188e..58c948f11458 100644 --- a/examples/nlp/machine_translation/enc_dec_nmt-bottleneck.py +++ b/examples/nlp/machine_translation/enc_dec_nmt-bottleneck.py @@ -15,8 +15,8 @@ from dataclasses import dataclass from typing import Optional +from lightning.pytorch import Trainer from omegaconf import OmegaConf -from pytorch_lightning import Trainer from nemo.collections.nlp.data.machine_translation.preproc_mt_data import MTDataPreproc from nemo.collections.nlp.models.machine_translation.mt_enc_dec_bottleneck_model import MTBottleneckModel @@ -29,7 +29,6 @@ from nemo.utils.config_utils import update_model_config from nemo.utils.exp_manager import ExpManagerConfig, exp_manager - """ Usage: 1. If you need to start docker and install NeMo, otherwise skip this step: diff --git a/examples/nlp/machine_translation/enc_dec_nmt.py b/examples/nlp/machine_translation/enc_dec_nmt.py index 57b9f84c39ce..b901ba28a4db 100644 --- a/examples/nlp/machine_translation/enc_dec_nmt.py +++ b/examples/nlp/machine_translation/enc_dec_nmt.py @@ -15,8 +15,8 @@ from dataclasses import dataclass from typing import Optional +from lightning.pytorch import Trainer from omegaconf import OmegaConf -from pytorch_lightning import Trainer from nemo.collections.nlp.data.machine_translation.preproc_mt_data import MTDataPreproc from nemo.collections.nlp.models.machine_translation.mt_enc_dec_config import MTEncDecModelConfig @@ -29,7 +29,6 @@ from nemo.utils.config_utils import update_model_config from nemo.utils.exp_manager import ExpManagerConfig, exp_manager - """ Usage: 1. If you need to start docker and install NeMo, otherwise skip this step: diff --git a/examples/nlp/machine_translation/enc_dec_nmt_finetune.py b/examples/nlp/machine_translation/enc_dec_nmt_finetune.py index 16a635d09dee..688461a7b491 100644 --- a/examples/nlp/machine_translation/enc_dec_nmt_finetune.py +++ b/examples/nlp/machine_translation/enc_dec_nmt_finetune.py @@ -15,9 +15,9 @@ from dataclasses import dataclass from typing import Optional +from lightning.pytorch import Trainer from omegaconf import OmegaConf from omegaconf.omegaconf import MISSING -from pytorch_lightning import Trainer from nemo.collections.nlp.models.machine_translation.mt_enc_dec_config import MTEncDecModelConfig from nemo.collections.nlp.models.machine_translation.mt_enc_dec_model import MTEncDecModel @@ -29,7 +29,6 @@ from nemo.utils.config_utils import update_model_config from nemo.utils.exp_manager import ExpManagerConfig, exp_manager - """ Usage: python enc_dec_nmt_finetune.py \ diff --git a/examples/nlp/machine_translation/megatron_nmt_training.py b/examples/nlp/machine_translation/megatron_nmt_training.py index 7946500f92e9..5ff70a7a863c 100644 --- a/examples/nlp/machine_translation/megatron_nmt_training.py +++ b/examples/nlp/machine_translation/megatron_nmt_training.py @@ -14,11 +14,11 @@ import torch.multiprocessing as mp +from lightning.pytorch import Trainer +from lightning.pytorch.callbacks import ModelSummary +from lightning.pytorch.plugins.environments import TorchElasticEnvironment +from lightning.pytorch.trainer.connectors.checkpoint_connector import _CheckpointConnector from omegaconf.omegaconf import OmegaConf, open_dict -from pytorch_lightning import Trainer -from pytorch_lightning.callbacks import ModelSummary -from pytorch_lightning.plugins.environments import TorchElasticEnvironment -from pytorch_lightning.trainer.connectors.checkpoint_connector import _CheckpointConnector from nemo.collections.nlp.models.language_modeling.megatron_bart_model import MegatronBARTModel from nemo.collections.nlp.models.language_modeling.megatron_t5_model import MegatronT5Model @@ -53,7 +53,7 @@ def main(cfg) -> None: scaler = None if cfg.trainer.precision in [16, '16', '16-mixed']: scaler = GradScaler( - init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32), + init_scale=cfg.model.get('native_amp_init_scale', 2**32), growth_interval=cfg.model.get('native_amp_growth_interval', 1000), hysteresis=cfg.model.get('hysteresis', 2), ) diff --git a/examples/nlp/machine_translation/nmt_transformer_infer_megatron.py b/examples/nlp/machine_translation/nmt_transformer_infer_megatron.py index fcf1fb8d1796..349155101a5d 100644 --- a/examples/nlp/machine_translation/nmt_transformer_infer_megatron.py +++ b/examples/nlp/machine_translation/nmt_transformer_infer_megatron.py @@ -24,8 +24,8 @@ import os +from lightning.pytorch.trainer.trainer import Trainer from omegaconf.omegaconf import OmegaConf, open_dict -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.models.machine_translation.megatron_nmt_model import MegatronNMTModel from nemo.collections.nlp.modules.common.megatron.megatron_init import fake_initialize_model_parallel diff --git a/examples/nlp/question_answering/question_answering.py b/examples/nlp/question_answering/question_answering.py index fcde03582e5c..37bd43a4b0fb 100644 --- a/examples/nlp/question_answering/question_answering.py +++ b/examples/nlp/question_answering/question_answering.py @@ -14,7 +14,7 @@ import os -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import DictConfig, OmegaConf from nemo.collections.nlp.models.question_answering.qa_bert_model import BERTQAModel diff --git a/examples/nlp/spellchecking_asr_customization/helpers.py b/examples/nlp/spellchecking_asr_customization/helpers.py index 2db11b0e7d96..8e3957d34cc1 100644 --- a/examples/nlp/spellchecking_asr_customization/helpers.py +++ b/examples/nlp/spellchecking_asr_customization/helpers.py @@ -16,7 +16,7 @@ import os from typing import Tuple -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import DictConfig from nemo.collections.nlp.models import SpellcheckingAsrCustomizationModel @@ -32,7 +32,7 @@ def instantiate_model_and_trainer( cfg: DictConfig, model_name: str, do_training: bool ) -> Tuple[pl.Trainer, SpellcheckingAsrCustomizationModel]: - """ Function for instantiating a model and a trainer + """Function for instantiating a model and a trainer Args: cfg: The config used to instantiate the model and the trainer. model_name: A str indicates the model direction, currently only 'itn'. diff --git a/examples/nlp/text2sparql/evaluate_text2sparql.py b/examples/nlp/text2sparql/evaluate_text2sparql.py index 52baa2a7e78c..774ced98e8ec 100644 --- a/examples/nlp/text2sparql/evaluate_text2sparql.py +++ b/examples/nlp/text2sparql/evaluate_text2sparql.py @@ -39,7 +39,7 @@ import os -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import DictConfig, OmegaConf from nemo.collections.nlp.models.text2sparql import Text2SparqlModel diff --git a/examples/nlp/text2sparql/text2sparql.py b/examples/nlp/text2sparql/text2sparql.py index 1353a3967735..d70a7e616950 100644 --- a/examples/nlp/text2sparql/text2sparql.py +++ b/examples/nlp/text2sparql/text2sparql.py @@ -88,7 +88,7 @@ exp_manager.exp_dir=./NeMo_logs """ -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import DictConfig, OmegaConf from nemo.collections.nlp.models.text2sparql import Text2SparqlModel diff --git a/examples/nlp/text_classification/model_parallel_text_classification_evaluation.py b/examples/nlp/text_classification/model_parallel_text_classification_evaluation.py index ab3322f552c1..cf9b6d8dd2e4 100644 --- a/examples/nlp/text_classification/model_parallel_text_classification_evaluation.py +++ b/examples/nlp/text_classification/model_parallel_text_classification_evaluation.py @@ -15,7 +15,7 @@ """ This script runs model parallel text classification evaluation. """ -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import DictConfig, OmegaConf from nemo.collections.nlp.models.text_classification import TextClassificationModel diff --git a/examples/nlp/text_classification/text_classification_with_bert.py b/examples/nlp/text_classification/text_classification_with_bert.py index 01e8fae9bba5..a6c84b4e337a 100644 --- a/examples/nlp/text_classification/text_classification_with_bert.py +++ b/examples/nlp/text_classification/text_classification_with_bert.py @@ -95,7 +95,7 @@ eval_model.set_trainer(eval_trainer) eval_trainer.test(model=eval_model, verbose=False) """ -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import DictConfig, OmegaConf from nemo.collections.nlp.models.text_classification import TextClassificationModel diff --git a/examples/nlp/text_normalization_as_tagging/helpers.py b/examples/nlp/text_normalization_as_tagging/helpers.py index 347b05b25fba..de74794f8f40 100644 --- a/examples/nlp/text_normalization_as_tagging/helpers.py +++ b/examples/nlp/text_normalization_as_tagging/helpers.py @@ -16,7 +16,7 @@ import os from typing import Tuple -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import DictConfig from nemo.collections.nlp.models import ThutmoseTaggerModel @@ -31,7 +31,7 @@ def instantiate_model_and_trainer( cfg: DictConfig, model_name: str, do_training: bool ) -> Tuple[pl.Trainer, ThutmoseTaggerModel]: - """ Function for instantiating a model and a trainer + """Function for instantiating a model and a trainer Args: cfg: The config used to instantiate the model and the trainer. model_name: A str indicates the model direction, currently only 'itn'. diff --git a/examples/nlp/token_classification/punctuation_capitalization_lexical_audio_train_evaluate.py b/examples/nlp/token_classification/punctuation_capitalization_lexical_audio_train_evaluate.py index 149a9a4515e2..508e434bb598 100644 --- a/examples/nlp/token_classification/punctuation_capitalization_lexical_audio_train_evaluate.py +++ b/examples/nlp/token_classification/punctuation_capitalization_lexical_audio_train_evaluate.py @@ -14,7 +14,7 @@ import os -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from omegaconf import DictConfig, OmegaConf diff --git a/examples/nlp/token_classification/punctuation_capitalization_train_evaluate.py b/examples/nlp/token_classification/punctuation_capitalization_train_evaluate.py index e983540a68b2..b16e1ecd0bdc 100644 --- a/examples/nlp/token_classification/punctuation_capitalization_train_evaluate.py +++ b/examples/nlp/token_classification/punctuation_capitalization_train_evaluate.py @@ -14,7 +14,7 @@ import os -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from omegaconf import DictConfig, OmegaConf diff --git a/examples/nlp/token_classification/token_classification_evaluate.py b/examples/nlp/token_classification/token_classification_evaluate.py index b69212f59de4..764aa90c8593 100644 --- a/examples/nlp/token_classification/token_classification_evaluate.py +++ b/examples/nlp/token_classification/token_classification_evaluate.py @@ -14,7 +14,7 @@ import os -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import DictConfig from nemo.collections.nlp.models import TokenClassificationModel diff --git a/examples/nlp/token_classification/token_classification_train.py b/examples/nlp/token_classification/token_classification_train.py index 56c1487cf9c5..536327aff6da 100644 --- a/examples/nlp/token_classification/token_classification_train.py +++ b/examples/nlp/token_classification/token_classification_train.py @@ -14,7 +14,7 @@ import os -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import DictConfig, OmegaConf from nemo.collections.nlp.models import TokenClassificationModel diff --git a/examples/nlp/zero_shot_intent_recognition/zero_shot_intent_train.py b/examples/nlp/zero_shot_intent_recognition/zero_shot_intent_train.py index 5b91049e965d..4dbbf01c935e 100644 --- a/examples/nlp/zero_shot_intent_recognition/zero_shot_intent_train.py +++ b/examples/nlp/zero_shot_intent_recognition/zero_shot_intent_train.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import DictConfig, OmegaConf from nemo.collections.nlp.models import ZeroShotIntentModel diff --git a/examples/slu/speech_intent_slot/eval_utils/inference.py b/examples/slu/speech_intent_slot/eval_utils/inference.py index 9bd76c76822d..241f6463ed76 100644 --- a/examples/slu/speech_intent_slot/eval_utils/inference.py +++ b/examples/slu/speech_intent_slot/eval_utils/inference.py @@ -21,7 +21,7 @@ from pathlib import Path from typing import List, Optional -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from omegaconf import OmegaConf from tqdm.auto import tqdm diff --git a/examples/slu/speech_intent_slot/speech_intent_slot_train.py b/examples/slu/speech_intent_slot/speech_intent_slot_train.py index a9999d4d4682..f8732ec757e1 100644 --- a/examples/slu/speech_intent_slot/speech_intent_slot_train.py +++ b/examples/slu/speech_intent_slot/speech_intent_slot_train.py @@ -66,7 +66,7 @@ from pathlib import Path -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from omegaconf import OmegaConf diff --git a/examples/speaker_tasks/diarization/clustering_diarizer/offline_diar_infer.py b/examples/speaker_tasks/diarization/clustering_diarizer/offline_diar_infer.py index 35077a5fe415..5c0f956c2e3c 100644 --- a/examples/speaker_tasks/diarization/clustering_diarizer/offline_diar_infer.py +++ b/examples/speaker_tasks/diarization/clustering_diarizer/offline_diar_infer.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from lightning.pytorch import seed_everything from omegaconf import OmegaConf -from pytorch_lightning import seed_everything from nemo.collections.asr.models import ClusteringDiarizer from nemo.core.config import hydra_runner diff --git a/examples/speaker_tasks/diarization/conf/neural_diarizer/sortformer_diarizer_hybrid_loss_4spk-v1.yaml b/examples/speaker_tasks/diarization/conf/neural_diarizer/sortformer_diarizer_hybrid_loss_4spk-v1.yaml new file mode 100644 index 000000000000..66cfc5fd1b61 --- /dev/null +++ b/examples/speaker_tasks/diarization/conf/neural_diarizer/sortformer_diarizer_hybrid_loss_4spk-v1.yaml @@ -0,0 +1,213 @@ +# Sortformer Diarizer is an end-to-end speaker diarization model that is solely based on Transformer-encoder type of architecture. +# Model name convention for Sortformer Diarizer: sortformer_diarizer__-.yaml +# (Example) `sortformer_diarizer_hybrid_loss_4spk-v1.yaml`. +# Sortformer Diarizer model checkpoint (.ckpt) and NeMo file (.nemo) contain Fast Conformer Encoder model (NEST Encoder) and the pre-trained NEST model is loaded along with the Transformer Encoder layers. +# Example: a manifest line for training +# {"audio_filepath": "/path/to/audio01.wav", "offset": 390.83, "duration": 90.00, "text": "-", "num_speakers": 2, "rttm_filepath": "/path/to/audio01.rttm"} +name: "SortFormerDiarizer" +num_workers: 18 +batch_size: 8 + +model: + sample_rate: 16000 + pil_weight: 0.5 # Weight for Permutation Invariant Loss (PIL) used in training the Sortformer diarizer model + ats_weight: 0.5 # Weight for Arrival Time Sort (ATS) loss in training the Sortformer diarizer model + max_num_of_spks: 4 # Maximum number of speakers per model; currently set to 4 + + model_defaults: + fc_d_model: 512 # Hidden dimension size of the Fast-conformer Encoder + tf_d_model: 192 # Hidden dimension size of the Transformer Encoder + + train_ds: + manifest_filepath: ??? + sample_rate: ${model.sample_rate} + num_spks: ${model.max_num_of_spks} + session_len_sec: 90 # Maximum session length in seconds + soft_label_thres: 0.5 # Threshold for binarizing target values; higher values make the model more conservative in predicting speaker activity. + soft_targets: False # If True, use continuous values as target values when calculating cross-entropy loss + labels: null + batch_size: ${batch_size} + shuffle: True + num_workers: ${num_workers} + validation_mode: False + # lhotse config + use_lhotse: False + use_bucketing: True + num_buckets: 10 + bucket_duration_bins: [10, 20, 30, 40, 50, 60, 70, 80, 90] + pin_memory: True + min_duration: 10 + max_duration: 90 + batch_duration: 400 + quadratic_duration: 1200 + bucket_buffer_size: 20000 + shuffle_buffer_size: 10000 + window_stride: ${model.preprocessor.window_stride} + subsampling_factor: ${model.encoder.subsampling_factor} + + validation_ds: + manifest_filepath: ??? + is_tarred: False + tarred_audio_filepaths: null + sample_rate: ${model.sample_rate} + num_spks: ${model.max_num_of_spks} + session_len_sec: 90 # Maximum session length in seconds + soft_label_thres: 0.5 # A threshold value for setting up the binarized labels. The higher the more conservative the model becomes. + soft_targets: False + labels: null + batch_size: ${batch_size} + shuffle: False + num_workers: ${num_workers} + validation_mode: True + # lhotse config + use_lhotse: False + use_bucketing: False + drop_last: False + pin_memory: True + window_stride: ${model.preprocessor.window_stride} + subsampling_factor: ${model.encoder.subsampling_factor} + + test_ds: + manifest_filepath: null + is_tarred: False + tarred_audio_filepaths: null + sample_rate: 16000 + num_spks: ${model.max_num_of_spks} + session_len_sec: 90 # Maximum session length in seconds + soft_label_thres: 0.5 + soft_targets: False + labels: null + batch_size: ${batch_size} + shuffle: False + seq_eval_mode: True + num_workers: ${num_workers} + validation_mode: True + # lhotse config + use_lhotse: False + use_bucketing: False + drop_last: False + pin_memory: True + window_stride: ${model.preprocessor.window_stride} + subsampling_factor: ${model.encoder.subsampling_factor} + + preprocessor: + _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor + normalize: "per_feature" + window_size: 0.025 + sample_rate: ${model.sample_rate} + window_stride: 0.01 + window: "hann" + features: 80 + n_fft: 512 + frame_splicing: 1 + dither: 0.00001 + + sortformer_modules: + _target_: nemo.collections.asr.modules.sortformer_modules.SortformerModules + num_spks: ${model.max_num_of_spks} # Number of speakers per model. This is currently fixed at 4. + dropout_rate: 0.5 # Dropout rate + fc_d_model: ${model.model_defaults.fc_d_model} + tf_d_model: ${model.model_defaults.tf_d_model} # Hidden layer size for linear layers in Sortformer Diarizer module + + encoder: + _target_: nemo.collections.asr.modules.ConformerEncoder + feat_in: ${model.preprocessor.features} + feat_out: -1 + n_layers: 18 + d_model: ${model.model_defaults.fc_d_model} + + # Sub-sampling parameters + subsampling: dw_striding # vggnet, striding, stacking or stacking_norm, dw_striding + subsampling_factor: 8 # must be power of 2 for striding and vggnet + subsampling_conv_channels: 256 # set to -1 to make it equal to the d_model + causal_downsampling: false + # Feed forward module's params + ff_expansion_factor: 4 + # Multi-headed Attention Module's params + self_attention_model: rel_pos # rel_pos or abs_pos + n_heads: 8 # may need to be lower for smaller d_models + # [left, right] specifies the number of steps to be seen from left and right of each step in self-attention + att_context_size: [-1, -1] # -1 means unlimited context + att_context_style: regular # regular or chunked_limited + xscaling: true # scales up the input embeddings by sqrt(d_model) + untie_biases: true # unties the biases of the TransformerXL layers + pos_emb_max_len: 5000 + # Convolution module's params + conv_kernel_size: 9 + conv_norm_type: 'batch_norm' # batch_norm or layer_norm or groupnormN (N specifies the number of groups) + conv_context_size: null + # Regularization + dropout: 0.1 # The dropout used in most of the Conformer Modules + dropout_pre_encoder: 0.1 # The dropout used before the encoder + dropout_emb: 0.0 # The dropout used for embeddings + dropout_att: 0.1 # The dropout for multi-headed attention modules + # Set to non-zero to enable stochastic depth + stochastic_depth_drop_prob: 0.0 + stochastic_depth_mode: linear # linear or uniform + stochastic_depth_start_layer: 1 + + transformer_encoder: + _target_: nemo.collections.asr.modules.transformer.transformer_encoders.TransformerEncoder + num_layers: 18 + hidden_size: ${model.model_defaults.tf_d_model} # Needs to be multiple of num_attention_heads + inner_size: 768 + num_attention_heads: 8 + attn_score_dropout: 0.5 + attn_layer_dropout: 0.5 + ffn_dropout: 0.5 + hidden_act: relu + pre_ln: False + pre_ln_final_layer_norm: True + + loss: + _target_: nemo.collections.asr.losses.bce_loss.BCELoss + weight: null # Weight for binary cross-entropy loss. Either `null` or list type input. (e.g. [0.5,0.5]) + reduction: mean + + lr: 0.0001 + optim: + name: adamw + lr: ${model.lr} + # optimizer arguments + betas: [0.9, 0.98] + weight_decay: 1e-3 + + sched: + name: InverseSquareRootAnnealing + warmup_steps: 2500 + warmup_ratio: null + min_lr: 1e-06 + +trainer: + devices: 1 # number of gpus (devices) + accelerator: gpu + max_epochs: 800 + max_steps: -1 # computed at runtime if not set + num_nodes: 1 + strategy: ddp_find_unused_parameters_true # Could be "ddp" + accumulate_grad_batches: 1 + deterministic: True + enable_checkpointing: False + logger: False + log_every_n_steps: 1 # Interval of logging. + val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations + +exp_manager: + use_datetime_version: False + exp_dir: null + name: ${name} + resume_if_exists: True + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + resume_ignore_no_checkpoint: True + create_tensorboard_logger: True + create_checkpoint_callback: True + create_wandb_logger: False + checkpoint_callback_params: + monitor: "val_f1_acc" + mode: "max" + save_top_k: 9 + every_n_epochs: 1 + wandb_logger_kwargs: + resume: True + name: null + project: null \ No newline at end of file diff --git a/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_callhome-part1.yaml b/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_callhome-part1.yaml new file mode 100644 index 000000000000..9b7a9701c4f2 --- /dev/null +++ b/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_callhome-part1.yaml @@ -0,0 +1,13 @@ +# Postprocessing parameters for timestamp outputs from speaker diarization models. +# This speaker diarization postprocessing scheme is inspired by the postprocessing procedure in the following paper: +# Medennikov, Ivan, et al. "Target-Speaker Voice Activity Detection: a Novel Approach for Multi-Speaker Diarization in a Dinner Party Scenario." (2020). +# These parameters were optimized with hybrid-loss trained Sortformer model introduced in https://arxiv.org/pdf/2409.06656. +# These parameters were optimized on CallHome Dataset from the NIST SRE 2000 Disc8, especially from the part1 (callhome1) specified in: Kaldi, “Kaldi x-vector recipe v2,” https://github.com/kaldi-asr/kaldi/blob/master/egs/callhome_diarization/v2/run.sh +# Trial 24682 finished with value: 0.10257785779242055 and parameters: {'onset': 0.53, 'offset': 0.49, 'pad_onset': 0.23, 'pad_offset': 0.01, 'min_duration_on': 0.42, 'min_duration_off': 0.34}. Best is trial 24682 with value: 0.10257785779242055. +parameters: + onset: 0.53 # Onset threshold for detecting the beginning and end of a speech + offset: 0.49 # Offset threshold for detecting the end of a speech + pad_onset: 0.23 # Adding durations before each speech segment + pad_offset: 0.01 # Adding durations after each speech segment + min_duration_on: 0.42 # Threshold for small non-speech deletion + min_duration_off: 0.34 # Threshold for short speech segment deletion \ No newline at end of file diff --git a/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_dihard3-dev.yaml b/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_dihard3-dev.yaml new file mode 100644 index 000000000000..ebf994c10f2e --- /dev/null +++ b/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_dihard3-dev.yaml @@ -0,0 +1,13 @@ +# Postprocessing parameters for timestamp outputs from speaker diarization models. +# This speaker diarization postprocessing scheme is inspired by the postprocessing procedure in the following paper: +# Medennikov, Ivan, et al. "Target-Speaker Voice Activity Detection: a Novel Approach for Multi-Speaker Diarization in a Dinner Party Scenario." (2020). +# These parameters were optimized with hybrid-loss trained Sortformer model introduced in https://arxiv.org/pdf/2409.06656. +# These parameters were optimized on the development split of DIHARD3 dataset (See https://arxiv.org/pdf/2012.01477). +# Trial 732 finished with value: 0.12171946949255649 and parameters: {'onset': 0.64, 'offset': 0.74, 'pad_onset': 0.06, 'pad_offset': 0.0, 'min_duration_on': 0.1, 'min_duration_off': 0.15}. Best is trial 732 with value: 0.12171946949255649. +parameters: + onset: 0.64 # Onset threshold for detecting the beginning and end of a speech + offset: 0.74 # Offset threshold for detecting the end of a speech + pad_onset: 0.06 # Adding durations before each speech segment + pad_offset: 0.0 # Adding durations after each speech segment + min_duration_on: 0.1 # Threshold for small non-speech deletion + min_duration_off: 0.15 # Threshold for short speech segment deletion \ No newline at end of file diff --git a/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py b/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py new file mode 100644 index 000000000000..1767a16cbe02 --- /dev/null +++ b/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py @@ -0,0 +1,443 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This script provides an inference and evaluation script for end-to-end speaker diarization models. +The performance of the diarization model is measured using the Diarization Error Rate (DER). +If you want to evaluate its performance, the manifest JSON file should contain the corresponding RTTM +(Rich Transcription Time Marked) file. +Please refer to the NeMo Library Documentation for more details on data preparation for diarization inference: +https://docs.nvidia.com/nemo-framework/user-guide/latest/nemotoolkit +/asr/speaker_diarization/datasets.html#data-preparation-for-inference + +Usage for diarization inference: + +The end-to-end speaker diarization model can be specified by either "model_path" or "pretrained_name". +Data for diarization is fed through the "dataset_manifest". +By default, post-processing is bypassed, and only binarization is performed. +If you want to reproduce DER scores reported on NeMo model cards, you need to apply post-processing steps. +Use batch_size = 1 to have the longest inference window and the highest possible accuracy. + +python $BASEPATH/neural_diarizer/e2e_diarize_speech.py \ + model_path=/path/to/diar_sortformer_4spk_v1.nemo \ + batch_size=1 \ + dataset_manifest=/path/to/diarization_manifest.json + +""" +import logging +import os +import tempfile +from dataclasses import dataclass, is_dataclass +from typing import Dict, List, Optional, Union + +import lightning.pytorch as pl +import optuna +import torch +import yaml +from omegaconf import OmegaConf +from pytorch_lightning import seed_everything +from tqdm import tqdm + +from nemo.collections.asr.metrics.der import score_labels +from nemo.collections.asr.models import SortformerEncLabelModel +from nemo.collections.asr.parts.utils.speaker_utils import audio_rttm_map, timestamps_to_pyannote_object +from nemo.collections.asr.parts.utils.vad_utils import ts_vad_post_processing +from nemo.core.config import hydra_runner + +seed_everything(42) +torch.backends.cudnn.deterministic = True + + +@dataclass +class PostProcessingParams: + """ + Postprocessing parameters for end-to-end speaker diarization models. + These parameters can significantly affect DER performance depending on the evaluation style and the dataset. + It is recommended to tune these parameters based on the evaluation style and the dataset + to achieve the desired DER performance. + """ + + onset: float = 0.5 # Onset threshold for detecting the beginning and end of a speech + offset: float = 0.5 # Offset threshold for detecting the end of a speech + pad_onset: float = 0.0 # Adding durations before each speech segment + pad_offset: float = 0.0 # Adding durations after each speech segment + min_duration_on: float = 0.0 # Threshold for small non-speech deletion + min_duration_off: float = 0.0 # Threshold for short speech segment deletion + + +@dataclass +class DiarizationConfig: + """Diarization configuration parameters for inference.""" + + model_path: Optional[str] = None # Path to a .nemo file + pretrained_name: Optional[str] = None # Name of a pretrained model + audio_dir: Optional[str] = None # Path to a directory which contains audio files + dataset_manifest: Optional[str] = None # Path to dataset's JSON manifest + + postprocessing_yaml: Optional[str] = None # Path to a yaml file for postprocessing configurations + no_der: bool = False + out_rttm_dir: Optional[str] = None + + # General configs + session_len_sec: float = -1 # End-to-end diarization session length in seconds + batch_size: int = 1 + num_workers: int = 0 + random_seed: Optional[int] = None # seed number going to be used in seed_everything() + bypass_postprocessing: bool = True # If True, postprocessing will be bypassed + + # Eval Settings: (0.25, False) should be default setting for sortformer eval. + collar: float = 0.25 # Collar in seconds for DER calculation + ignore_overlap: bool = False # If True, DER will be calculated only for non-overlapping segments + + # If `cuda` is a negative number, inference will be on CPU only. + cuda: Optional[int] = None + matmul_precision: str = "highest" # Literal["highest", "high", "medium"] + + # Optuna Config + launch_pp_optim: bool = False # If True, launch optimization process for postprocessing parameters + optuna_study_name: str = "optim_postprocessing" + optuna_temp_dir: str = "/tmp/optuna" + optuna_storage: str = f"sqlite:///{optuna_study_name}.db" + optuna_log_file: str = f"{optuna_study_name}.log" + optuna_n_trials: int = 100000 + + +def load_postprocessing_from_yaml(postprocessing_yaml: PostProcessingParams = None) -> PostProcessingParams: + """ + Load postprocessing parameters from a YAML file. + + Args: + postprocessing_yaml (str): + Path to a YAML file for postprocessing configurations. + + Returns: + postprocessing_params (dataclass): + Postprocessing parameters loaded from the YAML file. + """ + # Add PostProcessingParams as a field + postprocessing_params = OmegaConf.structured(PostProcessingParams()) + if postprocessing_yaml is None: + logging.info( + f"No postprocessing YAML file has been provided. Default postprocessing configurations will be applied." + ) + else: + # Load postprocessing params from the provided YAML file + with open(postprocessing_yaml, 'r') as file: + yaml_params = yaml.safe_load(file)['parameters'] + # Update the postprocessing_params with the loaded values + logging.info(f"Postprocessing YAML file '{postprocessing_yaml}' has been loaded.") + for key, value in yaml_params.items(): + if hasattr(postprocessing_params, key): + setattr(postprocessing_params, key, value) + return postprocessing_params + + +def optuna_suggest_params(postprocessing_cfg: PostProcessingParams, trial: optuna.Trial) -> PostProcessingParams: + """ + Suggests hyperparameters for postprocessing using Optuna. + See the following link for `trial` instance in Optuna framework. + https://optuna.readthedocs.io/en/stable/reference/generated/optuna.trial.Trial.html#optuna.trial.Trial + + Args: + postprocessing_cfg (PostProcessingParams): The current postprocessing configuration. + trial (optuna.Trial): The Optuna trial object used to suggest hyperparameters. + + Returns: + PostProcessingParams: The updated postprocessing configuration with suggested hyperparameters. + """ + postprocessing_cfg.onset = trial.suggest_float("onset", 0.4, 0.8, step=0.01) + postprocessing_cfg.offset = trial.suggest_float("offset", 0.4, 0.9, step=0.01) + postprocessing_cfg.pad_onset = trial.suggest_float("pad_onset", 0.1, 0.5, step=0.01) + postprocessing_cfg.pad_offset = trial.suggest_float("pad_offset", 0.0, 0.2, step=0.01) + postprocessing_cfg.min_duration_on = trial.suggest_float("min_duration_on", 0.0, 0.75, step=0.01) + postprocessing_cfg.min_duration_off = trial.suggest_float("min_duration_off", 0.0, 0.75, step=0.01) + return postprocessing_cfg + + +def get_tensor_path(cfg: DiarizationConfig) -> str: + """ + Constructs the file path for saving or loading prediction tensors based on the configuration. + + Args: + cfg (DiarizationConfig): The configuration object containing model and dataset details. + + Returns: + str: The constructed file path for the prediction tensor. + """ + tensor_filename = os.path.basename(cfg.dataset_manifest).replace("manifest.", "").replace(".json", "") + model_base_path = os.path.dirname(cfg.model_path) + model_id = os.path.basename(cfg.model_path).replace(".ckpt", "").replace(".nemo", "") + bpath = f"{model_base_path}/pred_tensors" + if not os.path.exists(bpath): + os.makedirs(bpath) + tensor_path = f"{bpath}/__{model_id}__{tensor_filename}.pt" + return tensor_path + + +def diarization_objective( + trial: optuna.Trial, + postprocessing_cfg: PostProcessingParams, + temp_out_dir: str, + infer_audio_rttm_dict: Dict[str, Dict[str, str]], + diar_model_preds_total_list: List[torch.Tensor], + collar: float = 0.25, + ignore_overlap: bool = False, +) -> float: + """ + Objective function for Optuna hyperparameter optimization in speaker diarization. + + This function evaluates the diarization performance using a set of postprocessing parameters + suggested by Optuna. It converts prediction matrices to time-stamp segments, scores the + diarization results, and returns the Diarization Error Rate (DER) as the optimization metric. + + Args: + trial (optuna.Trial): The Optuna trial object used to suggest hyperparameters. + postprocessing_cfg (PostProcessingParams): The current postprocessing configuration. + temp_out_dir (str): Temporary directory for storing intermediate outputs. + infer_audio_rttm_dict (Dict[str, Dict[str, str]]): Dictionary containing audio file paths, + offsets, durations, and RTTM file paths. + diar_model_preds_total_list (List[torch.Tensor]): List of prediction matrices containing + sigmoid values for each speaker. + Dimension: [(1, num_frames, num_speakers), ..., (1, num_frames, num_speakers)] + collar (float, optional): Collar in seconds for DER calculation. Defaults to 0.25. + ignore_overlap (bool, optional): If True, DER will be calculated only for non-overlapping segments. + Defaults to False. + + Returns: + float: The Diarization Error Rate (DER) for the given set of postprocessing parameters. + """ + with tempfile.TemporaryDirectory(dir=temp_out_dir, prefix="Diar_PostProcessing_") as local_temp_out_dir: + if trial is not None: + postprocessing_cfg = optuna_suggest_params(postprocessing_cfg, trial) + all_hyps, all_refs, all_uems = convert_pred_mat_to_segments( + audio_rttm_map_dict=infer_audio_rttm_dict, + postprocessing_cfg=postprocessing_cfg, + batch_preds_list=diar_model_preds_total_list, + unit_10ms_frame_count=8, + bypass_postprocessing=False, + ) + metric, mapping_dict, itemized_errors = score_labels( + AUDIO_RTTM_MAP=infer_audio_rttm_dict, + all_reference=all_refs, + all_hypothesis=all_hyps, + all_uem=all_uems, + collar=collar, + ignore_overlap=ignore_overlap, + ) + der = abs(metric) + return der + + +def run_optuna_hyperparam_search( + cfg: DiarizationConfig, # type: DiarizationConfig + postprocessing_cfg: PostProcessingParams, + infer_audio_rttm_dict: Dict[str, Dict[str, str]], + preds_list: List[torch.Tensor], + temp_out_dir: str, +): + """ + Run Optuna hyperparameter optimization for speaker diarization. + + Args: + cfg (DiarizationConfig): The configuration object containing model and dataset details. + postprocessing_cfg (PostProcessingParams): The current postprocessing configuration. + infer_audio_rttm_dict (dict): dictionary of audio file path, offset, duration and RTTM filepath. + preds_list (List[torch.Tensor]): list of prediction matrices containing sigmoid values for each speaker. + Dimension: [(1, num_frames, num_speakers), ..., (1, num_frames, num_speakers)] + temp_out_dir (str): temporary directory for storing intermediate outputs. + """ + worker_function = lambda trial: diarization_objective( + trial=trial, + postprocessing_cfg=postprocessing_cfg, + temp_out_dir=temp_out_dir, + infer_audio_rttm_dict=infer_audio_rttm_dict, + diar_model_preds_total_list=preds_list, + collar=cfg.collar, + ) + study = optuna.create_study( + direction="minimize", study_name=cfg.optuna_study_name, storage=cfg.optuna_storage, load_if_exists=True + ) + logger = logging.getLogger() + logger.setLevel(logging.INFO) # Setup the root logger. + if cfg.optuna_log_file is not None: + logger.addHandler(logging.FileHandler(cfg.optuna_log_file, mode="a")) + logger.addHandler(logging.StreamHandler()) + optuna.logging.enable_propagation() # Propagate logs to the root logger. + study.optimize(worker_function, n_trials=cfg.optuna_n_trials) + + +def convert_pred_mat_to_segments( + audio_rttm_map_dict: Dict[str, Dict[str, str]], + postprocessing_cfg, + batch_preds_list: List[torch.Tensor], + unit_10ms_frame_count: int = 8, + bypass_postprocessing: bool = False, + out_rttm_dir: str | None = None, +): + """ + Convert prediction matrix to time-stamp segments. + + Args: + audio_rttm_map_dict (dict): dictionary of audio file path, offset, duration and RTTM filepath. + batch_preds_list (List[torch.Tensor]): list of prediction matrices containing sigmoid values for each speaker. + Dimension: [(1, num_frames, num_speakers), ..., (1, num_frames, num_speakers)] + unit_10ms_frame_count (int, optional): number of 10ms segments in a frame. Defaults to 8. + bypass_postprocessing (bool, optional): if True, postprocessing will be bypassed. Defaults to False. + + Returns: + all_hypothesis (list): list of pyannote objects for each audio file. + all_reference (list): list of pyannote objects for each audio file. + all_uems (list): list of pyannote objects for each audio file. + """ + batch_pred_ts_segs, all_hypothesis, all_reference, all_uems = [], [], [], [] + cfg_vad_params = OmegaConf.structured(postprocessing_cfg) + pp_message = "Bypass PP, Running Binarization" if bypass_postprocessing else "Running post-processing" + for sample_idx, (uniq_id, audio_rttm_values) in tqdm( + enumerate(audio_rttm_map_dict.items()), total=len(audio_rttm_map_dict), desc=pp_message + ): + spk_ts = [] + offset, duration = audio_rttm_values['offset'], audio_rttm_values['duration'] + speaker_assign_mat = batch_preds_list[sample_idx].squeeze(dim=0) + speaker_timestamps = [[] for _ in range(speaker_assign_mat.shape[-1])] + for spk_id in range(speaker_assign_mat.shape[-1]): + ts_mat = ts_vad_post_processing( + speaker_assign_mat[:, spk_id], + cfg_vad_params=cfg_vad_params, + unit_10ms_frame_count=unit_10ms_frame_count, + bypass_postprocessing=bypass_postprocessing, + ) + ts_mat = ts_mat + offset + ts_mat = torch.clamp(ts_mat, min=offset, max=(offset + duration)) + ts_seg_list = ts_mat.tolist() + speaker_timestamps[spk_id].extend(ts_seg_list) + spk_ts.append(ts_seg_list) + all_hypothesis, all_reference, all_uems = timestamps_to_pyannote_object( + speaker_timestamps, + uniq_id, + audio_rttm_values, + all_hypothesis, + all_reference, + all_uems, + out_rttm_dir, + ) + batch_pred_ts_segs.append(spk_ts) + return all_hypothesis, all_reference, all_uems + + +@hydra_runner(config_name="DiarizationConfig", schema=DiarizationConfig) +def main(cfg: DiarizationConfig) -> Union[DiarizationConfig]: + """Main function for end-to-end speaker diarization inference.""" + for key in cfg: + cfg[key] = None if cfg[key] == 'None' else cfg[key] + + if is_dataclass(cfg): + cfg = OmegaConf.structured(cfg) + + if cfg.random_seed: + pl.seed_everything(cfg.random_seed) + + if cfg.model_path is None and cfg.pretrained_name is None: + raise ValueError("Both cfg.model_path and cfg.pretrained_name cannot be None!") + if cfg.audio_dir is None and cfg.dataset_manifest is None: + raise ValueError("Both cfg.audio_dir and cfg.dataset_manifest cannot be None!") + + # setup GPU + torch.set_float32_matmul_precision(cfg.matmul_precision) + if cfg.cuda is None: + if torch.cuda.is_available(): + device = [0] # use 0th CUDA device + accelerator = 'gpu' + map_location = torch.device('cuda:0') + else: + device = 1 + accelerator = 'cpu' + map_location = torch.device('cpu') + else: + device = [cfg.cuda] + accelerator = 'gpu' + map_location = torch.device(f'cuda:{cfg.cuda}') + + if cfg.model_path.endswith(".ckpt"): + diar_model = SortformerEncLabelModel.load_from_checkpoint( + checkpoint_path=cfg.model_path, map_location=map_location, strict=False + ) + elif cfg.model_path.endswith(".nemo"): + diar_model = SortformerEncLabelModel.restore_from(restore_path=cfg.model_path, map_location=map_location) + else: + raise ValueError("cfg.model_path must end with.ckpt or.nemo!") + + diar_model._cfg.test_ds.session_len_sec = cfg.session_len_sec + trainer = pl.Trainer(devices=device, accelerator=accelerator) + diar_model.set_trainer(trainer) + + diar_model = diar_model.eval() + diar_model._cfg.test_ds.manifest_filepath = cfg.dataset_manifest + infer_audio_rttm_dict = audio_rttm_map(cfg.dataset_manifest) + diar_model._cfg.test_ds.batch_size = cfg.batch_size + + # Model setup for inference + diar_model._cfg.test_ds.num_workers = cfg.num_workers + diar_model.setup_test_data(test_data_config=diar_model._cfg.test_ds) + + postprocessing_cfg = load_postprocessing_from_yaml(cfg.postprocessing_yaml) + tensor_path = get_tensor_path(cfg) + + if os.path.exists(tensor_path): + logging.info( + f"A saved prediction tensor has been found. Loading the saved prediction tensors from {tensor_path}..." + ) + diar_model_preds_total_list = torch.load(tensor_path) + else: + logging.info(f"No saved prediction tensors found. Running inference on the dataset...") + diar_model.test_batch() + diar_model_preds_total_list = diar_model.preds_total_list + torch.save(diar_model.preds_total_list, tensor_path) + + if cfg.launch_pp_optim: + # Launch a hyperparameter optimization process if launch_pp_optim is True + run_optuna_hyperparam_search( + cfg=cfg, + postprocessing_cfg=postprocessing_cfg, + infer_audio_rttm_dict=infer_audio_rttm_dict, + preds_list=diar_model_preds_total_list, + temp_out_dir=cfg.optuna_temp_dir, + ) + + # Evaluation + if not cfg.no_der: + if cfg.out_rttm_dir is not None and not os.path.exists(cfg.out_rttm_dir): + os.mkdir(cfg.out_rttm_dir) + all_hyps, all_refs, all_uems = convert_pred_mat_to_segments( + infer_audio_rttm_dict, + postprocessing_cfg=postprocessing_cfg, + batch_preds_list=diar_model_preds_total_list, + unit_10ms_frame_count=8, + bypass_postprocessing=cfg.bypass_postprocessing, + out_rttm_dir=cfg.out_rttm_dir, + ) + logging.info(f"Evaluating the model on the {len(diar_model_preds_total_list)} audio segments...") + score_labels( + AUDIO_RTTM_MAP=infer_audio_rttm_dict, + all_reference=all_refs, + all_hypothesis=all_hyps, + all_uem=all_uems, + collar=cfg.collar, + ignore_overlap=cfg.ignore_overlap, + ) + logging.info(f"PostProcessingParams: {postprocessing_cfg}") + + +if __name__ == '__main__': + main() diff --git a/examples/speaker_tasks/diarization/neural_diarizer/multiscale_diar_decoder.py b/examples/speaker_tasks/diarization/neural_diarizer/multiscale_diar_decoder.py index 984b5ce93464..bc1db4dc1126 100644 --- a/examples/speaker_tasks/diarization/neural_diarizer/multiscale_diar_decoder.py +++ b/examples/speaker_tasks/diarization/neural_diarizer/multiscale_diar_decoder.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl +from lightning.pytorch import seed_everything from omegaconf import OmegaConf -from pytorch_lightning import seed_everything from nemo.collections.asr.models import EncDecDiarLabelModel from nemo.core.config import hydra_runner diff --git a/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py b/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py new file mode 100644 index 000000000000..ab6e418b1072 --- /dev/null +++ b/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py @@ -0,0 +1,55 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import lightning.pytorch as pl +from omegaconf import OmegaConf +from pytorch_lightning import seed_everything + +from nemo.collections.asr.models import SortformerEncLabelModel +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.exp_manager import exp_manager + +""" +Example training session (single node training) + +python ./sortformer_diar_train.py --config-path='../conf/neural_diarizer' \ + --config-name='sortformer_diarizer_hybrid_loss_4spk-v1.yaml' \ + trainer.devices=1 \ + model.train_ds.manifest_filepath="" \ + model.validation_ds.manifest_filepath="" \ + exp_manager.name='sample_train' \ + exp_manager.exp_dir='./sortformer_diar_train' +""" + +seed_everything(42) + + +@hydra_runner(config_path="../conf/neural_diarizer", config_name="sortformer_diarizer_hybrid_loss_4spk-v1.yaml") +def main(cfg): + """Main function for training the sortformer diarizer model.""" + logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') + trainer = pl.Trainer(**cfg.trainer) + exp_manager(trainer, cfg.get("exp_manager", None)) + sortformer_model = SortformerEncLabelModel(cfg=cfg.model, trainer=trainer) + sortformer_model.maybe_init_from_pretrained_checkpoint(cfg) + trainer.fit(sortformer_model) + + if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.manifest_filepath is not None: + if sortformer_model.prepare_test(trainer): + trainer.test(sortformer_model) + + +if __name__ == '__main__': + main() diff --git a/examples/speaker_tasks/recognition/speaker_identification_infer.py b/examples/speaker_tasks/recognition/speaker_identification_infer.py index 90f930fcbfa6..7075a9f1f92a 100644 --- a/examples/speaker_tasks/recognition/speaker_identification_infer.py +++ b/examples/speaker_tasks/recognition/speaker_identification_infer.py @@ -16,8 +16,8 @@ import numpy as np import torch +from lightning.pytorch import seed_everything from omegaconf import OmegaConf -from pytorch_lightning import seed_everything from nemo.collections.asr.data.audio_to_label import AudioToSpeechLabelDataset from nemo.collections.asr.models import EncDecSpeakerLabelModel @@ -55,10 +55,18 @@ def main(cfg): speaker_model = EncDecSpeakerLabelModel.from_pretrained(model_path) enroll_embs, _, enroll_truelabels, _ = speaker_model.batch_inference( - enrollment_manifest, batch_size, sample_rate, device=device, + enrollment_manifest, + batch_size, + sample_rate, + device=device, ) - test_embs, _, _, _ = speaker_model.batch_inference(test_manifest, batch_size, sample_rate, device=device,) + test_embs, _, _, _ = speaker_model.batch_inference( + test_manifest, + batch_size, + sample_rate, + device=device, + ) # length normalize enroll_embs = enroll_embs / (np.linalg.norm(enroll_embs, ord=2, axis=-1, keepdims=True)) @@ -91,7 +99,12 @@ def main(cfg): "number of labels mis match. Make sure you trained or finetuned neural classifier with labels from enrollement manifest_filepath" ) - _, test_logits, _, _ = speaker_model.batch_inference(test_manifest, batch_size, sample_rate, device=device,) + _, test_logits, _, _ = speaker_model.batch_inference( + test_manifest, + batch_size, + sample_rate, + device=device, + ) matched_labels = test_logits.argmax(axis=-1) with open(test_manifest, 'rb') as f1, open(out_manifest, 'w', encoding='utf-8') as f2: diff --git a/examples/speaker_tasks/recognition/speaker_reco.py b/examples/speaker_tasks/recognition/speaker_reco.py index a8acd4de4a3f..ac5cb12ac836 100644 --- a/examples/speaker_tasks/recognition/speaker_reco.py +++ b/examples/speaker_tasks/recognition/speaker_reco.py @@ -14,10 +14,10 @@ import os -import pytorch_lightning as pl +import lightning.pytorch as pl import torch +from lightning.pytorch import seed_everything from omegaconf import OmegaConf -from pytorch_lightning import seed_everything from nemo.collections.asr.models import EncDecSpeakerLabelModel from nemo.core.config import hydra_runner diff --git a/examples/speaker_tasks/recognition/speaker_reco_finetune.py b/examples/speaker_tasks/recognition/speaker_reco_finetune.py index 884e5a60bc59..502d016a920d 100644 --- a/examples/speaker_tasks/recognition/speaker_reco_finetune.py +++ b/examples/speaker_tasks/recognition/speaker_reco_finetune.py @@ -14,10 +14,10 @@ import os -import pytorch_lightning as pl +import lightning.pytorch as pl import torch +from lightning.pytorch import seed_everything from omegaconf import OmegaConf -from pytorch_lightning import seed_everything from nemo.collections.asr.models import EncDecSpeakerLabelModel from nemo.core.config import hydra_runner diff --git a/examples/tts/aligner.py b/examples/tts/aligner.py index e32c0444ca68..939b8dbcf11f 100644 --- a/examples/tts/aligner.py +++ b/examples/tts/aligner.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl from nemo.collections.common.callbacks import LogEpochTimeCallback from nemo.collections.tts.models import AlignerModel diff --git a/examples/tts/audio_codec.py b/examples/tts/audio_codec.py index 5fc4b6fd0afd..d875a3037ba3 100644 --- a/examples/tts/audio_codec.py +++ b/examples/tts/audio_codec.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import OmegaConf from nemo.collections.tts.models import AudioCodecModel diff --git a/examples/tts/fastpitch.py b/examples/tts/fastpitch.py index a8e6ecdc902d..7fd584b773e4 100644 --- a/examples/tts/fastpitch.py +++ b/examples/tts/fastpitch.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl from nemo.collections.common.callbacks import LogEpochTimeCallback from nemo.collections.tts.models import FastPitchModel diff --git a/examples/tts/fastpitch_finetune.py b/examples/tts/fastpitch_finetune.py index 64b5e8b90625..9bdf704c514c 100644 --- a/examples/tts/fastpitch_finetune.py +++ b/examples/tts/fastpitch_finetune.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl from nemo.collections.common.callbacks import LogEpochTimeCallback from nemo.collections.tts.models import FastPitchModel diff --git a/examples/tts/fastpitch_finetune_adapters.py b/examples/tts/fastpitch_finetune_adapters.py index 1361d63fb4cf..9b50d70ab15e 100644 --- a/examples/tts/fastpitch_finetune_adapters.py +++ b/examples/tts/fastpitch_finetune_adapters.py @@ -15,7 +15,7 @@ import os from dataclasses import is_dataclass -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import DictConfig, OmegaConf, open_dict from nemo.collections.common.callbacks import LogEpochTimeCallback diff --git a/examples/tts/fastpitch_ssl.py b/examples/tts/fastpitch_ssl.py index 1101ac1eeaf7..b92983a4bfb1 100644 --- a/examples/tts/fastpitch_ssl.py +++ b/examples/tts/fastpitch_ssl.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl from nemo.collections.common.callbacks import LogEpochTimeCallback from nemo.collections.tts.models import fastpitch_ssl, hifigan diff --git a/examples/tts/g2p/g2p_heteronym_classification_inference.py b/examples/tts/g2p/g2p_heteronym_classification_inference.py index 61262c41a340..89a563e9b683 100644 --- a/examples/tts/g2p/g2p_heteronym_classification_inference.py +++ b/examples/tts/g2p/g2p_heteronym_classification_inference.py @@ -18,7 +18,7 @@ from dataclasses import dataclass, is_dataclass from typing import Optional -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from omegaconf import OmegaConf @@ -56,9 +56,9 @@ class TranscriptionConfig: # path to .json manifest inference, if not provided, interactive mode will be enabled manifest: Optional[str] = None # Path to .json manifest - output_manifest: Optional[ - str - ] = "predictions.json" # Path to .json manifest to save prediction, will be saved in "pred_text" field + output_manifest: Optional[str] = ( + "predictions.json" # Path to .json manifest to save prediction, will be saved in "pred_text" field + ) grapheme_field: str = "text_graphemes" # name of the field in .json manifest for input grapheme text # mapping from wordid predicted by the model to phonemes, e.g., @@ -132,9 +132,10 @@ def main(cfg): save_errors = True correct = 0 total = 0 - with open(cfg.output_manifest, "r", encoding="utf-8") as f_preds, open( - cfg.errors_file, "w", encoding="utf-8" - ) as f_errors: + with ( + open(cfg.output_manifest, "r", encoding="utf-8") as f_preds, + open(cfg.errors_file, "w", encoding="utf-8") as f_errors, + ): for line in f_preds: line = json.loads(line) predictions = line["pred_wordid"] diff --git a/examples/tts/g2p/g2p_heteronym_classification_train_and_evaluate.py b/examples/tts/g2p/g2p_heteronym_classification_train_and_evaluate.py index 613865618501..f86a0a3934e4 100644 --- a/examples/tts/g2p/g2p_heteronym_classification_train_and_evaluate.py +++ b/examples/tts/g2p/g2p_heteronym_classification_train_and_evaluate.py @@ -14,7 +14,7 @@ import os -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from nemo.collections.common.callbacks import LogEpochTimeCallback diff --git a/examples/tts/g2p/g2p_inference.py b/examples/tts/g2p/g2p_inference.py index e7bffa888653..a9da11fcffdb 100644 --- a/examples/tts/g2p/g2p_inference.py +++ b/examples/tts/g2p/g2p_inference.py @@ -16,7 +16,7 @@ from dataclasses import dataclass, is_dataclass from typing import Optional -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from omegaconf import OmegaConf from utils import get_metrics @@ -41,23 +41,23 @@ class TranscriptionConfig: # Required configs pretrained_model: str # Path to a .nemo file or Name of a pretrained model manifest_filepath: str # Path to .json manifest file - phoneme_field: Optional[ - str - ] = None # name of the field in manifest_filepath for ground truth phonemes, default during training "text" + phoneme_field: Optional[str] = ( + None # name of the field in manifest_filepath for ground truth phonemes, default during training "text" + ) grapheme_field: Optional[str] = "text_graphemes" # name of the field in manifest_filepath for input grapheme text # General configs - output_file: Optional[ - str - ] = None # Path to .json manifest file to save predictions, will be saved in "target_field" + output_file: Optional[str] = ( + None # Path to .json manifest file to save predictions, will be saved in "target_field" + ) pred_field: Optional[str] = "pred_text" # name of the field in the output_file to save predictions batch_size: int = 32 # Batch size to use for inference num_workers: int = 0 # Number of workers to use for DataLoader during inference # Config for heteronyms correction - pretrained_heteronyms_model: Optional[ - str - ] = None # Path to a .nemo file or a Name of a pretrained model to disambiguate heteronyms (Optional) + pretrained_heteronyms_model: Optional[str] = ( + None # Path to a .nemo file or a Name of a pretrained model to disambiguate heteronyms (Optional) + ) @hydra_runner(config_name="TranscriptionConfig", schema=TranscriptionConfig) diff --git a/examples/tts/g2p/g2p_train_and_evaluate.py b/examples/tts/g2p/g2p_train_and_evaluate.py index ff7b2b0675ea..319e1fb6a776 100644 --- a/examples/tts/g2p/g2p_train_and_evaluate.py +++ b/examples/tts/g2p/g2p_train_and_evaluate.py @@ -14,7 +14,7 @@ import os -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from utils import get_model diff --git a/examples/tts/hifigan.py b/examples/tts/hifigan.py index 5c3406a2f24c..6cf5c7a5aac4 100644 --- a/examples/tts/hifigan.py +++ b/examples/tts/hifigan.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl from nemo.collections.tts.models import HifiGanModel from nemo.core.config import hydra_runner diff --git a/examples/tts/hifigan_finetune.py b/examples/tts/hifigan_finetune.py index f0e2513404fd..328e1f423903 100644 --- a/examples/tts/hifigan_finetune.py +++ b/examples/tts/hifigan_finetune.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl from nemo.collections.tts.models import HifiGanModel from nemo.core.config import hydra_runner diff --git a/examples/tts/mixer_tts.py b/examples/tts/mixer_tts.py index 61a188f53969..53f55d93bcda 100644 --- a/examples/tts/mixer_tts.py +++ b/examples/tts/mixer_tts.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl from nemo.collections.common.callbacks import LogEpochTimeCallback from nemo.collections.tts.models import MixerTTSModel diff --git a/examples/tts/radtts.py b/examples/tts/radtts.py index 09bf69a2d6e5..4b3b0e62da87 100644 --- a/examples/tts/radtts.py +++ b/examples/tts/radtts.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl from nemo.collections.common.callbacks import LogEpochTimeCallback from nemo.collections.tts.models.radtts import RadTTSModel diff --git a/examples/tts/spectrogram_enhancer.py b/examples/tts/spectrogram_enhancer.py index 336729236d74..cd91ef3cb815 100644 --- a/examples/tts/spectrogram_enhancer.py +++ b/examples/tts/spectrogram_enhancer.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl from nemo.collections.tts.models import SpectrogramEnhancerModel from nemo.core.config import hydra_runner diff --git a/examples/tts/speechllm/conf/megatron_t5_speechllm_inference.yaml b/examples/tts/speechllm/conf/megatron_t5_speechllm_inference.yaml new file mode 100644 index 000000000000..8b37077bfdd5 --- /dev/null +++ b/examples/tts/speechllm/conf/megatron_t5_speechllm_inference.yaml @@ -0,0 +1,160 @@ +name: megatron_t5_speechllm_tts_inference +checkpoint_path: ??? + +trainer: + devices: 1 + accelerator: gpu + num_nodes: 1 + precision: 32 + logger: False + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: 10000 + max_steps: -1 + log_every_n_steps: 10 + val_check_interval: null + check_val_every_n_epoch: 3 + gradient_clip_val: 1.0 + +exp_manager: + exp_dir: null + name: ${name} + create_wandb_logger: False + resume_if_exists: False + resume_ignore_no_checkpoint: True + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: val_loss + save_top_k: 2 + mode: min + save_nemo_on_train_end: False # Should be false, correct prompt learning model file is saved at model.nemo_path set below + filename: "megatron_t5_speechllm_tts--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}" + model_parallel_size: ${model.tensor_model_parallel_size} + save_best_model: True + create_early_stopping_callback: False + early_stopping_callback_params: + monitor: "val_loss" + mode: "min" + min_delta: 0.001 + patience: 10 + verbose: True + +model: + seed: 1234 + nemo_path: ${name}.nemo # .nemo filename/absolute path to where the virtual prompt model parameters will be saved + virtual_prompt_style: "p-tuning" # one of 'prompt-tuning', 'p-tuning', or 'inference' + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + global_batch_size: 16 + micro_batch_size: 16 # micro batch size should equal global batch size when pipeline parallel = 1 + validation_global_batch_size: ${model.global_batch_size} + validation_micro_batch_size: ${model.micro_batch_size} + validation_drop_last: False + report_validation_metric: False + validation_metric: accuracy + num_speech_tokens: 10112 # Vocabulary size pertaining to speech + seq_pattern: "parallel" # parallel, delay_parallel, flatten + temperature: 0.85 # Temperature to be used for inference + top_k: 80 # Top k to be used for inference + max_inference_timesteps: 1000 # Maximum number of timesteps to run inference for + + restore_path: null # Path to an existing p-tuned/prompt tuned .nemo model you wish to add new tasks to or run inference with + language_model_path: ??? # Path to the pretrained T5 language model .nemo file, always required + save_nemo_on_validation_end: True # Saves an inference ready .nemo file every time a checkpoint is saved during training. + existing_tasks: [] + new_tasks: ["squad"] + codecmodel_type: nemo_codec + codecmodel_path: ??? + english_only_model: true + context_conditioning: decoder + use_flash_attention: false + lm_vocab_size: 30000 + task_templates: + - taskname: "squad" + prompt_template: "<|VIRTUAL_PROMPT_0|> {context} {question} {answer}" + total_virtual_tokens: 3 + virtual_token_splits: [3] + truncate_field: context + answer_field: answer + + p_tuning: # P-tuning specific params + encoder_type: "mlp" # Either "mlp" or "lstm", mlp is default + num_layers: 2 # 2 recommended for MLP, 1 recommended for LSTM, must be at least 2 for mlp + dropout: 0.0 + + prompt_tuning: # Prompt tunin specific params + new_prompt_init_methods: ['text'] # List of 'text' or 'random', should correspond to tasks listed in new tasks + new_prompt_init_text: ['some init text goes here'] # some init text if init method is text, or None if init method is random + + data: + grapheme_prefix: null + train_ds: null + validation_ds: null + test_ds: ??? + max_seq_length: 1536 + sample_rate: 24000 + add_eos: true + add_bos: false + decoder_starts_with_pad: False + add_eos_to_decoder_output: True + add_sentinel_to_input: True + ul2_prompt_token: null # , , + shuffle: true + num_workers: 4 + pin_memory: true + speech_offset: 30000 + train_task: all + sup_data_path: None + num_speech_codebooks: 8 + codebook_fps: 86 + context_duration_min: 2.9 + context_duration_max: 2.9 + context_slice_method: "fixed" + phoneme_probability: 1.0 + g2p: + english: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt" + heteronyms: "scripts/tts_dataset_files/heteronyms-052722" + phoneme_probability: 0.8 + ignore_ambiguous_words: False + use_chars: True + use_stresses: True + grapheme_prefix: ${model.data.grapheme_prefix} + spanish: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/es_ES/es_ES_nv230301.dict" + phoneme_probability: 0.8 + use_chars: True + use_stresses: True + ignore_ambiguous_words: False + grapheme_prefix: ${model.data.grapheme_prefix} + locale: "es-ES" + mandarin: + _target_: nemo.collections.tts.g2p.models.zh_cn_pinyin.ChineseG2p + phoneme_dict: "scripts/tts_dataset_files/zh/36finals/ipa_dict_nv23.05.txt" + word_segmenter: "jieba" + phoneme_prefix: "" + phoneme_case: "lower" + tone_prefix: "#" + ascii_letter_prefix: ${model.data.grapheme_prefix} + ascii_letter_case: "upper" + german: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/de/de_nv230119.dict" + heteronyms: "scripts/tts_dataset_files/de/de_nv230119.heteronym" + phoneme_probability: 0.8 + ignore_ambiguous_words: False + use_chars: True + use_stresses: True + grapheme_case: mixed + grapheme_prefix: ${model.data.grapheme_prefix} + locale: "de-DE" + + optim: + name: fused_adam + lr: 5e-5 + weight_decay: 0.01 + betas: + - 0.9 + - 0.98 \ No newline at end of file diff --git a/examples/tts/speechllm/conf/megatron_t5_speechllm_inference_model.yaml b/examples/tts/speechllm/conf/megatron_t5_speechllm_inference_model.yaml new file mode 100644 index 000000000000..1858edf9e667 --- /dev/null +++ b/examples/tts/speechllm/conf/megatron_t5_speechllm_inference_model.yaml @@ -0,0 +1,213 @@ +name: megatron_t5_speechllm_tts_inference +checkpoint_path: ??? + +trainer: + devices: 1 + accelerator: gpu + num_nodes: 1 + precision: 32 + logger: False + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: 10000 + max_steps: -1 + log_every_n_steps: 10 + val_check_interval: null + check_val_every_n_epoch: 3 + gradient_clip_val: 1.0 + +exp_manager: + exp_dir: null + name: ${name} + create_wandb_logger: False + resume_if_exists: False + resume_ignore_no_checkpoint: True + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: val_loss + save_top_k: 2 + mode: min + save_nemo_on_train_end: False # Should be false, correct prompt learning model file is saved at model.nemo_path set below + filename: "megatron_t5_speechllm_tts--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}" + model_parallel_size: ${model.tensor_model_parallel_size} + save_best_model: True + create_early_stopping_callback: False + early_stopping_callback_params: + monitor: "val_loss" + mode: "min" + min_delta: 0.001 + patience: 10 + verbose: True + +model: + seed: 1234 + nemo_path: ${name}.nemo # .nemo filename/absolute path to where the virtual prompt model parameters will be saved + virtual_prompt_style: "p-tuning" # one of 'prompt-tuning', 'p-tuning', or 'inference' + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + global_batch_size: 16 + micro_batch_size: 16 # micro batch size should equal global batch size when pipeline parallel = 1 + validation_global_batch_size: ${model.global_batch_size} + validation_micro_batch_size: ${model.micro_batch_size} + validation_drop_last: False + report_validation_metric: False + validation_metric: accuracy + num_speech_tokens: 10112 # Vocabulary size pertaining to speech + seq_pattern: "parallel" # parallel, delay_parallel, flatten + temperature: 0.85 # Temperature to be used for inference + top_k: 80 # Top k to be used for inference + max_inference_timesteps: 1000 # Maximum number of timesteps to run inference for + + restore_path: null # Path to an existing p-tuned/prompt tuned .nemo model you wish to add new tasks to or run inference with + save_nemo_on_validation_end: True # Saves an inference ready .nemo file every time a checkpoint is saved during training. + existing_tasks: [] + new_tasks: ["squad"] + codecmodel_type: nemo_codec + codecmodel_path: ??? + english_only_model: true + context_conditioning: decoder + train_from_scratch: true + override_tokenizer_vocab_file: ??? + use_flash_attention: false + lm_vocab_size: 30000 + + frozen_model: + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + pipeline_model_parallel_split_rank: 0 + make_vocab_size_divisible_by: 128 + pre_process: true + post_process: true + gradient_as_bucket_view: true + native_amp_init_scale: 4294967296 + native_amp_growth_interval: 1000 + fp16_lm_cross_entropy: false + seed: 1234 + use_cpu_initialization: false + apex_transformer_log_level: 30 + tokenizer: + library: megatron + type: BertWordPieceCase + model: null + vocab_file: null + merge_file: null + optim: + name: null + data: + dataset_type: t5 + encoder: + arch: transformer + bias_activation_fusion: false + use_flash_attention: ${model.use_flash_attention} + num_layers: 12 + hidden_size: 768 + ffn_hidden_size: 2048 + num_attention_heads: 12 + init_method_std: 0.015 + hidden_dropout: 0.1 + attention_dropout: 0.1 + kv_channels: 64 + activation: geglu + decoder: + arch: transformer + bias_activation_fusion: false + use_flash_attention: ${model.use_flash_attention} + num_layers: 12 + hidden_size: 768 + ffn_hidden_size: 2048 + num_attention_heads: 12 + init_method_std: 0.015 + hidden_dropout: 0.1 + attention_dropout: 0.1 + kv_channels: 64 + activation: geglu + + task_templates: + - taskname: "squad" + prompt_template: "<|VIRTUAL_PROMPT_0|> {context} {question} {answer}" + total_virtual_tokens: 3 + virtual_token_splits: [3] + truncate_field: context + answer_field: answer + + p_tuning: # P-tuning specific params + encoder_type: "mlp" # Either "mlp" or "lstm", mlp is default + num_layers: 2 # 2 recommended for MLP, 1 recommended for LSTM, must be at least 2 for mlp + dropout: 0.0 + + prompt_tuning: # Prompt tunin specific params + new_prompt_init_methods: ['text'] # List of 'text' or 'random', should correspond to tasks listed in new tasks + new_prompt_init_text: ['some init text goes here'] # some init text if init method is text, or None if init method is random + + data: + grapheme_prefix: null + train_ds: null + validation_ds: null + test_ds: ??? + max_seq_length: 1536 + sample_rate: 24000 + add_eos: true + add_bos: false + decoder_starts_with_pad: False + add_eos_to_decoder_output: True + add_sentinel_to_input: True + ul2_prompt_token: null # , , + shuffle: true + num_workers: 4 + pin_memory: true + speech_offset: 30000 + train_task: all + sup_data_path: None + num_speech_codebooks: 8 + codebook_fps: 86 + context_duration_min: 2.9 + context_duration_max: 2.9 + context_slice_method: "fixed" + phoneme_probability: 1.0 + g2p: + english: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt" + heteronyms: "scripts/tts_dataset_files/heteronyms-052722" + phoneme_probability: 0.8 + ignore_ambiguous_words: False + use_chars: True + use_stresses: True + grapheme_prefix: ${model.data.grapheme_prefix} + spanish: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/es_ES/es_ES_nv230301.dict" + phoneme_probability: 0.8 + use_chars: True + use_stresses: True + ignore_ambiguous_words: False + grapheme_prefix: ${model.data.grapheme_prefix} + locale: "es-ES" + mandarin: + _target_: nemo.collections.tts.g2p.models.zh_cn_pinyin.ChineseG2p + phoneme_dict: "scripts/tts_dataset_files/zh/36finals/ipa_dict_nv23.05.txt" + word_segmenter: "jieba" + phoneme_prefix: "" + phoneme_case: "lower" + tone_prefix: "#" + ascii_letter_prefix: ${model.data.grapheme_prefix} + ascii_letter_case: "upper" + german: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/de/de_nv230119.dict" + heteronyms: "scripts/tts_dataset_files/de/de_nv230119.heteronym" + phoneme_probability: 0.8 + ignore_ambiguous_words: False + use_chars: True + use_stresses: True + grapheme_case: mixed + grapheme_prefix: ${model.data.grapheme_prefix} + locale: "de-DE" + + optim: + name: fused_adam + lr: 5e-5 + weight_decay: 0.01 + betas: + - 0.9 + - 0.98 diff --git a/examples/tts/speechllm/conf/megatron_t5_speechllm_inference_multiencoder.yaml b/examples/tts/speechllm/conf/megatron_t5_speechllm_inference_multiencoder.yaml new file mode 100644 index 000000000000..8ad967d20538 --- /dev/null +++ b/examples/tts/speechllm/conf/megatron_t5_speechllm_inference_multiencoder.yaml @@ -0,0 +1,218 @@ +name: megatron_t5_speechllm +checkpoint_path: ??? + +trainer: + devices: 1 + accelerator: gpu + num_nodes: 1 + precision: 32 + logger: False + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: -1 + max_steps: 250000 + log_every_n_steps: 10 + val_check_interval: null + check_val_every_n_epoch: 1 + +exp_manager: + explicit_log_dir: null + exp_dir: null + name: ${name} + create_wandb_logger: False + resume_if_exists: False + resume_ignore_no_checkpoint: True + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: val_loss + save_top_k: 3 + mode: min + save_nemo_on_train_end: False # Should be false, correct prompt learning model file is saved at model.nemo_path set below + filename: "megatron_t5_speechllm_tts--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}" + model_parallel_size: ${model.tensor_model_parallel_size} + save_best_model: True + create_early_stopping_callback: False + early_stopping_callback_params: + monitor: "val_loss" + mode: "min" + min_delta: 0.001 + patience: 10 + verbose: True + +model: + seed: 1234 + nemo_path: ${name}.nemo # .nemo filename/absolute path to where the virtual prompt model parameters will be saved + virtual_prompt_style: "p-tuning" # one of 'prompt-tuning', 'p-tuning', or 'inference' + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + global_batch_size: 2 + micro_batch_size: 2 # micro batch size should equal global batch size when pipeline parallel = 1 + validation_global_batch_size: ${model.global_batch_size} + validation_micro_batch_size: ${model.micro_batch_size} + validation_drop_last: False + report_validation_metric: False + validation_metric: accuracy + num_speech_tokens: 10112 # Vocabulary size pertaining to speech + seq_pattern: "parallel" # parallel, delay_parallel, flatten + temperature: 0.85 # Temperature to be used for inference + top_k: 80 # Top k to be used for inference + max_inference_timesteps: 2000 # Maximum number of timesteps to run inference for + restore_path: null # Path to an existing p-tuned/prompt tuned .nemo model you wish to add new tasks to or run inference with + save_nemo_on_validation_end: True # Saves an inference ready .nemo file every time a checkpoint is saved during training. + existing_tasks: [] + new_tasks: ["squad"] + codecmodel_type: nemo_codec + codecmodel_path: ??? + english_only_model: true + context_conditioning: encoder + train_from_scratch: true + override_tokenizer_vocab_file: ??? + use_flash_attention: false + lm_vocab_size: 30000 + enc_output_to_layers: [[0,1,2],[3,4,5,6,7,8]] + + frozen_model: + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + pipeline_model_parallel_split_rank: 0 + make_vocab_size_divisible_by: 128 + pre_process: true + post_process: true + gradient_as_bucket_view: true + native_amp_init_scale: 4294967296 + native_amp_growth_interval: 1000 + fp16_lm_cross_entropy: false + seed: 1234 + use_cpu_initialization: false + apex_transformer_log_level: 30 + tokenizer: + library: megatron + type: BertWordPieceCase + model: null + vocab_file: null + merge_file: null + optim: + name: null + data: + dataset_type: t5 + encoder: + arch: multi_transformer + n_transformers: 2 + bias_activation_fusion: false + use_flash_attention: ${model.use_flash_attention} + num_layers: 6 + hidden_size: 768 + ffn_hidden_size: 2048 + num_attention_heads: 12 + init_method_std: 0.015 + hidden_dropout: 0.1 + attention_dropout: 0.1 + kv_channels: 64 + activation: geglu + decoder: + arch: transformer + bias_activation_fusion: false + use_flash_attention: ${model.use_flash_attention} + num_layers: 12 + hidden_size: 768 + ffn_hidden_size: 2048 + num_attention_heads: 12 + init_method_std: 0.015 + hidden_dropout: 0.1 + attention_dropout: 0.1 + kv_channels: 64 + activation: geglu + + task_templates: + - taskname: "squad" + prompt_template: "<|VIRTUAL_PROMPT_0|> {context} {question} {answer}" + total_virtual_tokens: 3 + virtual_token_splits: [3] + truncate_field: context + answer_field: answer + + p_tuning: # P-tuning specific params + encoder_type: "mlp" # Either "mlp" or "lstm", mlp is default + num_layers: 2 # 2 recommended for MLP, 1 recommended for LSTM, must be at least 2 for mlp + dropout: 0.0 + + prompt_tuning: # Prompt tunin specific params + new_prompt_init_methods: ['text'] # List of 'text' or 'random', should correspond to tasks listed in new tasks + new_prompt_init_text: ['some init text goes here'] # some init text if init method is text, or None if init method is random + + data: + grapheme_prefix: null + train_ds: null + validation_ds: null + test_ds: ??? + max_seq_length: 2048 + sample_rate: 24000 + add_eos: true + add_bos: false + use_attention_prior: false + attention_prior_scaling_factor: 0.05 + cross_attention_epsilon: 0.0 + decoder_starts_with_pad: False + add_eos_to_decoder_output: True + add_sentinel_to_input: True + ul2_prompt_token: null # , , + shuffle: true + num_workers: 4 + pin_memory: true + speech_offset: 30128 + train_task: all + sup_data_path: None + num_speech_codebooks: 8 + codebook_fps: 86 + context_duration_min: 2.9 + context_duration_max: 2.9 + context_slice_method: "fixed" + phoneme_probability: 1.0 + encoder_type: ${model.frozen_model.encoder.arch} + g2p: + english: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt" + heteronyms: "scripts/tts_dataset_files/heteronyms-052722" + phoneme_probability: 0.8 + ignore_ambiguous_words: False + use_chars: True + use_stresses: True + grapheme_prefix: ${model.data.grapheme_prefix} + spanish: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/es_ES/es_ES_nv230301.dict" + phoneme_probability: 0.8 + use_chars: True + use_stresses: True + ignore_ambiguous_words: False + grapheme_prefix: ${model.data.grapheme_prefix} + locale: "es-ES" + mandarin: + _target_: nemo.collections.tts.g2p.models.zh_cn_pinyin.ChineseG2p + phoneme_dict: "scripts/tts_dataset_files/zh/36finals/ipa_dict_nv23.05.txt" + word_segmenter: "jieba" + phoneme_prefix: "" + phoneme_case: "lower" + tone_prefix: "#" + ascii_letter_prefix: ${model.data.grapheme_prefix} + ascii_letter_case: "upper" + german: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/de/de_nv230119.dict" + heteronyms: "scripts/tts_dataset_files/de/de_nv230119.heteronym" + phoneme_probability: 0.8 + ignore_ambiguous_words: False + use_chars: True + use_stresses: True + grapheme_case: mixed + grapheme_prefix: ${model.data.grapheme_prefix} + locale: "de-DE" + + optim: + name: fused_adam + lr: 1e-4 + weight_decay: 0.01 + betas: + - 0.9 + - 0.98 diff --git a/examples/tts/speechllm/conf/megatron_t5_speechllm_medium.yaml b/examples/tts/speechllm/conf/megatron_t5_speechllm_medium.yaml new file mode 100644 index 000000000000..bd31f0712fdf --- /dev/null +++ b/examples/tts/speechllm/conf/megatron_t5_speechllm_medium.yaml @@ -0,0 +1,161 @@ +name: megatron_t5_speechllm_medium + +trainer: + devices: 1 + accelerator: gpu + num_nodes: 1 + precision: 32 + logger: False + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: -1 + max_steps: 1000000 + log_every_n_steps: 10 + val_check_interval: null + check_val_every_n_epoch: 1 + +exp_manager: + explicit_log_dir: null + exp_dir: null + name: ${name} + create_wandb_logger: False + resume_if_exists: True + resume_ignore_no_checkpoint: True + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: val_loss + save_top_k: 3 + mode: min + save_nemo_on_train_end: False + filename: "megatron_t5_speechllm_tts--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}" + model_parallel_size: ${model.tensor_model_parallel_size} + save_best_model: True + create_early_stopping_callback: False + early_stopping_callback_params: + monitor: "val_loss" + mode: "min" + min_delta: 0.001 + patience: 10 + verbose: True + +model: + seed: 1234 + nemo_path: ${name}.nemo # .nemo filename/absolute path to where the virtual prompt model parameters will be saved + virtual_prompt_style: "p-tuning" # one of 'prompt-tuning', 'p-tuning', or 'inference' + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + global_batch_size: 2 + micro_batch_size: 2 # micro batch size should equal global batch size when pipeline parallel = 1 + validation_global_batch_size: ${model.global_batch_size} + validation_micro_batch_size: ${model.micro_batch_size} + validation_drop_last: False + report_validation_metric: False + validation_metric: accuracy + num_speech_tokens: 10112 # Vocabulary size pertaining to speech + seq_pattern: "parallel" # parallel, delay_parallel, flatten + attn_prior_scaledown_start_step: 10000 + attn_prior_end_step: 11000 + return_all_crossattention_probs: True + num_cross_attention_heads: 12 # 12 for 220m, 16 for 3b. + restore_path: null # Path to an existing p-tuned/prompt tuned .nemo model you wish to add new tasks to or run inference with + language_model_path: ??? # Path to the pretrained T5 language model .nemo file, always required + save_nemo_on_validation_end: True # Saves an inference ready .nemo file every time a checkpoint is saved during training. + existing_tasks: [] + new_tasks: ["squad"] + freeze_model: false + use_alignment_loss: true + codecmodel_type: nemo_codec + codecmodel_path: ??? + english_only_model: true + context_conditioning: decoder + use_flash_attention: false + lm_vocab_size: 30000 + + task_templates: + - taskname: "squad" + prompt_template: "<|VIRTUAL_PROMPT_0|> {context} {question} {answer}" + total_virtual_tokens: 3 + virtual_token_splits: [3] + truncate_field: context + answer_field: answer + + p_tuning: # P-tuning specific params + encoder_type: "mlp" # Either "mlp" or "lstm", mlp is default + num_layers: 2 # 2 recommended for MLP, 1 recommended for LSTM, must be at least 2 for mlp + dropout: 0.0 + + prompt_tuning: # Prompt tunin specific params + new_prompt_init_methods: ['text'] # List of 'text' or 'random', should correspond to tasks listed in new tasks + new_prompt_init_text: ['some init text goes here'] # some init text if init method is text, or None if init method is random + + data: + grapheme_prefix: null + train_ds: ??? + validation_ds: ??? + max_seq_length: 2048 + sample_rate: 24000 + add_eos: true + add_bos: false + use_attention_prior: true + attention_prior_scaling_factor: 0.05 + cross_attention_epsilon: 0.0 + decoder_starts_with_pad: False + add_eos_to_decoder_output: True + add_sentinel_to_input: True + ul2_prompt_token: null # , , + shuffle: true + num_workers: 4 + pin_memory: true + speech_offset: 30128 + train_task: all + num_speech_codebooks: 8 + codebook_fps: 86 + context_duration_min: 2.9 + context_duration_max: 2.9 + g2p: + english: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt" + heteronyms: "scripts/tts_dataset_files/heteronyms-052722" + phoneme_probability: 0.8 + ignore_ambiguous_words: False + use_chars: True + use_stresses: True + grapheme_prefix: ${model.data.grapheme_prefix} + spanish: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/es_ES/es_ES_nv230301.dict" + phoneme_probability: 0.8 + use_chars: True + use_stresses: True + ignore_ambiguous_words: False + grapheme_prefix: ${model.data.grapheme_prefix} + locale: "es-ES" + mandarin: + _target_: nemo.collections.tts.g2p.models.zh_cn_pinyin.ChineseG2p + phoneme_dict: "scripts/tts_dataset_files/zh/36finals/ipa_dict_nv23.05.txt" + word_segmenter: "jieba" + phoneme_prefix: "" + phoneme_case: "lower" + tone_prefix: "#" + ascii_letter_prefix: ${model.data.grapheme_prefix} + ascii_letter_case: "upper" + german: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/de/de_nv230119.dict" + heteronyms: "scripts/tts_dataset_files/de/de_nv230119.heteronym" + phoneme_probability: 0.8 + ignore_ambiguous_words: False + use_chars: True + use_stresses: True + grapheme_case: mixed + grapheme_prefix: ${model.data.grapheme_prefix} + locale: "de-DE" + + optim: + name: fused_adam + lr: 1e-4 + weight_decay: 0.01 + betas: + - 0.9 + - 0.98 \ No newline at end of file diff --git a/examples/tts/speechllm/conf/megatron_t5_speechllm_multiencoder.yaml b/examples/tts/speechllm/conf/megatron_t5_speechllm_multiencoder.yaml new file mode 100644 index 000000000000..bf3f65ff9e00 --- /dev/null +++ b/examples/tts/speechllm/conf/megatron_t5_speechllm_multiencoder.yaml @@ -0,0 +1,223 @@ +name: megatron_t5_speechllm + +trainer: + devices: 1 + accelerator: gpu + num_nodes: 1 + precision: 32 + logger: False + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: -1 + max_steps: 250000 + log_every_n_steps: 10 + val_check_interval: null + check_val_every_n_epoch: 1 + +exp_manager: + explicit_log_dir: null + exp_dir: null + name: ${name} + create_wandb_logger: False + resume_if_exists: True + resume_ignore_no_checkpoint: True + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: val_loss + save_top_k: 3 + mode: min + save_nemo_on_train_end: False # Should be false, correct prompt learning model file is saved at model.nemo_path set below + filename: "megatron_t5_speechllm_tts--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}" + model_parallel_size: ${model.tensor_model_parallel_size} + save_best_model: True + create_early_stopping_callback: False + early_stopping_callback_params: + monitor: "val_loss" + mode: "min" + min_delta: 0.001 + patience: 10 + verbose: True + +model: + seed: 1234 + nemo_path: ${name}.nemo # .nemo filename/absolute path to where the virtual prompt model parameters will be saved + virtual_prompt_style: "p-tuning" # one of 'prompt-tuning', 'p-tuning', or 'inference' + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + global_batch_size: 2 + micro_batch_size: 2 # micro batch size should equal global batch size when pipeline parallel = 1 + validation_global_batch_size: ${model.global_batch_size} + validation_micro_batch_size: ${model.micro_batch_size} + validation_drop_last: False + report_validation_metric: False + validation_metric: accuracy + num_speech_tokens: 10112 # Vocabulary size pertaining to speech + seq_pattern: "parallel" # parallel, delay_parallel, flatten + attn_prior_scaledown_start_step: 10000 + attn_prior_end_step: 11000 + return_all_crossattention_probs: True + num_cross_attention_heads: 12 # 12 for 220m, 16 for 3b. + restore_path: null # Path to an existing p-tuned/prompt tuned .nemo model you wish to add new tasks to or run inference with + save_nemo_on_validation_end: True # Saves an inference ready .nemo file every time a checkpoint is saved during training. + existing_tasks: [] + new_tasks: ["squad"] + freeze_model: false + use_alignment_loss: true + codecmodel_type: nemo_codec + codecmodel_path: ??? + english_only_model: true + context_conditioning: encoder + train_from_scratch: true + override_tokenizer_vocab_file: ??? + use_flash_attention: false + lm_vocab_size: 30000 + enc_output_to_layers: [[0,1,2],[3,4,5,6,7,8]] + + frozen_model: + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + pipeline_model_parallel_split_rank: 0 + make_vocab_size_divisible_by: 128 + pre_process: true + post_process: true + gradient_as_bucket_view: true + native_amp_init_scale: 4294967296 + native_amp_growth_interval: 1000 + fp16_lm_cross_entropy: false + seed: 1234 + use_cpu_initialization: false + apex_transformer_log_level: 30 + tokenizer: + library: megatron + type: BertWordPieceCase + model: null + vocab_file: null + merge_file: null + optim: + name: null + data: + dataset_type: t5 + encoder: + arch: multi_transformer + n_transformers: 2 + bias_activation_fusion: false + use_flash_attention: ${model.use_flash_attention} + num_layers: 6 + hidden_size: 768 + ffn_hidden_size: 2048 + num_attention_heads: 12 + init_method_std: 0.015 + hidden_dropout: 0.1 + attention_dropout: 0.1 + kv_channels: 64 + activation: geglu + decoder: + arch: transformer + bias_activation_fusion: false + use_flash_attention: ${model.use_flash_attention} + num_layers: 12 + hidden_size: 768 + ffn_hidden_size: 2048 + num_attention_heads: 12 + init_method_std: 0.015 + hidden_dropout: 0.1 + attention_dropout: 0.1 + kv_channels: 64 + activation: geglu + + task_templates: + - taskname: "squad" + prompt_template: "<|VIRTUAL_PROMPT_0|> {context} {question} {answer}" + total_virtual_tokens: 3 + virtual_token_splits: [3] + truncate_field: context + answer_field: answer + + p_tuning: # P-tuning specific params + encoder_type: "mlp" # Either "mlp" or "lstm", mlp is default + num_layers: 2 # 2 recommended for MLP, 1 recommended for LSTM, must be at least 2 for mlp + dropout: 0.0 + + prompt_tuning: # Prompt tunin specific params + new_prompt_init_methods: ['text'] # List of 'text' or 'random', should correspond to tasks listed in new tasks + new_prompt_init_text: ['some init text goes here'] # some init text if init method is text, or None if init method is random + + data: + grapheme_prefix: null + train_ds: ??? + validation_ds: ??? + max_seq_length: 2048 + sample_rate: 24000 + add_eos: true + add_bos: false + use_attention_prior: true + attention_prior_scaling_factor: 0.05 + cross_attention_epsilon: 0.0 + decoder_starts_with_pad: False + add_eos_to_decoder_output: True + add_sentinel_to_input: True + ul2_prompt_token: null # , , + shuffle: true + num_workers: 4 + pin_memory: true + speech_offset: 30128 + train_task: all + num_speech_codebooks: 8 + codebook_fps: 86 + context_duration_min: 2.9 + context_duration_max: 2.9 + encoder_type: ${model.frozen_model.encoder.arch} + g2p: + english: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt" + heteronyms: "scripts/tts_dataset_files/heteronyms-052722" + phoneme_probability: 0.8 + ignore_ambiguous_words: False + use_chars: True + use_stresses: True + grapheme_prefix: ${model.data.grapheme_prefix} + spanish: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/es_ES/es_ES_nv230301.dict" + phoneme_probability: 0.8 + use_chars: True + use_stresses: True + ignore_ambiguous_words: False + grapheme_prefix: ${model.data.grapheme_prefix} + locale: "es-ES" + mandarin: + _target_: nemo.collections.tts.g2p.models.zh_cn_pinyin.ChineseG2p + phoneme_dict: "scripts/tts_dataset_files/zh/36finals/ipa_dict_nv23.05.txt" + word_segmenter: "jieba" + phoneme_prefix: "" + phoneme_case: "lower" + tone_prefix: "#" + ascii_letter_prefix: ${model.data.grapheme_prefix} + ascii_letter_case: "upper" + german: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/de/de_nv230119.dict" + heteronyms: "scripts/tts_dataset_files/de/de_nv230119.heteronym" + phoneme_probability: 0.8 + ignore_ambiguous_words: False + use_chars: True + use_stresses: True + grapheme_case: mixed + grapheme_prefix: ${model.data.grapheme_prefix} + locale: "de-DE" + + optim: + name: fused_adam + lr: 1e-4 + weight_decay: 0.01 + betas: + - 0.9 + - 0.98 + sched: + name: CosineAnnealing + warmup_steps: 1000 + constant_steps: 0 + min_lr: 1e-5 + monitor: val_loss + reduce_on_plateau: false diff --git a/examples/tts/speechllm/conf/megatron_t5_speechlm_model.yaml b/examples/tts/speechllm/conf/megatron_t5_speechlm_model.yaml new file mode 100644 index 000000000000..d69bfb979182 --- /dev/null +++ b/examples/tts/speechllm/conf/megatron_t5_speechlm_model.yaml @@ -0,0 +1,221 @@ +name: megatron_t5_speechllm + +trainer: + devices: 1 + accelerator: gpu + num_nodes: 1 + precision: 32 + logger: False + enable_checkpointing: False + use_distributed_sampler: False + max_epochs: -1 + max_steps: 250000 + log_every_n_steps: 10 + val_check_interval: null + check_val_every_n_epoch: 1 + +exp_manager: + explicit_log_dir: null + exp_dir: null + name: ${name} + create_wandb_logger: False + resume_if_exists: True + resume_ignore_no_checkpoint: True + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: val_loss + save_top_k: 3 + mode: min + save_nemo_on_train_end: False # Should be false, correct prompt learning model file is saved at model.nemo_path set below + filename: "megatron_t5_speechllm_tts--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}" + model_parallel_size: ${model.tensor_model_parallel_size} + save_best_model: True + create_early_stopping_callback: False + early_stopping_callback_params: + monitor: "val_loss" + mode: "min" + min_delta: 0.001 + patience: 10 + verbose: True + +model: + seed: 1234 + nemo_path: ${name}.nemo # .nemo filename/absolute path to where the virtual prompt model parameters will be saved + virtual_prompt_style: "p-tuning" # one of 'prompt-tuning', 'p-tuning', or 'inference' + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + global_batch_size: 2 + micro_batch_size: 2 # micro batch size should equal global batch size when pipeline parallel = 1 + validation_global_batch_size: ${model.global_batch_size} + validation_micro_batch_size: ${model.micro_batch_size} + validation_drop_last: False + report_validation_metric: False + validation_metric: accuracy + num_speech_tokens: 10112 # Vocabulary size pertaining to speech + seq_pattern: "parallel" # parallel, delay_parallel, flatten + attn_prior_scaledown_start_step: 10000 + attn_prior_end_step: 11000 + return_all_crossattention_probs: True + num_cross_attention_heads: 12 # 12 for 220m, 16 for 3b. + restore_path: null # Path to an existing p-tuned/prompt tuned .nemo model you wish to add new tasks to or run inference with + save_nemo_on_validation_end: True # Saves an inference ready .nemo file every time a checkpoint is saved during training. + existing_tasks: [] + new_tasks: ["squad"] + freeze_model: false + use_alignment_loss: true + codecmodel_type: nemo_codec + codecmodel_path: ??? + english_only_model: true + context_conditioning: decoder + train_from_scratch: true + override_tokenizer_vocab_file: ??? + use_flash_attention: false + lm_vocab_size: 30000 + + frozen_model: + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + pipeline_model_parallel_split_rank: 0 + make_vocab_size_divisible_by: 128 + pre_process: true + post_process: true + gradient_as_bucket_view: true + native_amp_init_scale: 4294967296 + native_amp_growth_interval: 1000 + fp16_lm_cross_entropy: false + seed: 1234 + use_cpu_initialization: false + apex_transformer_log_level: 30 + tokenizer: + library: megatron + type: BertWordPieceCase + model: null + vocab_file: null + merge_file: null + optim: + name: null + data: + dataset_type: t5 + encoder: + arch: transformer + bias_activation_fusion: false + use_flash_attention: ${model.use_flash_attention} + num_layers: 12 + hidden_size: 768 + ffn_hidden_size: 2048 + num_attention_heads: 12 + init_method_std: 0.015 + hidden_dropout: 0.1 + attention_dropout: 0.1 + kv_channels: 64 + activation: geglu + decoder: + arch: transformer + bias_activation_fusion: false + use_flash_attention: ${model.use_flash_attention} + num_layers: 12 + hidden_size: 768 + ffn_hidden_size: 2048 + num_attention_heads: 12 + init_method_std: 0.015 + hidden_dropout: 0.1 + attention_dropout: 0.1 + kv_channels: 64 + activation: geglu + + task_templates: + - taskname: "squad" + prompt_template: "<|VIRTUAL_PROMPT_0|> {context} {question} {answer}" + total_virtual_tokens: 3 + virtual_token_splits: [3] + truncate_field: context + answer_field: answer + + p_tuning: # P-tuning specific params + encoder_type: "mlp" # Either "mlp" or "lstm", mlp is default + num_layers: 2 # 2 recommended for MLP, 1 recommended for LSTM, must be at least 2 for mlp + dropout: 0.0 + + prompt_tuning: # Prompt tunin specific params + new_prompt_init_methods: ['text'] # List of 'text' or 'random', should correspond to tasks listed in new tasks + new_prompt_init_text: ['some init text goes here'] # some init text if init method is text, or None if init method is random + + data: + use_ipa: false + grapheme_prefix: null + train_ds: ??? + validation_ds: ??? + max_seq_length: 2048 + sample_rate: 24000 + add_eos: true + add_bos: false + use_attention_prior: true + attention_prior_scaling_factor: 0.05 + cross_attention_epsilon: 0.0 + decoder_starts_with_pad: False + add_eos_to_decoder_output: True + add_sentinel_to_input: True + ul2_prompt_token: null # , , + shuffle: true + num_workers: 4 + pin_memory: true + speech_offset: 30128 + train_task: all + num_speech_codebooks: 8 + codebook_fps: 86 + context_duration_min: 2.9 + context_duration_max: 2.9 + g2p: + english: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt" + heteronyms: "scripts/tts_dataset_files/heteronyms-052722" + phoneme_probability: 0.8 + ignore_ambiguous_words: False + use_chars: True + use_stresses: True + grapheme_prefix: ${model.data.grapheme_prefix} + spanish: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/es_ES/es_ES_nv230301.dict" + phoneme_probability: 0.8 + use_chars: True + use_stresses: True + ignore_ambiguous_words: False + grapheme_prefix: ${model.data.grapheme_prefix} + locale: "es-ES" + mandarin: + _target_: nemo.collections.tts.g2p.models.zh_cn_pinyin.ChineseG2p + phoneme_dict: "scripts/tts_dataset_files/zh/36finals/ipa_dict_nv23.05.txt" + word_segmenter: "jieba" + phoneme_prefix: "" + phoneme_case: "lower" + tone_prefix: "#" + ascii_letter_prefix: ${model.data.grapheme_prefix} + ascii_letter_case: "upper" + german: + _target_: nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p + phoneme_dict: "scripts/tts_dataset_files/de/de_nv230119.dict" + heteronyms: "scripts/tts_dataset_files/de/de_nv230119.heteronym" + phoneme_probability: 0.8 + ignore_ambiguous_words: False + use_chars: True + use_stresses: True + grapheme_case: mixed + grapheme_prefix: ${model.data.grapheme_prefix} + locale: "de-DE" + + optim: + name: fused_adam + lr: 1e-4 + weight_decay: 0.01 + betas: + - 0.9 + - 0.98 + sched: + name: CosineAnnealing + warmup_steps: 1000 + constant_steps: 0 + min_lr: 1e-5 + monitor: val_loss + reduce_on_plateau: false diff --git a/examples/tts/speechllm/megatron_t5_speechllm.py b/examples/tts/speechllm/megatron_t5_speechllm.py new file mode 100644 index 000000000000..c4ec1a77f944 --- /dev/null +++ b/examples/tts/speechllm/megatron_t5_speechllm.py @@ -0,0 +1,57 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from omegaconf.omegaconf import OmegaConf, open_dict + +from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronTrainerBuilder +from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector +from nemo.collections.tts.models.speechllm.megatron_t5_speechllm_model import MegatronT5SpeechLMModel +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.exp_manager import exp_manager + + +@hydra_runner(config_path="conf", config_name="megatron_t5_speechllm_medium.yaml") +def main(cfg) -> None: + logging.info("\n\n************** Experiment configuration ***********") + logging.info(f'\n{OmegaConf.to_yaml(cfg)}') + + # MegatronTrainerBuilder compat checks + if "gradient_as_bucket_view" not in cfg.model: + with open_dict(cfg): + cfg.model.gradient_as_bucket_view = False + + trainer = MegatronTrainerBuilder(cfg).create_trainer() + exp_manager(trainer, cfg.exp_manager) + + # hydra interpolation does not work here as the interpolation key is lost when PTL saves hparams + with open_dict(cfg): + cfg.model.precision = cfg.trainer.precision + + # load existing or init new soft prompt T5 model + if cfg.model.get("restore_path", None) is not None: + logging.info(f"cfg.model.restore_path {cfg.model.restore_path}") + model = MegatronT5SpeechLMModel.restore_from( + cfg.model.restore_path, cfg.model, trainer=trainer, save_restore_connector=NLPSaveRestoreConnector() + ) + else: + logging.info(f"cfg.model.restore_path is None") + model = MegatronT5SpeechLMModel(cfg.model, trainer=trainer) + model.maybe_init_from_pretrained_checkpoint(cfg=cfg) + + trainer.fit(model) + + +if __name__ == '__main__': + main() diff --git a/examples/tts/speechllm/megatron_t5_speechllm_inference.py b/examples/tts/speechllm/megatron_t5_speechllm_inference.py new file mode 100644 index 000000000000..48d46952a993 --- /dev/null +++ b/examples/tts/speechllm/megatron_t5_speechllm_inference.py @@ -0,0 +1,53 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from omegaconf.omegaconf import OmegaConf, open_dict + +from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronTrainerBuilder +from nemo.collections.tts.models.speechllm.megatron_t5_speechllm_model import MegatronT5SpeechLMModel +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.exp_manager import exp_manager + + +@hydra_runner(config_path="conf", config_name="megatron_t5_speechllm_inference.yaml") +def main(cfg) -> None: + logging.info("\n\n************** Experiment configuration ***********") + logging.info(f'\n{OmegaConf.to_yaml(cfg)}') + + # MegatronTrainerBuilder compat checks + if "gradient_as_bucket_view" not in cfg.model: + with open_dict(cfg): + cfg.model.gradient_as_bucket_view = False + + trainer = MegatronTrainerBuilder(cfg).create_trainer() + exp_manager(trainer, cfg.exp_manager) + + # hydra interpolation does not work here as the interpolation key is lost when PTL saves hparams + with open_dict(cfg): + cfg.model.precision = cfg.trainer.precision + + # load existing or init new soft prompt T5 model + checkpoint_path = cfg.get('checkpoint_path', None) + assert checkpoint_path is not None, "Please specify checkpoint_path in the config file" + model = MegatronT5SpeechLMModel.load_from_checkpoint( + checkpoint_path=checkpoint_path, trainer=trainer, cfg=cfg.model + ) + model.eval() + model = model.cuda() + trainer.test(model) + + +if __name__ == '__main__': + main() diff --git a/examples/tts/ssl_tts.py b/examples/tts/ssl_tts.py index a96dccb930ab..a50997a8f432 100644 --- a/examples/tts/ssl_tts.py +++ b/examples/tts/ssl_tts.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl from nemo.collections.common.callbacks import LogEpochTimeCallback from nemo.collections.tts.models import ssl_tts diff --git a/examples/tts/tacotron2.py b/examples/tts/tacotron2.py index a5446c35f775..6c4a15d98ef2 100755 --- a/examples/tts/tacotron2.py +++ b/examples/tts/tacotron2.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl from nemo.collections.common.callbacks import LogEpochTimeCallback from nemo.collections.tts.models import Tacotron2Model diff --git a/examples/tts/tacotron2_finetune.py b/examples/tts/tacotron2_finetune.py index a0531f1f2801..f8d4d1dcaad0 100644 --- a/examples/tts/tacotron2_finetune.py +++ b/examples/tts/tacotron2_finetune.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl from nemo.collections.common.callbacks import LogEpochTimeCallback from nemo.collections.tts.models import Tacotron2Model diff --git a/examples/tts/univnet.py b/examples/tts/univnet.py index 91aafa661842..ac6949405fd5 100644 --- a/examples/tts/univnet.py +++ b/examples/tts/univnet.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl from nemo.collections.common.callbacks import LogEpochTimeCallback from nemo.collections.tts.models import UnivNetModel diff --git a/examples/tts/vits.py b/examples/tts/vits.py index 75e0d827018a..6eeebd3ea15a 100644 --- a/examples/tts/vits.py +++ b/examples/tts/vits.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl from nemo.collections.tts.models.vits import VitsModel from nemo.core.config import hydra_runner diff --git a/examples/tts/waveglow.py b/examples/tts/waveglow.py index 66b13491abd4..3bcd008ab5e0 100755 --- a/examples/tts/waveglow.py +++ b/examples/tts/waveglow.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl from nemo.collections.common.callbacks import LogEpochTimeCallback from nemo.collections.tts.models import WaveGlowModel diff --git a/examples/vision/convert_ckpt_to_nemo.py b/examples/vision/convert_ckpt_to_nemo.py index 14876f6931f9..e0cf773f98c2 100644 --- a/examples/vision/convert_ckpt_to_nemo.py +++ b/examples/vision/convert_ckpt_to_nemo.py @@ -28,8 +28,8 @@ from argparse import ArgumentParser import torch -from pytorch_lightning.plugins.environments import TorchElasticEnvironment -from pytorch_lightning.trainer.trainer import Trainer +from lightning.pytorch.plugins.environments import TorchElasticEnvironment +from lightning.pytorch.trainer.trainer import Trainer from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector from nemo.collections.vision.models.megatron_vit_classification_models import MegatronVitClassificationModel diff --git a/examples/vision/vision_transformer/megatron_vit_classification_evaluate.py b/examples/vision/vision_transformer/megatron_vit_classification_evaluate.py index e827e4db73c7..f7c384809702 100644 --- a/examples/vision/vision_transformer/megatron_vit_classification_evaluate.py +++ b/examples/vision/vision_transformer/megatron_vit_classification_evaluate.py @@ -15,9 +15,9 @@ import os import torch +from lightning.pytorch import Trainer +from lightning.pytorch.plugins.environments import TorchElasticEnvironment from omegaconf.omegaconf import OmegaConf, open_dict -from pytorch_lightning import Trainer -from pytorch_lightning.plugins.environments import TorchElasticEnvironment from torch.utils.data import DataLoader from tqdm import tqdm @@ -38,7 +38,8 @@ def main(cfg) -> None: plugins = [] strategy = NLPDDPStrategy( - no_ddp_communication_hook=True, find_unused_parameters=False, # we don't use DDP for async grad allreduce + no_ddp_communication_hook=True, + find_unused_parameters=False, # we don't use DDP for async grad allreduce ) if cfg.get('cluster_type', None) == 'BCP': plugins.append(TorchElasticEnvironment()) @@ -82,7 +83,10 @@ def main(cfg) -> None: model.eval() val_transform = ClassificationTransform(model.cfg, (model.cfg.img_h, model.cfg.img_w), train=False) - val_data = ImageFolder(root=cfg.model.data.imagenet_val, transform=val_transform,) + val_data = ImageFolder( + root=cfg.model.data.imagenet_val, + transform=val_transform, + ) def dummy(): return @@ -91,12 +95,20 @@ def dummy(): trainer.strategy.launcher.launch(dummy, trainer=trainer) trainer.strategy.setup_environment() - test_loader = DataLoader(val_data, batch_size=cfg.model.micro_batch_size, num_workers=cfg.model.data.num_workers,) + test_loader = DataLoader( + val_data, + batch_size=cfg.model.micro_batch_size, + num_workers=cfg.model.data.num_workers, + ) autocast_dtype = torch_dtype_from_precision(trainer.precision) - with torch.no_grad(), torch.cuda.amp.autocast( - enabled=autocast_dtype in (torch.half, torch.bfloat16), dtype=autocast_dtype, + with ( + torch.no_grad(), + torch.cuda.amp.autocast( + enabled=autocast_dtype in (torch.half, torch.bfloat16), + dtype=autocast_dtype, + ), ): total = correct = 0.0 for tokens, labels in tqdm(test_loader): diff --git a/examples/vision/vision_transformer/megatron_vit_classification_infer.py b/examples/vision/vision_transformer/megatron_vit_classification_infer.py index a757eb7a1c1f..f50ccf1c325c 100644 --- a/examples/vision/vision_transformer/megatron_vit_classification_infer.py +++ b/examples/vision/vision_transformer/megatron_vit_classification_infer.py @@ -16,10 +16,10 @@ import os import torch +from lightning.pytorch import Trainer +from lightning.pytorch.plugins.environments import TorchElasticEnvironment from omegaconf.omegaconf import OmegaConf, open_dict from PIL import Image -from pytorch_lightning import Trainer -from pytorch_lightning.plugins.environments import TorchElasticEnvironment from torch.utils.data import DataLoader, Dataset from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy, NLPSaveRestoreConnector @@ -63,7 +63,8 @@ def main(cfg) -> None: plugins = [] strategy = NLPDDPStrategy( - no_ddp_communication_hook=True, find_unused_parameters=False, # we don't use DDP for async grad allreduce + no_ddp_communication_hook=True, + find_unused_parameters=False, # we don't use DDP for async grad allreduce ) if cfg.get('cluster_type', None) == 'BCP': plugins.append(TorchElasticEnvironment()) @@ -107,7 +108,10 @@ def main(cfg) -> None: model.eval() test_transform = ClassificationTransform(cfg.model, (model_cfg.img_h, model_cfg.img_w), train=False) - test_data = ImageFolderDataset(folder_path=cfg.data_path, transform=test_transform,) + test_data = ImageFolderDataset( + folder_path=cfg.data_path, + transform=test_transform, + ) test_loader = DataLoader(test_data, batch_size=8) def dummy(): @@ -119,8 +123,12 @@ def dummy(): autocast_dtype = torch_dtype_from_precision(trainer.precision) - with torch.no_grad(), torch.cuda.amp.autocast( - enabled=autocast_dtype in (torch.half, torch.bfloat16), dtype=autocast_dtype, + with ( + torch.no_grad(), + torch.cuda.amp.autocast( + enabled=autocast_dtype in (torch.half, torch.bfloat16), + dtype=autocast_dtype, + ), ): class_names = [] for tokens in test_loader: diff --git a/nemo/README.md b/nemo/README.md index a6025e77822a..ebc23f4d5803 100644 --- a/nemo/README.md +++ b/nemo/README.md @@ -2,7 +2,12 @@ NeMo (**Ne**ural **Mo**dules) is a toolkit for creating AI applications built ar **NeMo Core** provides common APIs all modules and models have to implement. -**NeMo Collections** +**NeMo 2.0 Collections** + +* LLM - A collection of data modules, models, configurations, and recipes for building training and parameter-efficient fine-tuning (PEFT) pipelines, including decoder-only models like those in the Llama, Gemma, and Mamba families. +* VLM - A collection of data modules, models, configurations, and recipes for training and PEFT pipelines in vision-language models. + +**NeMo 1.0 Collections** * ASR - collection of modules and models for building speech recognition networks * TTS - collection of modules and models for building speech synthesis networks diff --git a/nemo/collections/asr/data/audio_to_diar_label.py b/nemo/collections/asr/data/audio_to_diar_label.py index a1cb6d0f1bdc..0824c9c6ab51 100644 --- a/nemo/collections/asr/data/audio_to_diar_label.py +++ b/nemo/collections/asr/data/audio_to_diar_label.py @@ -15,15 +15,20 @@ import os from collections import OrderedDict from statistics import mode -from typing import Dict, Optional +from typing import Dict, List, Optional, Tuple +import numpy as np import torch from nemo.collections.asr.parts.utils.offline_clustering import get_argmin_mat -from nemo.collections.asr.parts.utils.speaker_utils import convert_rttm_line, prepare_split_data -from nemo.collections.common.parts.preprocessing.collections import DiarizationSpeechLabel +from nemo.collections.asr.parts.utils.speaker_utils import convert_rttm_line, get_subsegments, prepare_split_data +from nemo.collections.common.parts.preprocessing.collections import ( + DiarizationSpeechLabel, + EndtoEndDiarizationSpeechLabel, +) from nemo.core.classes import Dataset from nemo.core.neural_types import AudioSignal, EncodedRepresentation, LengthsType, NeuralType, ProbsType +from nemo.utils import logging def get_scale_mapping_list(uniq_timestamps): @@ -62,7 +67,7 @@ def get_scale_mapping_list(uniq_timestamps): return scale_mapping_argmat -def extract_seg_info_from_rttm(uniq_id, rttm_lines, mapping_dict=None, target_spks=None): +def extract_seg_info_from_rttm(rttm_lines, mapping_dict=None, target_spks=None): """ Get RTTM lines containing speaker labels, start time and end time. target_spks contains two targeted speaker indices for creating groundtruth label files. Only speakers in target_spks variable will be @@ -76,7 +81,8 @@ def extract_seg_info_from_rttm(uniq_id, rttm_lines, mapping_dict=None, target_sp mapping_dict (dict): Mapping between the estimated speakers and the speakers in the ground-truth annotation. `mapping_dict` variable is only provided when the inference mode is running in sequence-eval mode. - Sequence eval mode uses the mapping between the estimated speakers and the speakers in ground-truth annotation. + Sequence eval mode uses the mapping between the estimated speakers and the speakers + in ground-truth annotation. Returns: rttm_tup (tuple): Tuple containing lists of start time, end time and speaker labels. @@ -108,12 +114,14 @@ def assign_frame_level_spk_vector(rttm_timestamps, round_digits, frame_per_sec, Args: rttm_timestamps (list): List containing start and end time for each speaker segment label. - stt_list, end_list and speaker_list are contained. + `stt_list`, `end_list` and `speaker_list` are contained. frame_per_sec (int): - Number of feature frames per second. This quantity is determined by window_stride variable in preprocessing module. + Number of feature frames per second. This quantity is determined by + `window_stride` variable in preprocessing module. target_spks (tuple): - Speaker indices that are generated from combinations. If there are only one or two speakers, - only a single target_spks variable is generated. + Speaker indices that are generated from combinations. + If there are only one or two speakers, + only a single `target_spks` variable is generated. Returns: fr_level_target (torch.tensor): @@ -124,7 +132,7 @@ def assign_frame_level_spk_vector(rttm_timestamps, round_digits, frame_per_sec, return None else: sorted_speakers = sorted(list(set(speaker_list))) - total_fr_len = int(max(end_list) * (10 ** round_digits)) + total_fr_len = int(max(end_list) * (10**round_digits)) spk_num = max(len(sorted_speakers), min_spks) speaker_mapping_dict = {rttm_key: x_int for x_int, rttm_key in enumerate(sorted_speakers)} fr_level_target = torch.zeros(total_fr_len, spk_num) @@ -139,6 +147,140 @@ def assign_frame_level_spk_vector(rttm_timestamps, round_digits, frame_per_sec, return fr_level_target +def get_subsegments_to_timestamps( + subsegments: List[Tuple[float, float]], feat_per_sec: int = 100, max_end_ts: float = None, decimals=2 +): + """ + Convert subsegment timestamps to scale timestamps by multiplying with the feature rate (`feat_per_sec`) + and rounding. Segment is consisted of many subsegments and sugsegments are equivalent to `frames` + in end-to-end speaker diarization models. + + Args: + subsegments (List[Tuple[float, float]]): + A list of tuples where each tuple contains the start and end times of a subsegment + (frames in end-to-end models). + >>> subsegments = [[t0_start, t0_duration], [t1_start, t1_duration],..., [tN_start, tN_duration]] + feat_per_sec (int, optional): + The number of feature frames per second. Defaults to 100. + max_end_ts (float, optional): + The maximum end timestamp to clip the results. If None, no clipping is applied. Defaults to None. + decimals (int, optional): + The number of decimal places to round the timestamps. Defaults to 2. + + Example: + Segments starting from 0.0 and ending at 69.2 seconds. + If hop-length is 0.08 and the subsegment (frame) length is 0.16 seconds, + there are 864 = (69.2 - 0.16)/0.08 + 1 subsegments (frames in end-to-end models) in this segment. + >>> subsegments = [[[0.0, 0.16], [0.08, 0.16], ..., [69.04, 0.16], [69.12, 0.08]] + + Returns: + ts (torch.tensor): + A tensor containing the scaled and rounded timestamps for each subsegment. + """ + seg_ts = (torch.tensor(subsegments) * feat_per_sec).float() + ts_round = torch.round(seg_ts, decimals=decimals) + ts = ts_round.long() + ts[:, 1] = ts[:, 0] + ts[:, 1] + if max_end_ts is not None: + ts = np.clip(ts, 0, int(max_end_ts * feat_per_sec)) + return ts + + +def extract_frame_info_from_rttm(offset, duration, rttm_lines, round_digits=3): + """ + Extracts RTTM lines containing speaker labels, start time, and end time for a given audio segment. + + Args: + uniq_id (str): Unique identifier for the audio file and corresponding RTTM file. + offset (float): The starting time offset for the segment of interest. + duration (float): The duration of the segment of interest. + rttm_lines (list): List of RTTM lines in string format. + round_digits (int, optional): Number of decimal places to round the start and end times. Defaults to 3. + + Returns: + rttm_mat (tuple): A tuple containing lists of start times, end times, and speaker labels. + sess_to_global_spkids (dict): A mapping from session-specific speaker indices to global speaker identifiers. + """ + rttm_stt, rttm_end = offset, offset + duration + stt_list, end_list, speaker_list, speaker_set = [], [], [], [] + sess_to_global_spkids = dict() + + for rttm_line in rttm_lines: + start, end, speaker = convert_rttm_line(rttm_line) + + # Skip invalid RTTM lines where the start time is greater than the end time. + if start > end: + continue + + # Check if the RTTM segment overlaps with the specified segment of interest. + if (end > rttm_stt and start < rttm_end) or (start < rttm_end and end > rttm_stt): + # Adjust the start and end times to fit within the segment of interest. + start, end = max(start, rttm_stt), min(end, rttm_end) + else: + continue + + # Round the start and end times to the specified number of decimal places. + end_list.append(round(end, round_digits)) + stt_list.append(round(start, round_digits)) + + # Assign a unique index to each speaker and maintain a mapping. + if speaker not in speaker_set: + speaker_set.append(speaker) + speaker_list.append(speaker_set.index(speaker)) + sess_to_global_spkids.update({speaker_set.index(speaker): speaker}) + + rttm_mat = (stt_list, end_list, speaker_list) + return rttm_mat, sess_to_global_spkids + + +def get_frame_targets_from_rttm( + rttm_timestamps: list, + offset: float, + duration: float, + round_digits: int, + feat_per_sec: int, + max_spks: int, +): + """ + Create a multi-dimensional vector sequence containing speaker timestamp information in RTTM. + The unit-length is the frame shift length of the acoustic feature. The feature-level annotations + `feat_level_target` will later be converted to base-segment level diarization label. + + Args: + rttm_timestamps (list): + List containing start and end time for each speaker segment label. + stt_list, end_list and speaker_list are contained. + feat_per_sec (int): + Number of feature frames per second. + This quantity is determined by window_stride variable in preprocessing module. + target_spks (tuple): + Speaker indices that are generated from combinations. If there are only one or two speakers, + only a single target_spks variable is generated. + + Returns: + feat_level_target (torch.tensor): + Tensor containing label for each feature level frame. + """ + stt_list, end_list, speaker_list = rttm_timestamps + sorted_speakers = sorted(list(set(speaker_list))) + total_fr_len = int(duration * feat_per_sec) + if len(sorted_speakers) > max_spks: + logging.warning( + f"Number of speakers in RTTM file {len(sorted_speakers)} exceeds the maximum number of speakers: " + f"{max_spks}! Only {max_spks} first speakers remain, and this will affect frame metrics!" + ) + feat_level_target = torch.zeros(total_fr_len, max_spks) + for count, (stt, end, spk_rttm_key) in enumerate(zip(stt_list, end_list, speaker_list)): + if end < offset or stt > offset + duration: + continue + stt, end = max(offset, stt), min(offset + duration, end) + spk = spk_rttm_key + if spk < max_spks: + stt_fr, end_fr = int((stt - offset) * feat_per_sec), int((end - offset) * feat_per_sec) + feat_level_target[stt_fr:end_fr, spk] = 1 + return feat_level_target + + class _AudioMSDDTrainDataset(Dataset): """ Dataset class that loads a json file containing paths to audio files, @@ -214,7 +356,7 @@ def __init__( self.multiscale_args_dict = multiscale_args_dict self.emb_dir = emb_dir self.round_digits = 2 - self.decim = 10 ** self.round_digits + self.decim = 10**self.round_digits self.soft_label_thres = soft_label_thres self.pairwise_infer = pairwise_infer self.max_spks = 2 @@ -224,7 +366,10 @@ def __init__( self.global_rank = global_rank self.manifest_filepath = manifest_filepath self.multiscale_timestamp_dict = prepare_split_data( - self.manifest_filepath, self.emb_dir, self.multiscale_args_dict, self.global_rank, + self.manifest_filepath, + self.emb_dir, + self.multiscale_args_dict, + self.global_rank, ) def __len__(self): @@ -241,7 +386,7 @@ def assign_labels_to_longer_segs(self, uniq_id, base_scale_clus_label): Unique sample ID for training. base_scale_clus_label (torch.tensor): Tensor variable containing the speaker labels for the base-scale segments. - + Returns: per_scale_clus_label (torch.tensor): Tensor variable containing the speaker labels for each segment in each scale. @@ -270,15 +415,17 @@ def assign_labels_to_longer_segs(self, uniq_id, base_scale_clus_label): def get_diar_target_labels(self, uniq_id, sample, fr_level_target): """ - Convert frame-level diarization target variable into segment-level target variable. Since the granularity is reduced - from frame level (10ms) to segment level (100ms~500ms), we need a threshold value, `soft_label_thres`, which determines - the label of each segment based on the overlap between a segment range (start and end time) and the frame-level target variable. + Convert frame-level diarization target variable into segment-level target variable. + Since the granularity is reduced from frame level (10ms) to segment level (100ms~500ms), + we need a threshold value, `soft_label_thres`, which determines the label of each segment + based on the overlap between a segment range (start and end time) and the frame-level target variable. Args: uniq_id (str): Unique file ID that refers to an input audio file and corresponding RTTM (Annotation) file. sample: - `DiarizationSpeechLabel` instance containing sample information such as audio filepath and RTTM filepath. + `DiarizationSpeechLabel` instance containing sample information such as + audio filepath and RTTM filepath. fr_level_target (torch.tensor): Tensor containing label for each feature-level frame. @@ -286,13 +433,14 @@ def get_diar_target_labels(self, uniq_id, sample, fr_level_target): seg_target (torch.tensor): Tensor containing binary speaker labels for base-scale segments. base_clus_label (torch.tensor): - Representative speaker label for each segment. This variable only has one speaker label for each base-scale segment. + Representative speaker label for each segment. This variable only has one speaker label + for each base-scale segment. -1 means that there is no corresponding speaker in the target_spks tuple. """ seg_target_list, base_clus_label = [], [] self.scale_n = len(self.multiscale_timestamp_dict[uniq_id]['scale_dict']) subseg_time_stamp_list = self.multiscale_timestamp_dict[uniq_id]["scale_dict"][self.scale_n - 1]["time_stamps"] - for (seg_stt, seg_end) in subseg_time_stamp_list: + for seg_stt, seg_end in subseg_time_stamp_list: seg_stt_fr, seg_end_fr = int(seg_stt * self.frame_per_sec), int(seg_end * self.frame_per_sec) soft_label_vec_sess = torch.sum(fr_level_target[seg_stt_fr:seg_end_fr, :], axis=0) / ( seg_end_fr - seg_stt_fr @@ -321,7 +469,8 @@ def parse_rttm_for_ms_targets(self, sample): Args: sample: - `DiarizationSpeechLabel` instance containing sample information such as audio filepath and RTTM filepath. + `DiarizationSpeechLabel` instance containing sample information such as + audio filepath and RTTM filepath. target_spks (tuple): Speaker indices that are generated from combinations. If there are only one or two speakers, only a single target_spks tuple is generated. @@ -336,9 +485,10 @@ def parse_rttm_for_ms_targets(self, sample): multiscale embeddings to form an input matrix for the MSDD model. """ - rttm_lines = open(sample.rttm_file).readlines() + with open(sample.rttm_file, 'r') as file: + rttm_lines = file.readlines() uniq_id = self.get_uniq_id_with_range(sample) - rttm_timestamps = extract_seg_info_from_rttm(uniq_id, rttm_lines) + rttm_timestamps = extract_seg_info_from_rttm(rttm_lines) fr_level_target = assign_frame_level_spk_vector( rttm_timestamps, self.round_digits, self.frame_per_sec, target_spks=sample.target_spks ) @@ -370,14 +520,14 @@ def get_uniq_id_with_range(self, sample, deci=3): def get_ms_seg_timestamps(self, sample): """ - Get start and end time of segments in each scale. + Get start and end time of each diarization frame. Args: sample: `DiarizationSpeechLabel` instance from preprocessing.collections Returns: ms_seg_timestamps (torch.tensor): - Tensor containing Multiscale segment timestamps. + Tensor containing timestamps for each frame. ms_seg_counts (torch.tensor): Number of segments for each scale. This information is used for reshaping embedding batch during forward propagation. @@ -441,7 +591,8 @@ class _AudioMSDDInferDataset(Dataset): emb_dict (dict): Dictionary containing cluster-average embeddings and speaker mapping information. emb_seq (dict): - Dictionary containing multiscale speaker embedding sequence, scale mapping and corresponding segment timestamps. + Dictionary containing multiscale speaker embedding sequence, + scale mapping and corresponding segment timestamps. clus_label_dict (dict): Subsegment-level (from base-scale) speaker labels from clustering results. soft_label_thres (float): @@ -496,7 +647,7 @@ def __init__( self.emb_seq = emb_seq self.clus_label_dict = clus_label_dict self.round_digits = 2 - self.decim = 10 ** self.round_digits + self.decim = 10**self.round_digits self.frame_per_sec = int(1 / window_stride) self.soft_label_thres = soft_label_thres self.pairwise_infer = pairwise_infer @@ -529,20 +680,20 @@ def parse_rttm_multiscale(self, sample): rttm_lines = open(sample.rttm_file).readlines() uniq_id = os.path.splitext(os.path.basename(sample.rttm_file))[0] mapping_dict = self.emb_dict[max(self.emb_dict.keys())][uniq_id]['mapping'] - rttm_timestamps = extract_seg_info_from_rttm(uniq_id, rttm_lines, mapping_dict, sample.target_spks) + rttm_timestamps = extract_seg_info_from_rttm(rttm_lines, mapping_dict, sample.target_spks) fr_level_target = assign_frame_level_spk_vector( rttm_timestamps, self.round_digits, self.frame_per_sec, sample.target_spks ) seg_target = self.get_diar_target_labels_from_fr_target(uniq_id, fr_level_target) return seg_target - def get_diar_target_labels_from_fr_target(self, uniq_id, fr_level_target): + def get_diar_target_labels_from_fr_target(self, uniq_id: str, fr_level_target: torch.Tensor) -> torch.Tensor: """ Generate base-scale level binary diarization label from frame-level target matrix. For the given frame-level speaker target matrix fr_level_target, we count the number of frames that belong to each speaker and calculate - ratios for each speaker into the `soft_label_vec` variable. Finally, `soft_label_vec` variable is compared with `soft_label_thres` - to determine whether a label vector should contain 0 or 1 for each speaker bin. Note that seg_target variable has - dimension of (number of base-scale segments x 2) dimension. + ratios for each speaker into the `soft_label_vec` variable. Finally, `soft_label_vec` variable is compared + with `soft_label_thres` to determine whether a label vector should contain 0 or 1 for each speaker bin. + Note that seg_target variable has dimension of (number of base-scale segments x 2) dimension. Example of seg_target: [[0., 1.], [0., 1.], [1., 1.], [1., 0.], [1., 0.], ..., [0., 1.]] @@ -562,7 +713,7 @@ def get_diar_target_labels_from_fr_target(self, uniq_id, fr_level_target): return None else: seg_target_list = [] - for (seg_stt, seg_end, label_int) in self.clus_label_dict[uniq_id]: + for seg_stt, seg_end, label_int in self.clus_label_dict[uniq_id]: seg_stt_fr, seg_end_fr = int(seg_stt * self.frame_per_sec), int(seg_end * self.frame_per_sec) soft_label_vec = torch.sum(fr_level_target[seg_stt_fr:seg_end_fr, :], axis=0) / ( seg_end_fr - seg_stt_fr @@ -588,7 +739,8 @@ def __getitem__(self, index): if avg_embs.shape[2] > self.max_spks: raise ValueError( - f" avg_embs.shape[2] {avg_embs.shape[2]} should be less than or equal to self.max_num_speakers {self.max_spks}" + f" avg_embs.shape[2] {avg_embs.shape[2]} should be less than or equal to " + f"self.max_num_speakers {self.max_spks}" ) feats = [] @@ -682,7 +834,8 @@ def _msdd_train_collate_fn(self, batch): def _msdd_infer_collate_fn(self, batch): """ - Collate batch of feats (speaker embeddings), feature lengths, target label sequences and cluster-average embeddings. + Collate batch of feats (speaker embeddings), feature lengths, target label sequences + and cluster-average embeddings. Args: batch (tuple): @@ -784,6 +937,7 @@ def __init__( ) def msdd_train_collate_fn(self, batch): + """Collate batch of audio features, feature lengths, target label sequences for training.""" return _msdd_train_collate_fn(self, batch) @@ -805,11 +959,13 @@ class AudioToSpeechMSDDInferDataset(_AudioMSDDInferDataset): emb_dict (dict): Dictionary containing cluster-average embeddings and speaker mapping information. emb_seq (dict): - Dictionary containing multiscale speaker embedding sequence, scale mapping and corresponding segment timestamps. + Dictionary containing multiscale speaker embedding sequence, scale mapping + and corresponding segment timestamps. clus_label_dict (dict): Subsegment-level (from base-scale) speaker labels from clustering results. soft_label_thres (float): - Threshold that determines speaker labels of segments depending on the overlap with groundtruth speaker timestamps. + Threshold that determines speaker labels of segments depending on the overlap + with groundtruth speaker timestamps. featurizer: Featurizer instance for generating features from raw waveform. use_single_scale_clus (bool): @@ -817,11 +973,12 @@ class AudioToSpeechMSDDInferDataset(_AudioMSDDInferDataset): seq_eval_mode (bool): If True, F1 score will be calculated for each speaker pair during inference mode. window_stride (float): - Window stride for acoustic feature. This value is used for calculating the numbers of feature-level frames. + Window stride for acoustic feature. This value is used for calculating the numbers of + feature-level frames. pairwise_infer (bool): - If True, this Dataset class operates in inference mode. In inference mode, a set of speakers in the input audio - is split into multiple pairs of speakers and speaker tuples (e.g. 3 speakers: [(0,1), (1,2), (0,2)]) and then - fed into the MSDD to merge the individual results. + If True, this Dataset class operates in inference mode. In inference mode, a set of speakers + in the input audio is split into multiple pairs of speakers and speaker tuples + (e.g. 3 speakers: [(0,1), (1,2), (0,2)]) and then fed into the MSDD to merge the individual results. """ def __init__( @@ -850,4 +1007,366 @@ def __init__( ) def msdd_infer_collate_fn(self, batch): + """Collate batch of audio features, feature lengths, target label sequences for inference.""" return _msdd_infer_collate_fn(self, batch) + + +class _AudioToSpeechE2ESpkDiarDataset(Dataset): + """ + Dataset class that loads a json file containing paths to audio files, + RTTM files and number of speakers. This Dataset class is designed for + training or fine-tuning speaker embedding extractor and diarization decoder + at the same time. + + Example: + {"audio_filepath": "/path/to/audio_0.wav", "num_speakers": 2, + "rttm_filepath": "/path/to/diar_label_0.rttm} + ... + {"audio_filepath": "/path/to/audio_n.wav", "num_speakers": 2, + "rttm_filepath": "/path/to/diar_label_n.rttm} + + Args: + manifest_filepath (str): + Path to input manifest json files. + multiargs_dict (dict): + Dictionary containing the parameters for multiscale segmentation and clustering. + soft_label_thres (float): + Threshold that determines the label of each segment based on RTTM file information. + featurizer: + Featurizer instance for generating audio_signal from the raw waveform. + window_stride (float): + Window stride for acoustic feature. This value is used for calculating the numbers of feature-level frames. + """ + + @property + def output_types(self) -> Optional[Dict[str, NeuralType]]: + """Returns definitions of module output ports.""" + output_types = { + "audio_signal": NeuralType(('B', 'T'), AudioSignal()), + "audio_length": NeuralType(('B'), LengthsType()), + "targets": NeuralType(('B', 'T', 'C'), ProbsType()), + "target_len": NeuralType(('B'), LengthsType()), + } + + return output_types + + def __init__( + self, + *, + manifest_filepath: str, + soft_label_thres: float, + session_len_sec: float, + num_spks: int, + featurizer, + window_stride: float, + min_subsegment_duration: float = 0.03, + global_rank: int = 0, + dtype=torch.float16, + round_digits: int = 2, + soft_targets: bool = False, + subsampling_factor: int = 8, + ): + super().__init__() + self.collection = EndtoEndDiarizationSpeechLabel( + manifests_files=manifest_filepath.split(','), + round_digits=round_digits, + ) + self.featurizer = featurizer + self.round_digits = round_digits + self.feat_per_sec = int(1 / window_stride) + self.diar_frame_length = round(subsampling_factor * window_stride, round_digits) + self.session_len_sec = session_len_sec + self.soft_label_thres = soft_label_thres + self.max_spks = num_spks + self.min_subsegment_duration = min_subsegment_duration + self.dtype = dtype + self.use_asr_style_frame_count = True + self.soft_targets = soft_targets + self.round_digits = 2 + self.floor_decimal = 10**self.round_digits + + def __len__(self): + return len(self.collection) + + def get_uniq_id_with_range(self, sample, deci=3): + """ + Generate unique training sample ID from unique file ID, offset and duration. The start-end time added + unique ID is required for identifying the sample since multiple short audio samples are generated from a single + audio file. The start time and end time of the audio stream uses millisecond units if `deci=3`. + + Args: + sample: + `DiarizationSpeechLabel` instance from collections. + + Returns: + uniq_id (str): + Unique sample ID which includes start and end time of the audio stream. + Example: abc1001_3122_6458 + """ + bare_uniq_id = os.path.splitext(os.path.basename(sample.rttm_file))[0] + offset = str(int(round(sample.offset, deci) * pow(10, deci))) + endtime = str(int(round(sample.offset + sample.duration, deci) * pow(10, deci))) + uniq_id = f"{bare_uniq_id}_{offset}_{endtime}" + return uniq_id + + def parse_rttm_for_targets_and_lens(self, rttm_file, offset, duration, target_len): + """ + Generate target tensor variable by extracting groundtruth diarization labels from an RTTM file. + This function converts (start, end, speaker_id) format into base-scale (the finest scale) segment level + diarization label in a matrix form. + + Example of seg_target: + [[0., 1.], [0., 1.], [1., 1.], [1., 0.], [1., 0.], ..., [0., 1.]] + """ + with open(rttm_file, 'r') as f: + rttm_lines = f.readlines() + + rttm_timestamps, sess_to_global_spkids = extract_frame_info_from_rttm(offset, duration, rttm_lines) + + fr_level_target = get_frame_targets_from_rttm( + rttm_timestamps=rttm_timestamps, + offset=offset, + duration=duration, + round_digits=self.round_digits, + feat_per_sec=self.feat_per_sec, + max_spks=self.max_spks, + ) + + soft_target_seg = self.get_soft_targets_seg(feat_level_target=fr_level_target, target_len=target_len) + if self.soft_targets: + step_target = soft_target_seg + else: + step_target = (soft_target_seg >= self.soft_label_thres).float() + return step_target + + def get_soft_targets_seg(self, feat_level_target, target_len): + """ + Generate the final targets for the actual diarization step. + Here, frame level means step level which is also referred to as segments. + We follow the original paper and refer to the step level as "frames". + + Args: + feat_level_target (torch.tensor): + Tensor variable containing hard-labels of speaker activity in each feature-level segment. + target_len (torch.tensor): + Numbers of ms segments + + Returns: + soft_target_seg (torch.tensor): + Tensor variable containing soft-labels of speaker activity in each step-level segment. + """ + num_seg = torch.max(target_len) + targets = torch.zeros(num_seg, self.max_spks) + stride = int(self.feat_per_sec * self.diar_frame_length) + for index in range(num_seg): + if index == 0: + seg_stt_feat = 0 + else: + seg_stt_feat = stride * index - 1 - int(stride / 2) + if index == num_seg - 1: + seg_end_feat = feat_level_target.shape[0] + else: + seg_end_feat = stride * index - 1 + int(stride / 2) + targets[index] = torch.mean(feat_level_target[seg_stt_feat : seg_end_feat + 1, :], axis=0) + return targets + + def get_segment_timestamps( + self, + duration: float, + offset: float = 0, + sample_rate: int = 16000, + ): + """ + Get start and end time of segments in each scale. + + Args: + sample: + `DiarizationSpeechLabel` instance from preprocessing.collections + Returns: + segment_timestamps (torch.tensor): + Tensor containing Multiscale segment timestamps. + target_len (torch.tensor): + Number of segments for each scale. This information is used for reshaping embedding batch + during forward propagation. + """ + subsegments = get_subsegments( + offset=offset, + window=round(self.diar_frame_length * 2, self.round_digits), + shift=self.diar_frame_length, + duration=duration, + min_subsegment_duration=self.min_subsegment_duration, + use_asr_style_frame_count=self.use_asr_style_frame_count, + sample_rate=sample_rate, + feat_per_sec=self.feat_per_sec, + ) + if self.use_asr_style_frame_count: + effective_dur = ( + np.ceil((1 + duration * sample_rate) / int(sample_rate / self.feat_per_sec)).astype(int) + / self.feat_per_sec + ) + else: + effective_dur = duration + ts_tensor = get_subsegments_to_timestamps( + subsegments, self.feat_per_sec, decimals=2, max_end_ts=(offset + effective_dur) + ) + target_len = torch.tensor([ts_tensor.shape[0]]) + return target_len + + def __getitem__(self, index): + sample = self.collection[index] + if sample.offset is None: + sample.offset = 0 + offset = sample.offset + if self.session_len_sec < 0: + session_len_sec = sample.duration + else: + session_len_sec = min(sample.duration, self.session_len_sec) + + audio_signal = self.featurizer.process(sample.audio_file, offset=offset, duration=session_len_sec) + + # We should resolve the length mis-match from the round-off errors between these two variables: + # `session_len_sec` and `audio_signal.shape[0]` + session_len_sec = ( + np.floor(audio_signal.shape[0] / self.featurizer.sample_rate * self.floor_decimal) / self.floor_decimal + ) + audio_signal = audio_signal[: round(self.featurizer.sample_rate * session_len_sec)] + + audio_signal_length = torch.tensor(audio_signal.shape[0]).long() + audio_signal, audio_signal_length = audio_signal.to('cpu'), audio_signal_length.to('cpu') + target_len = self.get_segment_timestamps(duration=session_len_sec, sample_rate=self.featurizer.sample_rate) + targets = self.parse_rttm_for_targets_and_lens( + rttm_file=sample.rttm_file, offset=offset, duration=session_len_sec, target_len=target_len + ) + return audio_signal, audio_signal_length, targets, target_len + + +def _eesd_train_collate_fn(self, batch): + """ + Collate a batch of variables needed for training the end-to-end speaker diarization (EESD) model + from raw waveforms to diarization labels. The following variables are included in the training/validation batch: + + Args: + batch (tuple): + A tuple containing the variables for diarization training. + + Returns: + audio_signal (torch.Tensor): + A tensor containing the raw waveform samples (time series) loaded from the `audio_filepath` + in the input manifest file. + feature_length (torch.Tensor): + A tensor containing the lengths of the raw waveform samples. + targets (torch.Tensor): + Groundtruth speaker labels for the given input embedding sequence. + target_lens (torch.Tensor): + A tensor containing the number of segments for each sample in the batch, necessary for + reshaping inputs to the EESD model. + """ + packed_batch = list(zip(*batch)) + audio_signal, feature_length, targets, target_len = packed_batch + audio_signal_list, feature_length_list = [], [] + target_len_list, targets_list = [], [] + + max_raw_feat_len = max([x.shape[0] for x in audio_signal]) + max_target_len = max([x.shape[0] for x in targets]) + if max([len(feat.shape) for feat in audio_signal]) > 1: + max_ch = max([feat.shape[1] for feat in audio_signal]) + else: + max_ch = 1 + for feat, feat_len, tgt, segment_ct in batch: + seq_len = tgt.shape[0] + if len(feat.shape) > 1: + pad_feat = (0, 0, 0, max_raw_feat_len - feat.shape[0]) + else: + pad_feat = (0, max_raw_feat_len - feat.shape[0]) + if feat.shape[0] < feat_len: + feat_len_pad = feat_len - feat.shape[0] + feat = torch.nn.functional.pad(feat, (0, feat_len_pad)) + pad_tgt = (0, 0, 0, max_target_len - seq_len) + padded_feat = torch.nn.functional.pad(feat, pad_feat) + padded_tgt = torch.nn.functional.pad(tgt, pad_tgt) + if max_ch > 1 and padded_feat.shape[1] < max_ch: + feat_ch_pad = max_ch - padded_feat.shape[1] + padded_feat = torch.nn.functional.pad(padded_feat, (0, feat_ch_pad)) + audio_signal_list.append(padded_feat) + feature_length_list.append(feat_len.clone().detach()) + target_len_list.append(segment_ct.clone().detach()) + targets_list.append(padded_tgt) + audio_signal = torch.stack(audio_signal_list) + feature_length = torch.stack(feature_length_list) + target_lens = torch.stack(target_len_list).squeeze(1) + targets = torch.stack(targets_list) + return audio_signal, feature_length, targets, target_lens + + +class AudioToSpeechE2ESpkDiarDataset(_AudioToSpeechE2ESpkDiarDataset): + """ + Dataset class for loading a JSON file containing paths to audio files, + RTTM (Rich Transcription Time Marked) files, and the number of speakers. + This class is designed for training or fine-tuning a speaker embedding + extractor and diarization decoder simultaneously. + + The JSON manifest file should have entries in the following format: + + Example: + { + "audio_filepath": "/path/to/audio_0.wav", + "num_speakers": 2, + "rttm_filepath": "/path/to/diar_label_0.rttm" + } + ... + { + "audio_filepath": "/path/to/audio_n.wav", + "num_speakers": 2, + "rttm_filepath": "/path/to/diar_label_n.rttm" + } + + Args: + manifest_filepath (str): + Path to the input manifest JSON file containing paths to audio and RTTM files. + soft_label_thres (float): + Threshold for assigning soft labels to segments based on RTTM file information. + session_len_sec (float): + Duration of each session (in seconds) for training or fine-tuning. + num_spks (int): + Number of speakers in the audio files. + featurizer: + Instance of a featurizer for generating features from the raw waveform. + window_stride (float): + Window stride (in seconds) for extracting acoustic features, used to calculate + the number of feature frames. + global_rank (int): + Global rank of the current process (used for distributed training). + soft_targets (bool): + Whether or not to use soft targets during training. + + Methods: + eesd_train_collate_fn(batch): + Collates a batch of data for end-to-end speaker diarization training. + """ + + def __init__( + self, + *, + manifest_filepath: str, + soft_label_thres: float, + session_len_sec: float, + num_spks: int, + featurizer, + window_stride, + global_rank: int, + soft_targets: bool, + ): + super().__init__( + manifest_filepath=manifest_filepath, + soft_label_thres=soft_label_thres, + session_len_sec=session_len_sec, + num_spks=num_spks, + featurizer=featurizer, + window_stride=window_stride, + global_rank=global_rank, + soft_targets=soft_targets, + ) + + def eesd_train_collate_fn(self, batch): + """Collate a batch of data for end-to-end speaker diarization training.""" + return _eesd_train_collate_fn(self, batch) diff --git a/nemo/collections/asr/data/audio_to_diar_label_lhotse.py b/nemo/collections/asr/data/audio_to_diar_label_lhotse.py new file mode 100644 index 000000000000..927e3887de78 --- /dev/null +++ b/nemo/collections/asr/data/audio_to_diar_label_lhotse.py @@ -0,0 +1,82 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Optional, Tuple + +import torch.utils.data +from lhotse.dataset import AudioSamples +from lhotse.dataset.collation import collate_matrices + +from nemo.collections.asr.parts.utils.asr_multispeaker_utils import ( + get_hidden_length_from_sample_length, + speaker_to_target, +) +from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, NeuralType + + +class LhotseAudioToSpeechE2ESpkDiarDataset(torch.utils.data.Dataset): + """ + This dataset is a Lhotse version of diarization dataset in audio_to_diar_label.py. + Unlike native NeMo datasets, Lhotse dataset defines only the mapping from + a CutSet (meta-data) to a mini-batch with PyTorch tensors. + Specifically, it performs tokenization, I/O, augmentation, and feature extraction (if any). + Managing data, sampling, de-duplication across workers/nodes etc. is all handled + by Lhotse samplers instead. + """ + + @property + def output_types(self) -> Optional[Dict[str, NeuralType]]: + """Define the output types of the dataset.""" + return { + 'audio_signal': NeuralType(('B', 'T'), AudioSignal()), + 'a_sig_length': NeuralType(tuple('B'), LengthsType()), + 'targets': NeuralType(('B', 'T', 'N'), LabelsType()), + 'target_length': NeuralType(tuple('B'), LengthsType()), + 'sample_id': NeuralType(tuple('B'), LengthsType(), optional=True), + } + + def __init__(self, cfg): + super().__init__() + self.load_audio = AudioSamples(fault_tolerant=True) + self.cfg = cfg + self.num_speakers = self.cfg.get('num_speakers', 4) + self.num_sample_per_mel_frame = int( + self.cfg.get('window_stride', 0.01) * self.cfg.get('sample_rate', 16000) + ) # 160 samples for every 1ms by default + self.num_mel_frame_per_target_frame = int(self.cfg.get('subsampling_factor', 8)) + self.spk_tar_all_zero = self.cfg.get('spk_tar_all_zero', False) + + def __getitem__(self, cuts) -> Tuple[torch.Tensor, ...]: + audio, audio_lens, cuts = self.load_audio(cuts) + speaker_activities = [] + for cut in cuts: + speaker_activity = speaker_to_target( + a_cut=cut, + num_speakers=self.num_speakers, + num_sample_per_mel_frame=self.num_sample_per_mel_frame, + num_mel_frame_per_asr_frame=self.num_mel_frame_per_target_frame, + spk_tar_all_zero=self.spk_tar_all_zero, + boundary_segments=True, + ) + speaker_activities.append(speaker_activity) + targets = collate_matrices(speaker_activities).to(audio.dtype) + target_lens_list = [] + for audio_len in audio_lens: + target_fr_len = get_hidden_length_from_sample_length( + audio_len, self.num_sample_per_mel_frame, self.num_mel_frame_per_target_frame + ) + target_lens_list.append([target_fr_len]) + target_lens = torch.tensor(target_lens_list) + + return audio, audio_lens, targets, target_lens diff --git a/nemo/collections/asr/data/audio_to_text_dataset.py b/nemo/collections/asr/data/audio_to_text_dataset.py index c63c73323797..3e1301dd4d53 100644 --- a/nemo/collections/asr/data/audio_to_text_dataset.py +++ b/nemo/collections/asr/data/audio_to_text_dataset.py @@ -19,9 +19,9 @@ from typing import Any, List, Optional, Union import torch +from lightning.pytorch.callbacks import BasePredictionWriter from omegaconf import DictConfig, OmegaConf, open_dict from omegaconf.listconfig import ListConfig -from pytorch_lightning.callbacks import BasePredictionWriter from torch.utils.data import ChainDataset from nemo.collections.asr.data import audio_to_text, audio_to_text_dali @@ -867,10 +867,15 @@ def write_on_batch_end( sample = sample_id if isinstance(sample, lhotse.cut.MixedCut): sample = sample.first_non_padding_cut - item["audio_filepath"] = sample.recording.sources[0].source + if sample.recording.sources[0].source != '': + item["audio_filepath"] = sample.recording.sources[0].source + else: + item["audio_filepath"] = sample.id item["offset"] = sample.start item["duration"] = sample.duration - item["text"] = sample.supervisions[0].text + item["text"] = sample.supervisions[0].text or '' + if hasattr(sample, 'shard_id'): + item["shard_id"] = sample.shard_id item["pred_text"] = transcribed_text self.outf.write(json.dumps(item) + "\n") self.samples_num += 1 diff --git a/nemo/collections/asr/data/audio_to_text_lhotse.py b/nemo/collections/asr/data/audio_to_text_lhotse.py index f916ae1de56b..0ae3059a9296 100644 --- a/nemo/collections/asr/data/audio_to_text_lhotse.py +++ b/nemo/collections/asr/data/audio_to_text_lhotse.py @@ -43,17 +43,18 @@ def output_types(self) -> Optional[Dict[str, NeuralType]]: 'sample_id': NeuralType(tuple('B'), LengthsType(), optional=True), } - def __init__(self, tokenizer): + def __init__(self, tokenizer, return_cuts=False): super().__init__() self.tokenizer = TokenizerWrapper(tokenizer) self.load_audio = AudioSamples(fault_tolerant=True) + self.return_cuts = return_cuts def __getitem__(self, cuts) -> Tuple[torch.Tensor, ...]: audio, audio_lens, cuts = self.load_audio(cuts) tokens = [ torch.cat( [ - torch.as_tensor(s.tokens if hasattr(s, "tokens") else self.tokenizer(s.text, s.language)) + torch.as_tensor(s.tokens if hasattr(s, "tokens") else self.tokenizer(s.text or "", s.language)) for s in c.supervisions ], dim=0, @@ -62,6 +63,8 @@ def __getitem__(self, cuts) -> Tuple[torch.Tensor, ...]: ] token_lens = torch.tensor([t.size(0) for t in tokens], dtype=torch.long) tokens = collate_vectors(tokens, padding_value=0) + if self.return_cuts: + return audio, audio_lens, tokens, token_lens, cuts.drop_in_memory_data() return audio, audio_lens, tokens, token_lens diff --git a/nemo/collections/asr/losses/__init__.py b/nemo/collections/asr/losses/__init__.py index 756a071178d7..f88bd49d1f7b 100644 --- a/nemo/collections/asr/losses/__init__.py +++ b/nemo/collections/asr/losses/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. from nemo.collections.asr.losses.angularloss import AngularSoftmaxLoss +from nemo.collections.asr.losses.bce_loss import BCELoss from nemo.collections.asr.losses.ctc import CTCLoss from nemo.collections.asr.losses.lattice_losses import LatticeLoss from nemo.collections.asr.losses.ssl_losses.contrastive import ContrastiveLoss diff --git a/nemo/collections/asr/losses/bce_loss.py b/nemo/collections/asr/losses/bce_loss.py index 30e31b8610ec..36a7a0166f26 100644 --- a/nemo/collections/asr/losses/bce_loss.py +++ b/nemo/collections/asr/losses/bce_loss.py @@ -28,12 +28,11 @@ class BCELoss(Loss, Typing): @property def input_types(self): - """Input types definitions for AnguarLoss. - """ + """Input types definitions for AnguarLoss.""" return { "probs": NeuralType(('B', 'T', 'C'), ProbsType()), 'labels': NeuralType(('B', 'T', 'C'), LabelsType()), - "signal_lengths": NeuralType(tuple('B'), LengthsType()), + "target_lens": NeuralType(('B'), LengthsType()), } @property @@ -43,31 +42,94 @@ def output_types(self): """ return {"loss": NeuralType(elements_type=LossType())} - def __init__(self, reduction='sum', alpha=1.0, weight=torch.tensor([0.5, 0.5])): + def __init__( + self, + reduction: str = 'mean', + alpha: float = 1.0, + weight: torch.Tensor = torch.tensor([0.1, 0.9]), + sorted_preds: bool = False, + sorted_loss: bool = False, + class_normalization: bool = False, + ): + """ + A custom loss function that supports class normalization, + weighted binary cross-entropy, and optional sorting. + + Args: + reduction (str): Specifies the reduction to apply to the output, + options are 'mean', 'sum', or 'none'. Default is 'mean'. + alpha (float): Scaling factor for loss (unused in this implementation). Default is 1.0. + weight (torch.Tensor): Class weights for the binary cross-entropy loss. Default is [0.1, 0.9]. + sorted_preds (bool): If True, assumes predictions are sorted. Default is False. + sorted_loss (bool): If True, sorts the loss before reduction. Default is False. + class_normalization (bool): If True, uses 'none' reduction for per-class loss. Default is False. + """ super().__init__() - self.reduction = reduction + self.class_normalization = class_normalization + if class_normalization: + self.reduction = 'none' + else: + self.reduction = 'mean' self.loss_weight = weight - self.loss_f = torch.nn.BCELoss(weight=self.loss_weight, reduction=self.reduction) + self.loss_f = torch.nn.BCELoss(reduction=self.reduction) + self.sorted_preds = sorted_preds + self.sorted_loss = sorted_loss + self.eps = 1e-6 @typecheck() - def forward(self, probs, labels, signal_lengths): + def forward(self, probs, labels, target_lens): """ - Calculate binary cross entropy loss based on probs, labels and signal_lengths variables. + Calculate binary cross entropy loss based on probs, labels and target_lens variables. Args: probs (torch.tensor) Predicted probability value which ranges from 0 to 1. Sigmoid output is expected. labels (torch.tensor) Groundtruth label for the predicted samples. - signal_lengths (torch.tensor): + target_lens (torch.tensor): The actual length of the sequence without zero-padding. Returns: loss (NeuralType) Binary cross entropy loss value. """ - probs_list = [probs[k, : signal_lengths[k], :] for k in range(probs.shape[0])] - targets_list = [labels[k, : signal_lengths[k], :] for k in range(labels.shape[0])] + probs_list = [probs[k, : target_lens[k], :] for k in range(probs.shape[0])] + targets_list = [labels[k, : target_lens[k], :] for k in range(labels.shape[0])] probs = torch.cat(probs_list, dim=0) labels = torch.cat(targets_list, dim=0) - return self.loss_f(probs, labels) + norm_weight = torch.zeros_like(labels).detach().clone() + loss = torch.tensor(0.0).to(labels.device) + + if self.class_normalization in ['class', 'class_binary', 'binary']: + if self.class_normalization in ['class', 'class_binary']: + # Normalize loss by number of classes + norm_weight = 1 / (labels.sum(dim=0) + self.eps) + norm_weight_norm = norm_weight / norm_weight.sum() + norm_weight_norm = torch.clamp(norm_weight_norm, min=0.05, max=1.0) + norm_weight_norm = norm_weight_norm / norm_weight_norm.max() + norm_weight = norm_weight_norm[None, :].expand_as(labels).detach().clone() + else: + norm_weight = torch.ones_like(labels).detach().clone() + + if self.class_normalization in ['binary', 'class_binary']: + binary_weight = torch.ones_like(labels).detach().clone() + one_weight = (labels.sum() / (labels.shape[0] * labels.shape[1])).to(labels.device) + binary_weight[labels == 0] = one_weight + binary_weight[labels == 1] = 1 - one_weight + else: + binary_weight = torch.ones_like(labels).detach().clone() + + elif self.class_normalization == 'none' or not self.class_normalization: + binary_weight = torch.ones_like(labels).detach().clone() + norm_weight = torch.ones_like(labels).detach().clone() + + if self.reduction == 'sum': + loss = self.loss_f(probs, labels) + elif self.reduction == 'mean': + loss = self.loss_f(probs, labels).mean() + elif self.reduction == 'none': + if self.class_normalization in ['class', 'class_binary', 'binary']: + loss = (binary_weight * norm_weight * self.loss_f(probs, labels)).sum() + else: + loss = self.loss_f(probs, labels) + return loss diff --git a/nemo/collections/asr/metrics/der.py b/nemo/collections/asr/metrics/der.py index fc5cded970d0..c8dec24eaaca 100644 --- a/nemo/collections/asr/metrics/der.py +++ b/nemo/collections/asr/metrics/der.py @@ -36,12 +36,12 @@ def get_partial_ref_labels(pred_labels: List[str], ref_labels: List[str]) -> List[str]: """ - For evaluation of online diarization performance, generate partial reference labels + For evaluation of online diarization performance, generate partial reference labels from the last prediction time. Args: pred_labels (list[str]): list of partial prediction labels - ref_labels (list[str]): list of full reference labels + ref_labels (list[str]): list of full reference labels Returns: ref_labels_out (list[str]): list of partial reference labels @@ -84,8 +84,8 @@ def get_online_DER_stats( For evaluation of online diarization performance, add cumulative, average, and maximum DER/CER. Args: - DER (float): Diarization Error Rate from the start to the current point - CER (float): Confusion Error Rate from the start to the current point + DER (float): Diarization Error Rate from the start to the current point + CER (float): Confusion Error Rate from the start to the current point FA (float): False Alarm from the start to the current point MISS (float): Miss rate from the start to the current point diar_eval_count (int): Number of evaluation sessions @@ -123,30 +123,45 @@ def uem_timeline_from_file(uem_file, uniq_name=''): lines = f.readlines() for line in lines: line = line.strip() - speaker_id, channel, start_time, end_time = line.split() + _, _, start_time, end_time = line.split() timeline.add(Segment(float(start_time), float(end_time))) return timeline def score_labels( - AUDIO_RTTM_MAP, all_reference, all_hypothesis, collar=0.25, ignore_overlap=True, verbose: bool = True + AUDIO_RTTM_MAP, + all_reference: list, + all_hypothesis: list, + all_uem: List[List[float]] = None, + collar: float = 0.25, + ignore_overlap: bool = True, + verbose: bool = True, ) -> Optional[Tuple[DiarizationErrorRate, Dict]]: """ Calculate DER, CER, FA and MISS rate from hypotheses and references. Hypothesis results are coming from Pyannote-formatted speaker diarization results and References are coming from Pyannote-formatted RTTM data. - Args: AUDIO_RTTM_MAP (dict): Dictionary containing information provided from manifestpath all_reference (list[uniq_name,Annotation]): reference annotations for score calculation all_hypothesis (list[uniq_name,Annotation]): hypothesis annotations for score calculation - verbose (bool): Warns if RTTM file is not found. + all_uem (list[list[float]]): List of UEM segments for each audio file. If UEM file is not provided, + it will be read from manifestpath + collar (float): Length of collar (in seconds) for diarization error rate calculation + ignore_overlap (bool): If True, overlapping segments in reference and hypothesis will be ignored + verbose (bool): If True, warning messages will be printed Returns: - metric (pyannote.DiarizationErrorRate): Pyannote Diarization Error Rate metric object. This object contains detailed scores of each audiofile. + metric (pyannote.DiarizationErrorRate): Pyannote Diarization Error Rate metric object. + This object contains detailed scores of each audiofile. mapping (dict): Mapping dict containing the mapping speaker label for each audio input + itemized_errors (tuple): Tuple containing (DER, CER, FA, MISS) for each audio file. + - DER: Diarization Error Rate, which is sum of all three errors, CER + FA + MISS. + - CER: Confusion Error Rate, which is sum of all errors + - FA: False Alarm Rate, which is the number of false alarm segments + - MISS: Missed Detection Rate, which is the number of missed detection segments < Caveat > Unlike md-eval.pl, "no score" collar in pyannote.metrics is the maximum length of @@ -157,33 +172,51 @@ def score_labels( if len(all_reference) == len(all_hypothesis): metric = DiarizationErrorRate(collar=2 * collar, skip_overlap=ignore_overlap) - mapping_dict = {} - for (reference, hypothesis) in zip(all_reference, all_hypothesis): + mapping_dict, correct_spk_count = {}, 0 + for idx, (reference, hypothesis) in enumerate(zip(all_reference, all_hypothesis)): ref_key, ref_labels = reference _, hyp_labels = hypothesis - uem = AUDIO_RTTM_MAP[ref_key].get('uem_filepath', None) - if uem is not None: - uem = uem_timeline_from_file(uem_file=uem, uniq_name=ref_key) - metric(ref_labels, hyp_labels, uem=uem, detailed=True) + if len(ref_labels.labels()) == len(hyp_labels.labels()): + correct_spk_count += 1 + if verbose and len(ref_labels.labels()) != len(hyp_labels.labels()): + logging.info( + f"Wrong Spk. Count with uniq_id:...{ref_key[-10:]}, " + f"Ref: {len(ref_labels.labels())}, Hyp: {len(hyp_labels.labels())}" + ) + uem_obj = None + if all_uem is not None: + metric(ref_labels, hyp_labels, uem=all_uem[idx], detailed=True) + elif AUDIO_RTTM_MAP[ref_key].get('uem_filepath', None) is not None: + uem_file = AUDIO_RTTM_MAP[ref_key].get('uem_filepath', None) + uem_obj = uem_timeline_from_file(uem_file=uem_file, uniq_name=ref_key) + metric(ref_labels, hyp_labels, uem=uem_obj, detailed=True) + else: + metric(ref_labels, hyp_labels, detailed=True) mapping_dict[ref_key] = metric.optimal_mapping(ref_labels, hyp_labels) + spk_count_acc = correct_spk_count / len(all_reference) DER = abs(metric) + if metric['total'] == 0: + raise ValueError("Total evaluation time is 0. Abort.") CER = metric['confusion'] / metric['total'] FA = metric['false alarm'] / metric['total'] MISS = metric['missed detection'] / metric['total'] + itemized_errors = (DER, CER, FA, MISS) + if verbose: + logging.info(f"\n{metric.report()}") logging.info( - "Cumulative Results for collar {} sec and ignore_overlap {}: \n FA: {:.4f}\t MISS {:.4f}\t \ - Diarization ER: {:.4f}\t, Confusion ER:{:.4f}".format( - collar, ignore_overlap, FA, MISS, DER, CER - ) + f"Cumulative Results for collar {collar} sec and ignore_overlap {ignore_overlap}: \n" + f"| FA: {FA:.4f} | MISS: {MISS:.4f} | CER: {CER:.4f} | DER: {DER:.4f} | " + f"Spk. Count Acc. {spk_count_acc:.4f}\n" ) return metric, mapping_dict, itemized_errors elif verbose: logging.warning( - "Check if each ground truth RTTMs were present in the provided manifest file. Skipping calculation of Diariazation Error Rate" + "Check if each ground truth RTTMs were present in the provided manifest file. " + "Skipping calculation of Diariazation Error Rate" ) return None @@ -365,7 +398,7 @@ def calculate_session_cpWER( # Calculate WER for each speaker in hypothesis with reference # There are (number of hyp speakers) x (number of ref speakers) combinations lsa_wer_list = [] - for (spk_hyp_trans, spk_ref_trans) in all_pairs: + for spk_hyp_trans, spk_ref_trans in all_pairs: spk_wer = word_error_rate(hypotheses=[spk_hyp_trans], references=[spk_ref_trans]) lsa_wer_list.append(spk_wer) @@ -419,7 +452,7 @@ def concat_perm_word_error_rate( f"{len(spk_hypotheses)} and {len(spk_references)} correspondingly" ) cpWER_values, hyps_spk, refs_spk = [], [], [] - for (spk_hypothesis, spk_reference) in zip(spk_hypotheses, spk_references): + for spk_hypothesis, spk_reference in zip(spk_hypotheses, spk_references): cpWER, min_hypothesis, concat_reference = calculate_session_cpWER(spk_hypothesis, spk_reference) cpWER_values.append(cpWER) hyps_spk.append(min_hypothesis) diff --git a/nemo/collections/asr/metrics/multi_binary_acc.py b/nemo/collections/asr/metrics/multi_binary_acc.py index 8cc21c53ad82..7b2b9148a74e 100644 --- a/nemo/collections/asr/metrics/multi_binary_acc.py +++ b/nemo/collections/asr/metrics/multi_binary_acc.py @@ -73,13 +73,26 @@ def on_validation_epoch_end(self): def __init__(self, dist_sync_on_step=False): super().__init__(dist_sync_on_step=dist_sync_on_step) - self.total_correct_counts = 0 - self.total_sample_counts = 0 - self.true_positive_count = 0 - self.false_positive_count = 0 - self.false_negative_count = 0 + self.add_state("total_correct_counts", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False) + self.add_state("total_sample_counts", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False) + self.add_state("true_positive_count", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False) + self.add_state("false_positive_count", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False) + self.add_state("false_negative_count", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False) + self.add_state("positive_count", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False) + self.eps = 1e-6 + + def update( + self, preds: torch.Tensor, targets: torch.Tensor, signal_lengths: torch.Tensor, cumulative=False + ) -> torch.Tensor: + """ + Update the metric with the given predictions, targets, and signal lengths to the metric instance. - def update(self, preds: torch.Tensor, targets: torch.Tensor, signal_lengths: torch.Tensor) -> torch.Tensor: + Args: + preds (torch.Tensor): Predicted values. + targets (torch.Tensor): Target values. + signal_lengths (torch.Tensor): Length of each sequence in the batch input. + cumulative (bool): Whether to accumulate the values over time. + """ with torch.no_grad(): preds_list = [preds[k, : signal_lengths[k], :] for k in range(preds.shape[0])] targets_list = [targets[k, : signal_lengths[k], :] for k in range(targets.shape[0])] @@ -91,22 +104,35 @@ def update(self, preds: torch.Tensor, targets: torch.Tensor, signal_lengths: tor self.positive = self.preds.round().bool() == 1 self.negative = self.preds.round().bool() == 0 - self.positive_count = torch.sum(self.preds.round().bool() == True) - self.true_positive_count += torch.sum(torch.logical_and(self.true, self.positive)) - self.false_positive_count += torch.sum(torch.logical_and(self.false, self.positive)) - self.false_negative_count += torch.sum(torch.logical_and(self.false, self.negative)) - - self.total_correct_counts += torch.sum(self.preds.round().bool() == self.targets.round().bool()) - self.total_sample_counts += torch.prod(torch.tensor(self.targets.shape)) + if cumulative: + self.positive_count += torch.sum(self.preds.round().bool() == True) + self.true_positive_count += torch.sum(torch.logical_and(self.true, self.positive)) + self.false_positive_count += torch.sum(torch.logical_and(self.false, self.positive)) + self.false_negative_count += torch.sum(torch.logical_and(self.false, self.negative)) + self.total_correct_counts += torch.sum(self.preds.round().bool() == self.targets.round().bool()) + self.total_sample_counts += torch.prod(torch.tensor(self.targets.shape)) + else: + self.positive_count = torch.sum(self.preds.round().bool() == True) + self.true_positive_count = torch.sum(torch.logical_and(self.true, self.positive)) + self.false_positive_count = torch.sum(torch.logical_and(self.false, self.positive)) + self.false_negative_count = torch.sum(torch.logical_and(self.false, self.negative)) + self.total_correct_counts = torch.sum(self.preds.round().bool() == self.targets.round().bool()) + self.total_sample_counts = torch.prod(torch.tensor(self.targets.shape)) def compute(self): """ Compute F1 score from the accumulated values. Return -1 if the F1 score is NaN. + + Returns: + f1_score (torch.Tensor): F1 score calculated from the accumulated values. + precision (torch.Tensor): Precision calculated from the accumulated values. + recall (torch.Tensor): Recall calculated from the accumulated values. """ - self.precision = self.true_positive_count / (self.true_positive_count + self.false_positive_count) - self.recall = self.true_positive_count / (self.true_positive_count + self.false_negative_count) - self.f1_score = 2 * self.precision * self.recall / (self.precision + self.recall) - if torch.isnan(self.f1_score): + precision = self.true_positive_count / (self.true_positive_count + self.false_positive_count + self.eps) + recall = self.true_positive_count / (self.true_positive_count + self.false_negative_count + self.eps) + f1_score = (2 * precision * recall / (precision + recall + self.eps)).detach().clone() + + if torch.isnan(f1_score): logging.warn("self.f1_score contains NaN value. Returning -1 instead of NaN value.") - self.f1_score = -1 - return self.f1_score + f1_score = -1 + return f1_score.float(), precision.float(), recall.float() diff --git a/nemo/collections/asr/models/__init__.py b/nemo/collections/asr/models/__init__.py index e4a1342b9c36..34dead15b33d 100644 --- a/nemo/collections/asr/models/__init__.py +++ b/nemo/collections/asr/models/__init__.py @@ -35,6 +35,7 @@ from nemo.collections.asr.models.rnnt_bpe_models import EncDecRNNTBPEModel from nemo.collections.asr.models.rnnt_models import EncDecRNNTModel from nemo.collections.asr.models.slu_models import SLUIntentSlotBPEModel +from nemo.collections.asr.models.sortformer_diar_models import SortformerEncLabelModel from nemo.collections.asr.models.ssl_models import ( EncDecDenoiseMaskedTokenPredModel, EncDecMaskedTokenPredModel, diff --git a/nemo/collections/asr/models/aed_multitask_models.py b/nemo/collections/asr/models/aed_multitask_models.py index f18fe02d2ed8..969966839dde 100644 --- a/nemo/collections/asr/models/aed_multitask_models.py +++ b/nemo/collections/asr/models/aed_multitask_models.py @@ -21,8 +21,8 @@ import numpy as np import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig, ListConfig, OmegaConf, open_dict -from pytorch_lightning import Trainer from torch.utils.data import DataLoader from nemo.collections.asr.data.audio_to_text_lhotse_prompted import ( @@ -62,7 +62,6 @@ from nemo.utils import logging, model_utils from nemo.utils.decorators import deprecated - __all__ = ['EncDecMultiTaskModel'] diff --git a/nemo/collections/asr/models/classification_models.py b/nemo/collections/asr/models/classification_models.py index b49ef50583a7..f84ece6d24ce 100644 --- a/nemo/collections/asr/models/classification_models.py +++ b/nemo/collections/asr/models/classification_models.py @@ -21,8 +21,8 @@ from typing import Any, Dict, List, Optional, Union import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig, ListConfig, OmegaConf -from pytorch_lightning import Trainer from torch.utils.data import DataLoader from torchmetrics import Accuracy from torchmetrics.regression import MeanAbsoluteError, MeanSquaredError diff --git a/nemo/collections/asr/models/clustering_diarizer.py b/nemo/collections/asr/models/clustering_diarizer.py index ddcc269bedcc..1f03cec59af7 100644 --- a/nemo/collections/asr/models/clustering_diarizer.py +++ b/nemo/collections/asr/models/clustering_diarizer.py @@ -22,8 +22,8 @@ from typing import Any, List, Optional, Union import torch +from lightning.pytorch.utilities import rank_zero_only from omegaconf import DictConfig, OmegaConf -from pytorch_lightning.utilities import rank_zero_only from tqdm import tqdm from nemo.collections.asr.metrics.der import score_labels @@ -49,7 +49,6 @@ from nemo.core.classes import Model from nemo.utils import logging, model_utils - __all__ = ['ClusteringDiarizer'] _MODEL_CONFIG_YAML = "model_config.yaml" diff --git a/nemo/collections/asr/models/confidence_ensemble.py b/nemo/collections/asr/models/confidence_ensemble.py index c6b2846085af..932d221be0f8 100644 --- a/nemo/collections/asr/models/confidence_ensemble.py +++ b/nemo/collections/asr/models/confidence_ensemble.py @@ -18,8 +18,8 @@ import joblib import numpy as np import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig, open_dict -from pytorch_lightning import Trainer from nemo.collections.asr.models.asr_model import ASRModel from nemo.collections.asr.models.hybrid_rnnt_ctc_models import EncDecHybridRNNTCTCModel diff --git a/nemo/collections/asr/models/configs/asr_models_config.py b/nemo/collections/asr/models/configs/asr_models_config.py index 29dbbe06d1f8..081233da5d32 100644 --- a/nemo/collections/asr/models/configs/asr_models_config.py +++ b/nemo/collections/asr/models/configs/asr_models_config.py @@ -41,6 +41,17 @@ class ASRDatasetConfig(nemo.core.classes.dataset.DatasetConfig): shard_manifests: bool = False shuffle_n: int = 0 + # lhotse support + use_lhotse: bool = False + tarred_random_access: bool = False + use_bucketing: bool = False + batch_duration: Optional[int] = None + quadratic_duration: Optional[int] = None + bucket_batch_size: Optional[int] = None + bucket_duration_bins: Optional[list] = None + num_buckets: Optional[int] = 0 + pin_memory: bool = False + # Optional int_values: Optional[int] = None augmentor: Optional[Dict[str, Any]] = None diff --git a/nemo/collections/asr/models/ctc_bpe_models.py b/nemo/collections/asr/models/ctc_bpe_models.py index 79c22794de01..1f84989c8ebe 100644 --- a/nemo/collections/asr/models/ctc_bpe_models.py +++ b/nemo/collections/asr/models/ctc_bpe_models.py @@ -97,9 +97,15 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]): if config.get("use_lhotse"): return get_lhotse_dataloader_from_config( config, - global_rank=self.global_rank, - world_size=self.world_size, - dataset=LhotseSpeechToTextBpeDataset(tokenizer=self.tokenizer), + # During transcription, the model is initially loaded on the CPU. + # To ensure the correct global_rank and world_size are set, + # these values must be passed from the configuration. + global_rank=self.global_rank if not config.get("do_transcribe", False) else config.get("global_rank"), + world_size=self.world_size if not config.get("do_transcribe", False) else config.get("world_size"), + dataset=LhotseSpeechToTextBpeDataset( + tokenizer=self.tokenizer, + return_cuts=config.get("do_transcribe", False), + ), tokenizer=self.tokenizer, ) diff --git a/nemo/collections/asr/models/ctc_models.py b/nemo/collections/asr/models/ctc_models.py index 993c7dc6b298..ae8c35220931 100644 --- a/nemo/collections/asr/models/ctc_models.py +++ b/nemo/collections/asr/models/ctc_models.py @@ -18,8 +18,8 @@ import numpy as np import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig, OmegaConf, open_dict -from pytorch_lightning import Trainer from torch.utils.data import DataLoader from nemo.collections.asr.data import audio_to_text_dataset @@ -43,7 +43,6 @@ from nemo.utils import logging from nemo.utils.decorators import deprecated - __all__ = ['EncDecCTCModel'] @@ -161,6 +160,7 @@ def transcribe( A list of transcriptions (or raw log probabilities if logprobs is True) in the same order as paths2audio_files """ + timestamps = timestamps or (override_config.timestamps if override_config is not None else None) if timestamps is not None: # else retain the decoder state (users can set it using change_decoding_strategy) if timestamps or (override_config is not None and override_config.timestamps): @@ -309,8 +309,11 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]): if config.get("use_lhotse"): return get_lhotse_dataloader_from_config( config, - global_rank=self.global_rank, - world_size=self.world_size, + # During transcription, the model is initially loaded on the CPU. + # To ensure the correct global_rank and world_size are set, + # these values must be passed from the configuration. + global_rank=self.global_rank if not config.get("do_transcribe", False) else config.get("global_rank"), + world_size=self.world_size if not config.get("do_transcribe", False) else config.get("world_size"), dataset=LhotseSpeechToTextBpeDataset( tokenizer=make_parser( labels=config.get('labels', None), @@ -319,6 +322,7 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]): blank_id=config.get('blank_index', -1), do_normalize=config.get('normalize_transcripts', False), ), + return_cuts=config.get("do_transcribe", False), ), ) @@ -614,7 +618,8 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): return_hypotheses=False, ) - sample_id = sample_id.cpu().detach().numpy() + if isinstance(sample_id, torch.Tensor): + sample_id = sample_id.cpu().detach().numpy() return list(zip(sample_id, transcribed_texts)) def validation_pass(self, batch, batch_idx, dataloader_idx=0): diff --git a/nemo/collections/asr/models/hybrid_asr_tts_models.py b/nemo/collections/asr/models/hybrid_asr_tts_models.py index 628395e04f94..89a7e1289675 100644 --- a/nemo/collections/asr/models/hybrid_asr_tts_models.py +++ b/nemo/collections/asr/models/hybrid_asr_tts_models.py @@ -19,8 +19,8 @@ from typing import Any, Dict, List, Optional, Tuple, Union, cast import torch +from lightning.pytorch import Trainer from omegaconf import MISSING, DictConfig, OmegaConf, open_dict -from pytorch_lightning import Trainer from torch.nn.utils.rnn import pad_sequence from nemo.collections.asr.data.audio_to_text_dali import DALIOutputs @@ -324,7 +324,9 @@ def __setattr__(self, name, value): return super().__setattr__(name, value) def setup_optimization( - self, optim_config: Optional[Union[DictConfig, Dict]] = None, optim_kwargs: Optional[Dict[str, Any]] = None, + self, + optim_config: Optional[Union[DictConfig, Dict]] = None, + optim_kwargs: Optional[Dict[str, Any]] = None, ): """ Setup optimizer and scheduler. Ensure tts model is frozen. @@ -430,7 +432,8 @@ def _get_batch_spect(self, batch: Union[TextToTextBatch, TextOrAudioToTextBatch, elif isinstance(batch, TextOrAudioToTextBatch): tts_spectrogram, tts_spectrogram_len = self._get_tts_spectrogram(batch.tts_texts, batch.speakers) asr_spectrogram, asr_spectrogram_len = self.asr_model.preprocessor( - input_signal=batch.audio_signals, length=batch.audio_signal_lengths, + input_signal=batch.audio_signals, + length=batch.audio_signal_lengths, ) spectrogram = pad_sequence( diff --git a/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py index 1d437a19a86b..cd04a5ad2462 100644 --- a/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py +++ b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py @@ -17,8 +17,8 @@ from typing import Dict, List, Optional, Union import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig, ListConfig, OmegaConf, open_dict -from pytorch_lightning import Trainer from nemo.collections.asr.data import audio_to_text_dataset from nemo.collections.asr.data.audio_to_text import _AudioTextDataset @@ -140,10 +140,14 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]): if config.get("use_lhotse"): return get_lhotse_dataloader_from_config( config, - global_rank=self.global_rank, - world_size=self.world_size, + # During transcription, the model is initially loaded on the CPU. + # To ensure the correct global_rank and world_size are set, + # these values must be passed from the configuration. + global_rank=self.global_rank if not config.get("do_transcribe", False) else config.get("global_rank"), + world_size=self.world_size if not config.get("do_transcribe", False) else config.get("world_size"), dataset=LhotseSpeechToTextBpeDataset( tokenizer=self.tokenizer, + return_cuts=config.get("do_transcribe", False), ), tokenizer=self.tokenizer, ) diff --git a/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py b/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py index 028073d7ca7f..1f63c617cea2 100644 --- a/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py +++ b/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py @@ -19,8 +19,8 @@ from typing import Any, List, Optional, Tuple import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig, OmegaConf, open_dict -from pytorch_lightning import Trainer from tqdm.auto import tqdm from nemo.collections.asr.data.audio_to_text_dali import DALIOutputs @@ -519,8 +519,8 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): best_hyp_text, all_hyp_text = self.decoding.rnnt_decoder_predictions_tensor( encoder_output=encoded, encoded_lengths=encoded_len, return_hypotheses=False ) - - sample_id = sample_id.cpu().detach().numpy() + if isinstance(sample_id, torch.Tensor): + sample_id = sample_id.cpu().detach().numpy() return list(zip(sample_id, best_hyp_text)) def validation_pass(self, batch, batch_idx, dataloader_idx): diff --git a/nemo/collections/asr/models/k2_sequence_models.py b/nemo/collections/asr/models/k2_sequence_models.py index 087e9e41b85d..b60d08afe635 100644 --- a/nemo/collections/asr/models/k2_sequence_models.py +++ b/nemo/collections/asr/models/k2_sequence_models.py @@ -14,8 +14,8 @@ from typing import List, Optional +from lightning.pytorch import Trainer from omegaconf import DictConfig -from pytorch_lightning import Trainer from nemo.collections.asr.models.ctc_bpe_models import EncDecCTCModelBPE from nemo.collections.asr.models.ctc_models import EncDecCTCModel @@ -76,7 +76,11 @@ def change_vocabulary(self, new_vocabulary: List[str]): @typecheck() def forward( - self, input_signal=None, input_signal_length=None, processed_signal=None, processed_signal_length=None, + self, + input_signal=None, + input_signal_length=None, + processed_signal=None, + processed_signal_length=None, ): """ Forward pass of the model. @@ -159,7 +163,11 @@ def change_vocabulary(self, new_tokenizer_dir: str, new_tokenizer_type: str): @typecheck() def forward( - self, input_signal=None, input_signal_length=None, processed_signal=None, processed_signal_length=None, + self, + input_signal=None, + input_signal_length=None, + processed_signal=None, + processed_signal_length=None, ): """ Forward pass of the model. diff --git a/nemo/collections/asr/models/label_models.py b/nemo/collections/asr/models/label_models.py index 08c304e4c52c..37391879547b 100644 --- a/nemo/collections/asr/models/label_models.py +++ b/nemo/collections/asr/models/label_models.py @@ -24,8 +24,8 @@ import soundfile as sf import torch from hydra.utils import instantiate +from lightning.pytorch import Trainer from omegaconf import DictConfig, OmegaConf, open_dict -from pytorch_lightning import Trainer from sklearn.metrics import roc_curve from torchmetrics import Accuracy from tqdm import tqdm diff --git a/nemo/collections/asr/models/msdd_models.py b/nemo/collections/asr/models/msdd_models.py index c88275dcacd3..5e90f7d62d78 100644 --- a/nemo/collections/asr/models/msdd_models.py +++ b/nemo/collections/asr/models/msdd_models.py @@ -25,11 +25,11 @@ import numpy as np import torch from hydra.utils import instantiate +from lightning.pytorch import LightningModule, Trainer +from lightning.pytorch.utilities import rank_zero_only from omegaconf import DictConfig, open_dict from pyannote.core import Annotation from pyannote.metrics.diarization import DiarizationErrorRate -from pytorch_lightning import LightningModule, Trainer -from pytorch_lightning.utilities import rank_zero_only from tqdm import tqdm from nemo.collections.asr.data.audio_to_diar_label import AudioToSpeechMSDDInferDataset, AudioToSpeechMSDDTrainDataset @@ -70,6 +70,7 @@ @contextmanager def autocast(enabled=None): + """auto-casting context manager""" yield @@ -78,8 +79,8 @@ def autocast(enabled=None): class EncDecDiarLabelModel(ModelPT, ExportableEncDecModel): """ - Encoder decoder class for multiscale diarization decoder (MSDD). Model class creates training, validation methods for setting - up data performing model forward pass. + Encoder decoder class for multiscale diarization decoder (MSDD). Model class creates training, + validation methods for setting up data performing model forward pass. This model class expects config dict for: * preprocessor @@ -99,15 +100,18 @@ def list_available_models(cls) -> List[PretrainedModelInfo]: model = PretrainedModelInfo( pretrained_model_name="diar_msdd_telephonic", - location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/diar_msdd_telephonic/versions/1.0.1/files/diar_msdd_telephonic.nemo", - description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:diar_msdd_telephonic", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/" + "diar_msdd_telephonic/versions/1.0.1/files/diar_msdd_telephonic.nemo", + description="For details about this model, please visit " + "https://ngc.nvidia.com/catalog/models/nvidia:nemo:diar_msdd_telephonic", ) result.append(model) return result def __init__(self, cfg: DictConfig, trainer: Trainer = None): """ - Initialize an MSDD model and the specified speaker embedding model. In this init function, training and validation datasets are prepared. + Initialize an MSDD model and the specified speaker embedding model. In this init function, + training and validation datasets are prepared. """ self._trainer = trainer if trainer else None self.cfg_msdd_model = cfg @@ -173,9 +177,9 @@ def _init_segmentation_info(self): def _init_speaker_model(self): """ - Initialize speaker embedding model with model name or path passed through config. Note that speaker embedding model is loaded to - `self.msdd` to enable multi-gpu and multi-node training. In addition, speaker embedding model is also saved with msdd model when - `.ckpt` files are saved. + Initialize speaker embedding model with model name or path passed through config. Note that + speaker embedding model is loaded to `self.msdd` to enable multi-gpu and multi-node training. + In addition, speaker embedding model is also saved with msdd model when `.ckpt` files are saved. """ model_path = self.cfg_msdd_model.diarizer.speaker_embeddings.model_path self._diarizer_params = self.cfg_msdd_model.diarizer @@ -341,15 +345,17 @@ def get_ms_emb_seq( Merged embeddings without zero-padding in the batch. See `ms_seg_counts` for details. Shape: (Total number of segments in the batch, emb_dim) scale_mapping (Tensor): - The element at the m-th row and the n-th column of the scale mapping matrix indicates the (m+1)-th scale - segment index which has the closest center distance with (n+1)-th segment in the base scale. + The element at the m-th row and the n-th column of the scale mapping matrix indicates + the (m+1)-th scale segment index which has the closest center distance with (n+1)-th segment + in the base scale. + Example: scale_mapping_argmat[2][101] = 85 - In the above example, it means that 86-th segment in the 3rd scale (python index is 2) is mapped with - 102-th segment in the base scale. Thus, the longer segments bound to have more repeating numbers since - multiple base scale segments (since the base scale has the shortest length) fall into the range of the - longer segments. At the same time, each row contains N numbers of indices where N is number of - segments in the base-scale (i.e., the finest scale). + In the above example, it means that 86-th segment in the 3rd scale (python index is 2) + is mapped with 102-th segment in the base scale. Thus, the longer segments bound to have more + repeating numbers since multiple base scale segments (since the base scale has the shortest length) + fall into the range of the longer segments. At the same time, each row contains N numbers of + indices where N is number of segments in the base-scale (i.e., the finest scale). Shape: (batch_size, scale_n, self.diar_window_length) ms_seg_counts (Tensor): Cumulative sum of the number of segments in each scale. This information is needed to reconstruct @@ -366,8 +372,8 @@ def get_ms_emb_seq( Returns: ms_emb_seq (Tensor): - Multi-scale embedding sequence that is mapped, matched and repeated. The longer scales are less repeated, - while shorter scales are more frequently repeated following the scale mapping tensor. + Multi-scale embedding sequence that is mapped, matched and repeated. The longer scales are + less repeated, while shorter scales are more frequently repeated following the scale mapping tensor. """ scale_n, batch_size = scale_mapping[0].shape[0], scale_mapping.shape[0] split_emb_tup = torch.split(embs, ms_seg_counts.view(-1).tolist(), dim=0) @@ -388,19 +394,20 @@ def get_cluster_avg_embs_model( self, embs: torch.Tensor, clus_label_index: torch.Tensor, ms_seg_counts: torch.Tensor, scale_mapping ) -> torch.Tensor: """ - Calculate the cluster-average speaker embedding based on the ground-truth speaker labels (i.e., cluster labels). + Calculate the cluster-average speaker embedding based on the ground-truth speaker labels + (i.e., cluster labels). Args: embs (Tensor): Merged embeddings without zero-padding in the batch. See `ms_seg_counts` for details. Shape: (Total number of segments in the batch, emb_dim) clus_label_index (Tensor): - Merged ground-truth cluster labels from all scales with zero-padding. Each scale's index can be - retrieved by using segment index in `ms_seg_counts`. + Merged ground-truth cluster labels from all scales with zero-padding. Each scale's + index can be retrieved by using segment index in `ms_seg_counts`. Shape: (batch_size, maximum total segment count among the samples in the batch) ms_seg_counts (Tensor): - Cumulative sum of the number of segments in each scale. This information is needed to reconstruct - multi-scale input tensors during forward propagating. + Cumulative sum of the number of segments in each scale. This information is needed + to reconstruct multi-scale input tensors during forward propagating. Example: `batch_size=3, scale_n=6, emb_dim=192` .. code:: python @@ -420,8 +427,9 @@ def get_cluster_avg_embs_model( Returns: ms_avg_embs (Tensor): - Multi-scale cluster-average speaker embedding vectors. These embedding vectors are used as reference for - each speaker to predict the speaker label for the given multi-scale embedding sequences. + Multi-scale cluster-average speaker embedding vectors. These embedding vectors are used + as reference for each speaker to predict the speaker label for the given multi-scale + embedding sequences. Shape: (batch_size, scale_n, emb_dim, self.num_spks_per_model) """ scale_n, batch_size = scale_mapping[0].shape[0], scale_mapping.shape[0] @@ -534,7 +542,8 @@ def get_ms_mel_feat( def forward_infer(self, input_signal, input_signal_length, emb_vectors, targets): """ - Wrapper function for inference case. + Wrapper function for inference case. This `forward_infer` is only used during inference, where `forward` + is used for training and validation. """ preds, scale_weights = self.msdd( ms_emb_seq=input_signal, length=input_signal_length, ms_avg_embs=emb_vectors, targets=targets @@ -545,6 +554,7 @@ def forward_infer(self, input_signal, input_signal_length, emb_vectors, targets) def forward( self, features, feature_length, ms_seg_timestamps, ms_seg_counts, clus_label_index, scale_mapping, targets ): + """Function to compute forward pass for training/validation.""" processed_signal, processed_signal_len = self.msdd._speaker_model.preprocessor( input_signal=features, length=feature_length ) @@ -577,6 +587,7 @@ def forward( return preds, scale_weights def training_step(self, batch: list, batch_idx: int): + """Function to compute training step.""" features, feature_length, ms_seg_timestamps, ms_seg_counts, clus_label_index, scale_mapping, targets = batch sequence_lengths = torch.tensor([x[-1] for x in ms_seg_counts.detach()]) preds, _ = self.forward( @@ -588,10 +599,11 @@ def training_step(self, batch: list, batch_idx: int): scale_mapping=scale_mapping, targets=targets, ) - loss = self.loss(probs=preds, labels=targets, signal_lengths=sequence_lengths) + # loss = self.loss(probs=preds, labels=targets, signal_lengths=sequence_lengths) + loss = self.loss(probs=preds, labels=targets, target_lens=sequence_lengths) self._accuracy_train(preds, targets, sequence_lengths) torch.cuda.empty_cache() - f1_acc = self._accuracy_train.compute() + f1_acc, _, _ = self._accuracy_train.compute() self.log('loss', loss, sync_dist=True) self.log('learning_rate', self._optimizer.param_groups[0]['lr'], sync_dist=True) self.log('train_f1_acc', f1_acc, sync_dist=True) @@ -599,6 +611,7 @@ def training_step(self, batch: list, batch_idx: int): return {'loss': loss} def validation_step(self, batch: list, batch_idx: int, dataloader_idx: int = 0): + """Function to compute validation step.""" features, feature_length, ms_seg_timestamps, ms_seg_counts, clus_label_index, scale_mapping, targets = batch sequence_lengths = torch.tensor([x[-1] for x in ms_seg_counts]) preds, _ = self.forward( @@ -610,9 +623,10 @@ def validation_step(self, batch: list, batch_idx: int, dataloader_idx: int = 0): scale_mapping=scale_mapping, targets=targets, ) - loss = self.loss(probs=preds, labels=targets, signal_lengths=sequence_lengths) + # loss = self.loss(probs=preds, labels=targets, signal_lengths=sequence_lengths) + loss = self.loss(probs=preds, labels=targets, target_lens=sequence_lengths) self._accuracy_valid(preds, targets, sequence_lengths) - f1_acc = self._accuracy_valid.compute() + f1_acc, _, _ = self._accuracy_valid.compute() self.log('val_loss', loss, sync_dist=True) self.log('val_f1_acc', f1_acc, sync_dist=True) return { @@ -622,7 +636,7 @@ def validation_step(self, batch: list, batch_idx: int, dataloader_idx: int = 0): def multi_validation_epoch_end(self, outputs: list, dataloader_idx: int = 0): val_loss_mean = torch.stack([x['val_loss'] for x in outputs]).mean() - f1_acc = self._accuracy_valid.compute() + f1_acc, _, _ = self._accuracy_valid.compute() self._accuracy_valid.reset() self.log('val_loss', val_loss_mean, sync_dist=True) @@ -634,7 +648,7 @@ def multi_validation_epoch_end(self, outputs: list, dataloader_idx: int = 0): def multi_test_epoch_end(self, outputs: List[Dict[str, torch.Tensor]], dataloader_idx: int = 0): test_loss_mean = torch.stack([x['test_loss'] for x in outputs]).mean() - f1_acc = self._accuracy_test.compute() + f1_acc, _, _ = self._accuracy_test.compute() self._accuracy_test.reset() self.log('test_f1_acc', f1_acc, sync_dist=True) return { @@ -648,9 +662,10 @@ def compute_accuracies(self): Returns: f1_score (float): F1 score of the estimated diarized speaker label sequences. - simple_acc (float): Accuracy of predicted speaker labels: (total # of correct labels)/(total # of sigmoid values) + simple_acc (float): Accuracy of predicted speaker labels: + (total # of correct labels)/(total # of sigmoid values) """ - f1_score = self._accuracy_test.compute() + f1_score, _, _ = self._accuracy_test.compute() num_correct = torch.sum(self._accuracy_test.true.bool()) total_count = torch.prod(torch.tensor(self._accuracy_test.targets.shape)) simple_acc = num_correct / total_count @@ -659,7 +674,9 @@ def compute_accuracies(self): class ClusterEmbedding(torch.nn.Module): """ - This class is built for calculating cluster-average embeddings, segmentation and load/save of the estimated cluster labels. + This class is built for calculating cluster-average embeddings, segmentation and load/save of + the estimated cluster labels. + The methods in this class is used for the inference of MSDD models. Args: @@ -708,10 +725,10 @@ def prepare_cluster_embs_infer(self): def assign_labels_to_longer_segs(self, base_clus_label_dict: Dict, session_scale_mapping_dict: Dict): """ - In multi-scale speaker diarization system, clustering result is solely based on the base-scale (the shortest scale). - To calculate cluster-average speaker embeddings for each scale that are longer than the base-scale, this function assigns - clustering results for the base-scale to the longer scales by measuring the distance between subsegment timestamps in the - base-scale and non-base-scales. + In multi-scale speaker diarization system, clustering result is solely based on the base-scale + (the shortest scale). To calculate cluster-average speaker embeddings for each scale that are longer + than the base-scale, this function assigns clustering results for the base-scale to the longer scales + by measuring the distance between subsegment timestamps in the base-scale and non-base-scales. Args: base_clus_label_dict (dict): @@ -754,7 +771,8 @@ def get_base_clus_label_dict(self, clus_labels: List[str], emb_scale_seq_dict: D Dictionary containing multiscale embedding input sequences. Returns: base_clus_label_dict (dict): - Dictionary containing start and end of base scale segments and its cluster label. Indexed by `uniq_id`. + Dictionary containing start and end of base scale segments and its cluster label. + Indexed by `uniq_id`. emb_dim (int): Embedding dimension in integer. """ @@ -771,17 +789,18 @@ def get_cluster_avg_embs( self, emb_scale_seq_dict: Dict, clus_labels: List, speaker_mapping_dict: Dict, session_scale_mapping_dict: Dict ): """ - MSDD requires cluster-average speaker embedding vectors for each scale. This function calculates an average embedding vector for each cluster (speaker) - and each scale. + MSDD requires cluster-average speaker embedding vectors for each scale. This function calculates + an average embedding vector for each cluster (speaker) and each scale. Args: emb_scale_seq_dict (dict): Dictionary containing embedding sequence for each scale. Keys are scale index in integer. clus_labels (list): - Clustering results from clustering diarizer including all the sessions provided in input manifest files. + Clustering results from clustering diarizer including all the sessions provided + in input manifest files. speaker_mapping_dict (dict): - Speaker mapping dictionary in case RTTM files are provided. This is mapping between integer based speaker index and - speaker ID tokens in RTTM files. + Speaker mapping dictionary in case RTTM files are provided. This is mapping between + integer based speaker index and speaker ID tokens in RTTM files. Example: {'en_0638': {'speaker_0': 'en_0638_A', 'speaker_1': 'en_0638_B'}, 'en_4065': {'speaker_0': 'en_4065_B', 'speaker_1': 'en_4065_A'}, ...,} @@ -793,7 +812,8 @@ def get_cluster_avg_embs( Dictionary containing speaker mapping information and cluster-average speaker embedding vector. Each session-level dictionary is indexed by scale index in integer. output_clus_label_dict (dict): - Subegmentation timestamps in float type and Clustering result in integer type. Indexed by `uniq_id` keys. + Subegmentation timestamps in float type and Clustering result in integer type. + Indexed by `uniq_id` keys. """ self.scale_n = len(emb_scale_seq_dict.keys()) emb_sess_avg_dict = { @@ -830,9 +850,10 @@ def get_cluster_avg_embs( def run_clustering_diarizer(self, manifest_filepath: str, emb_dir: str): """ - If no pre-existing data is provided, run clustering diarizer from scratch. This will create scale-wise speaker embedding - sequence, cluster-average embeddings, scale mapping and base scale clustering labels. Note that speaker embedding `state_dict` - is loaded from the `state_dict` in the provided MSDD checkpoint. + If no pre-existing data is provided, run clustering diarizer from scratch. This will create + scale-wise speaker embedding sequence, cluster-average embeddings, scale mapping and base scale + clustering labels. Note that speaker embedding `state_dict` is loaded from the `state_dict` + in the provided MSDD checkpoint. Args: manifest_filepath (str): @@ -846,7 +867,8 @@ def run_clustering_diarizer(self, manifest_filepath: str, emb_dir: str): emb_scale_seq_dict (dict): Dictionary containing embedding tensors which are indexed by scale numbers. base_clus_label_dict (dict): - Dictionary containing clustering results. Clustering results are cluster labels for the base scale segments. + Dictionary containing clustering results. Clustering results are cluster labels + for the base scale segments. """ self.cfg_diar_infer.diarizer.manifest_filepath = manifest_filepath self.cfg_diar_infer.diarizer.out_dir = emb_dir @@ -974,9 +996,9 @@ def load_emb_scale_seq_dict(self, out_dir): class NeuralDiarizer(LightningModule): """ - Class for inference based on multiscale diarization decoder (MSDD). MSDD requires initializing clustering results from - clustering diarizer. Overlap-aware diarizer requires separate RTTM generation and evaluation modules to check the effect of - overlap detection in speaker diarization. + Class for inference based on multiscale diarization decoder (MSDD). MSDD requires initializing + clustering results from clustering diarizer. Overlap-aware diarizer requires separate RTTM + generation and evaluation modules to check the effect of overlap detection in speaker diarization. """ def __init__(self, cfg: Union[DictConfig, NeuralDiarizerInferenceConfig]): @@ -1029,7 +1051,8 @@ def save_to(self, save_path: str): You can use "restore_from" method to fully restore instance from .nemo file. .nemo file is an archive (tar.gz) with the following: - model_config.yaml - model configuration in .yaml format. You can deserialize this into cfg argument for model's constructor + model_config.yaml - model configuration in .yaml format. + You can deserialize this into cfg argument for model's constructor model_wights.chpt - model checkpoint Args: @@ -1053,8 +1076,8 @@ def save_to(self, save_path: str): def extract_standalone_speaker_model(self, prefix: str = 'msdd._speaker_model.') -> EncDecSpeakerLabelModel: """ - MSDD model file contains speaker embedding model and MSDD model. This function extracts standalone speaker model and save it to - `self.spk_emb_state_dict` to be loaded separately for clustering diarizer. + MSDD model file contains speaker embedding model and MSDD model. This function extracts standalone + speaker model and save it to `self.spk_emb_state_dict` to be loaded separately for clustering diarizer. Args: ext (str): @@ -1104,20 +1127,22 @@ def _init_msdd_model(self, cfg: Union[DictConfig, NeuralDiarizerInferenceConfig] def get_pred_mat(self, data_list: List[Union[Tuple[int], List[torch.Tensor]]]) -> torch.Tensor: """ - This module puts together the pairwise, two-speaker, predicted results to form a finalized matrix that has dimension of - `(total_len, n_est_spks)`. The pairwise results are evenutally averaged. For example, in 4 speaker case (speaker 1, 2, 3, 4), - the sum of the pairwise results (1, 2), (1, 3), (1, 4) are then divided by 3 to take average of the sigmoid values. + This module puts together the pairwise, two-speaker, predicted results to form a finalized matrix + that has dimension of `(total_len, n_est_spks)`. The pairwise results are evenutally averaged. + For example, in 4 speaker case (speaker 1, 2, 3, 4), the sum of the pairwise results + (1, 2), (1, 3), (1, 4) are then divided by 3 to take average of the sigmoid values. Args: data_list (list): - List containing data points from `test_data_collection` variable. `data_list` has sublists `data` as follows: - data[0]: `target_spks` tuple - Examples: (0, 1, 2) - data[1]: Tensor containing estimaged sigmoid values. - [[0.0264, 0.9995], - [0.0112, 1.0000], - ..., - [1.0000, 0.0512]] + List containing data points from `test_data_collection` variable. `data_list` + has sublists `data` as follows: + data[0]: `target_spks` tuple + Examples: (0, 1, 2) + data[1]: Tensor containing estimaged sigmoid values. + [[0.0264, 0.9995], + [0.0112, 1.0000], + ..., + [1.0000, 0.0512]] Returns: sum_pred (Tensor): @@ -1152,7 +1177,8 @@ def get_integrated_preds_list( uniq_id_list (list): List containing `uniq_id` values. test_data_collection (collections.DiarizationLabelEntity): - Class instance that is containing session information such as targeted speaker indices, audio filepaths and RTTM filepaths. + Class instance that is containing session information such as targeted speaker indices, + audio filepaths and RTTM filepaths. preds_list (list): List containing tensors filled with sigmoid values. @@ -1177,9 +1203,11 @@ def get_emb_clus_infer(self, cluster_embeddings): @torch.no_grad() def diarize(self) -> Optional[List[Optional[List[Tuple[DiarizationErrorRate, Dict]]]]]: """ - Launch diarization pipeline which starts from VAD (or a oracle VAD stamp generation), initialization clustering and multiscale diarization decoder (MSDD). - Note that the result of MSDD can include multiple speakers at the same time. Therefore, RTTM output of MSDD needs to be based on `make_rttm_with_overlap()` - function that can generate overlapping timestamps. `self.run_overlap_aware_eval()` function performs DER evaluation. + Launch diarization pipeline which starts from VAD (or a oracle VAD stamp generation), + initialization clustering and multiscale diarization decoder (MSDD). Note that the result of MSDD + can include multiple speakers at the same time. Therefore, RTTM output of MSDD needs to be based on + `make_rttm_with_overlap()` function that can generate overlapping timestamps. + `self.run_overlap_aware_eval()` function performs DER evaluation. """ self.clustering_embedding.prepare_cluster_embs_infer() self.msdd_model.pairwise_infer = True @@ -1192,10 +1220,11 @@ def get_range_average( self, signals: torch.Tensor, emb_vectors: torch.Tensor, diar_window_index: int, test_data_collection: List[Any] ) -> Tuple[torch.Tensor, torch.Tensor, int]: """ - This function is only used when `split_infer=True`. This module calculates cluster-average embeddings for the given short range. - The range length is set by `self.diar_window_length`, and each cluster-average is only calculated for the specified range. - Note that if the specified range does not contain some speakers (e.g. the range contains speaker 1, 3) compared to the global speaker sets - (e.g. speaker 1, 2, 3, 4) then the missing speakers (e.g. speakers 2, 4) are assigned with zero-filled cluster-average speaker embedding. + This function is only used when `split_infer=True`. This module calculates cluster-average embeddings + for the given short range. The range length is set by `self.diar_window_length`, and each cluster-average + is only calculated for the specified range. Note that if the specified range does not contain some speakers + (e.g. the range contains speaker 1, 3) compared to the global speaker sets (e.g. speaker 1, 2, 3, 4) then + the missing speakers (e.g. speakers 2, 4) are assigned with zero-filled cluster-average speaker embedding. Args: signals (Tensor): @@ -1207,7 +1236,8 @@ def get_range_average( diar_window_index (int): Index of split diarization wondows. test_data_collection (collections.DiarizationLabelEntity) - Class instance that is containing session information such as targeted speaker indices, audio filepath and RTTM filepath. + Class instance that is containing session information such as targeted speaker indices, + audio filepath and RTTM filepath. Returns: return emb_vectors_split (Tensor): @@ -1237,7 +1267,8 @@ def get_range_average( ) target_clus_label_bool = target_clus_label_tensor == test_data_collection.target_spks[spk_idx] - # There are cases where there is no corresponding speaker in split range, so any(target_clus_label_bool) could be False. + # There are cases where there is no corresponding speaker in split range, + # so any(target_clus_label_bool) could be False. if any(target_clus_label_bool): emb_vectors_split[:, :, spk_idx] = torch.mean(emb_seq[target_clus_label_bool], dim=0) @@ -1263,14 +1294,17 @@ def get_range_clus_avg_emb( self, test_batch: List[torch.Tensor], _test_data_collection: List[Any], device: torch.device('cpu') ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ - This function is only used when `get_range_average` function is called. This module calculates cluster-average embeddings for - the given short range. The range length is set by `self.diar_window_length`, and each cluster-average is only calculated for the specified range. + This function is only used when `get_range_average` function is called. This module calculates + cluster-average embeddings for the given short range. The range length is set by `self.diar_window_length`, + and each cluster-average is only calculated for the specified range. Args: test_batch: (list) - List containing embedding sequences, length of embedding sequences, ground truth labels (if exists) and initializing embedding vectors. + List containing embedding sequences, length of embedding sequences, ground truth labels + (if exists) and initializing embedding vectors. test_data_collection: (list) - List containing test-set dataloader contents. test_data_collection includes wav file path, RTTM file path, clustered speaker indices. + List containing test-set dataloader contents. test_data_collection includes wav file path, + RTTM file path, clustered speaker indices. Returns: sess_emb_vectors (Tensor): @@ -1305,16 +1339,18 @@ def diar_infer( self, test_batch: List[torch.Tensor], test_data_collection: List[Any] ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ - Launch forward_infer() function by feeding the session-wise embedding sequences to get pairwise speaker prediction values. - If split_infer is True, the input audio clips are broken into short sequences then cluster average embeddings are calculated - for inference. Split-infer might result in an improved results if calculating clustering average on the shorter tim-espan can - help speaker assignment. + Launch forward_infer() function by feeding the session-wise embedding sequences to get pairwise + speaker prediction values. If split_infer is True, the input audio clips are broken into short + sequences then cluster average embeddings are calculated for inference. Split-infer might result in + an improved results if calculating clustering average on the shorter tim-espan can help speaker assignment. Args: test_batch: (list) - List containing embedding sequences, length of embedding sequences, ground truth labels (if exists) and initializing embedding vectors. + List containing embedding sequences, length of embedding sequences, ground truth labels (if exists) + and initializing embedding vectors. test_data_collection: (list) - List containing test-set dataloader contents. test_data_collection includes wav file path, RTTM file path, clustered speaker indices. + List containing test-set dataloader contents. test_data_collection includes wav file path, + RTTM file path, clustered speaker indices. Returns: preds (Tensor): @@ -1353,8 +1389,9 @@ def diar_infer( @torch.no_grad() def run_pairwise_diarization(self) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]: """ - Setup the parameters needed for batch inference and run batch inference. Note that each sample is pairwise speaker input. - The pairwise inference results are reconstructed to make session-wise prediction results. + Setup the parameters needed for batch inference and run batch inference. Note that each sample is + pairwise speaker input. The pairwise inference results are reconstructed to make session-wise + prediction results. Returns: integrated_preds_list: (list) @@ -1405,7 +1442,8 @@ def run_overlap_aware_eval( - If threshold is 0.0, all speakers are considered active at any time step. """ logging.info( - f" [Threshold: {threshold:.4f}] [use_clus_as_main={self.use_clus_as_main}] [diar_window={self.diar_window_length}]" + f" [Threshold: {threshold:.4f}] [use_clus_as_main={self.use_clus_as_main}] " + f"[diar_window={self.diar_window_length}]" ) outputs = [] manifest_filepath = self.msdd_model.cfg.test_ds.manifest_filepath diff --git a/nemo/collections/asr/models/rnnt_bpe_models.py b/nemo/collections/asr/models/rnnt_bpe_models.py index 25890ec716c8..cd8667f2f0fe 100644 --- a/nemo/collections/asr/models/rnnt_bpe_models.py +++ b/nemo/collections/asr/models/rnnt_bpe_models.py @@ -17,8 +17,8 @@ from typing import Dict, List, Optional, Union import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig, ListConfig, OmegaConf, open_dict -from pytorch_lightning import Trainer from nemo.collections.asr.data import audio_to_text_dataset from nemo.collections.asr.data.audio_to_text import _AudioTextDataset @@ -509,10 +509,14 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]): if config.get("use_lhotse"): return get_lhotse_dataloader_from_config( config, - global_rank=self.global_rank, - world_size=self.world_size, + # During transcription, the model is initially loaded on the CPU. + # To ensure the correct global_rank and world_size are set, + # these values must be passed from the configuration. + global_rank=self.global_rank if not config.get("do_transcribe", False) else config.get("global_rank"), + world_size=self.world_size if not config.get("do_transcribe", False) else config.get("world_size"), dataset=LhotseSpeechToTextBpeDataset( tokenizer=self.tokenizer, + return_cuts=config.get("do_transcribe", False), ), tokenizer=self.tokenizer, ) diff --git a/nemo/collections/asr/models/rnnt_models.py b/nemo/collections/asr/models/rnnt_models.py index ce3b6bc89bce..78038d404107 100644 --- a/nemo/collections/asr/models/rnnt_models.py +++ b/nemo/collections/asr/models/rnnt_models.py @@ -19,8 +19,8 @@ import numpy as np import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig, OmegaConf, open_dict -from pytorch_lightning import Trainer from torch.utils.data import DataLoader from nemo.collections.asr.data import audio_to_text_dataset @@ -285,7 +285,7 @@ def transcribe( * A list of greedy transcript texts / Hypothesis * An optional list of beam search transcript texts / Hypothesis / NBestHypothesis. """ - + timestamps = timestamps or (override_config.timestamps if override_config is not None else None) if timestamps is not None: if timestamps or (override_config is not None and override_config.timestamps): logging.info( @@ -469,8 +469,11 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]): if config.get("use_lhotse"): return get_lhotse_dataloader_from_config( config, - global_rank=self.global_rank, - world_size=self.world_size, + # During transcription, the model is initially loaded on the CPU. + # To ensure the correct global_rank and world_size are set, + # these values must be passed from the configuration. + global_rank=self.global_rank if not config.get("do_transcribe", False) else config.get("global_rank"), + world_size=self.world_size if not config.get("do_transcribe", False) else config.get("world_size"), dataset=LhotseSpeechToTextBpeDataset( tokenizer=make_parser( labels=config.get('labels', None), @@ -479,6 +482,7 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]): blank_id=config.get('blank_index', -1), do_normalize=config.get('normalize_transcripts', False), ), + return_cuts=config.get("do_transcribe", False), ), ) @@ -814,7 +818,8 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): encoder_output=encoded, encoded_lengths=encoded_len, return_hypotheses=False ) - sample_id = sample_id.cpu().detach().numpy() + if isinstance(sample_id, torch.Tensor): + sample_id = sample_id.cpu().detach().numpy() return list(zip(sample_id, best_hyp_text)) def validation_pass(self, batch, batch_idx, dataloader_idx=0): diff --git a/nemo/collections/asr/models/sortformer_diar_models.py b/nemo/collections/asr/models/sortformer_diar_models.py new file mode 100644 index 000000000000..f6b0eab4c895 --- /dev/null +++ b/nemo/collections/asr/models/sortformer_diar_models.py @@ -0,0 +1,579 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import random +from collections import OrderedDict +from typing import Dict, List, Optional, Union + +import torch +from hydra.utils import instantiate +from omegaconf import DictConfig +from pytorch_lightning import Trainer +from tqdm import tqdm + +from nemo.collections.asr.data.audio_to_diar_label import AudioToSpeechE2ESpkDiarDataset +from nemo.collections.asr.data.audio_to_diar_label_lhotse import LhotseAudioToSpeechE2ESpkDiarDataset +from nemo.collections.asr.metrics.multi_binary_acc import MultiBinaryAccuracy +from nemo.collections.asr.models.asr_model import ExportableEncDecModel +from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer +from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations +from nemo.collections.asr.parts.utils.asr_multispeaker_utils import get_ats_targets, get_pil_targets +from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config +from nemo.core.classes import ModelPT +from nemo.core.classes.common import PretrainedModelInfo +from nemo.core.neural_types import AudioSignal, LengthsType, NeuralType +from nemo.core.neural_types.elements import ProbsType +from nemo.utils import logging + +__all__ = ['SortformerEncLabelModel'] + + +class SortformerEncLabelModel(ModelPT, ExportableEncDecModel): + """ + Encoder class for Sortformer diarization model. + Model class creates training, validation methods for setting up data performing model forward pass. + + This model class expects config dict for: + * preprocessor + * Transformer Encoder + * FastConformer Encoder + * Sortformer Modules + """ + + @classmethod + def list_available_models(cls) -> List[PretrainedModelInfo]: + """ + This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud. + + Returns: + List of available pre-trained models. + """ + result = [] + return result + + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + """ + Initialize an Sortformer Diarizer model and a pretrained NEST encoder. + In this init function, training and validation datasets are prepared. + """ + random.seed(42) + self._trainer = trainer if trainer else None + self._cfg = cfg + + if self._trainer: + self.world_size = trainer.num_nodes * trainer.num_devices + else: + self.world_size = 1 + + if self._trainer is not None and self._cfg.get('augmentor', None) is not None: + self.augmentor = process_augmentations(self._cfg.augmentor) + else: + self.augmentor = None + super().__init__(cfg=self._cfg, trainer=trainer) + self.preprocessor = SortformerEncLabelModel.from_config_dict(self._cfg.preprocessor) + + if hasattr(self._cfg, 'spec_augment') and self._cfg.spec_augment is not None: + self.spec_augmentation = SortformerEncLabelModel.from_config_dict(self._cfg.spec_augment) + else: + self.spec_augmentation = None + + self.encoder = SortformerEncLabelModel.from_config_dict(self._cfg.encoder).to(self.device) + self.sortformer_modules = SortformerEncLabelModel.from_config_dict(self._cfg.sortformer_modules).to( + self.device + ) + self.transformer_encoder = SortformerEncLabelModel.from_config_dict(self._cfg.transformer_encoder).to( + self.device + ) + if self._cfg.encoder.d_model != self._cfg.model_defaults.tf_d_model: + self.sortformer_modules.encoder_proj = self.sortformer_modules.encoder_proj.to(self.device) + else: + self.sortformer_modules.encoder_proj = None + self._init_loss_weights() + + self.eps = 1e-3 + self.loss = instantiate(self._cfg.loss) + + self.streaming_mode = self._cfg.get("streaming_mode", False) + self.save_hyperparameters("cfg") + self._init_eval_metrics() + + speaker_inds = list(range(self._cfg.max_num_of_spks)) + self.speaker_permutations = torch.tensor(list(itertools.permutations(speaker_inds))) # Get all permutations + + def _init_loss_weights(self): + pil_weight = self._cfg.get("pil_weight", 0.0) + ats_weight = self._cfg.get("ats_weight", 1.0) + if pil_weight + ats_weight == 0: + raise ValueError(f"weights for PIL {pil_weight} and ATS {ats_weight} cannot sum to 0") + self.pil_weight = pil_weight / (pil_weight + ats_weight) + self.ats_weight = ats_weight / (pil_weight + ats_weight) + logging.info(f"Normalized weights for PIL {self.pil_weight} and ATS {self.ats_weight}") + + def _init_eval_metrics(self): + """ + If there is no label, then the evaluation metrics will be based on Permutation Invariant Loss (PIL). + """ + self._accuracy_test = MultiBinaryAccuracy() + self._accuracy_train = MultiBinaryAccuracy() + self._accuracy_valid = MultiBinaryAccuracy() + + self._accuracy_test_ats = MultiBinaryAccuracy() + self._accuracy_train_ats = MultiBinaryAccuracy() + self._accuracy_valid_ats = MultiBinaryAccuracy() + + def _reset_train_metrics(self): + self._accuracy_train.reset() + self._accuracy_train_ats.reset() + + def _reset_valid_metrics(self): + self._accuracy_valid.reset() + self._accuracy_valid_ats.reset() + + def __setup_dataloader_from_config(self, config): + # Switch to lhotse dataloader if specified in the config + if config.get("use_lhotse"): + return get_lhotse_dataloader_from_config( + config, + global_rank=self.global_rank, + world_size=self.world_size, + dataset=LhotseAudioToSpeechE2ESpkDiarDataset(cfg=config), + ) + + featurizer = WaveformFeaturizer( + sample_rate=config['sample_rate'], int_values=config.get('int_values', False), augmentor=self.augmentor + ) + + if 'manifest_filepath' in config and config['manifest_filepath'] is None: + logging.warning(f"Could not load dataset as `manifest_filepath` was None. Provided config : {config}") + return None + + logging.info(f"Loading dataset from {config.manifest_filepath}") + + if self._trainer is not None: + global_rank = self._trainer.global_rank + else: + global_rank = 0 + + dataset = AudioToSpeechE2ESpkDiarDataset( + manifest_filepath=config.manifest_filepath, + soft_label_thres=config.soft_label_thres, + session_len_sec=config.session_len_sec, + num_spks=config.num_spks, + featurizer=featurizer, + window_stride=self._cfg.preprocessor.window_stride, + global_rank=global_rank, + soft_targets=config.soft_targets if 'soft_targets' in config else False, + ) + + self.data_collection = dataset.collection + self.collate_ds = dataset + + dataloader_instance = torch.utils.data.DataLoader( + dataset=dataset, + batch_size=config.batch_size, + collate_fn=self.collate_ds.eesd_train_collate_fn, + drop_last=config.get('drop_last', False), + shuffle=False, + num_workers=config.get('num_workers', 1), + pin_memory=config.get('pin_memory', False), + ) + return dataloader_instance + + def setup_training_data(self, train_data_config: Optional[Union[DictConfig, Dict]]): + self._train_dl = self.__setup_dataloader_from_config( + config=train_data_config, + ) + + def setup_validation_data(self, val_data_layer_config: Optional[Union[DictConfig, Dict]]): + self._validation_dl = self.__setup_dataloader_from_config( + config=val_data_layer_config, + ) + + def setup_test_data(self, test_data_config: Optional[Union[DictConfig, Dict]]): + self._test_dl = self.__setup_dataloader_from_config( + config=test_data_config, + ) + + def test_dataloader(self): + if self._test_dl is not None: + return self._test_dl + return None + + @property + def input_types(self) -> Optional[Dict[str, NeuralType]]: + if hasattr(self.preprocessor, '_sample_rate'): + audio_eltype = AudioSignal(freq=self.preprocessor._sample_rate) + else: + audio_eltype = AudioSignal() + return { + "audio_signal": NeuralType(('B', 'T'), audio_eltype), + "audio_signal_length": NeuralType(('B',), LengthsType()), + } + + @property + def output_types(self) -> Dict[str, NeuralType]: + return OrderedDict( + { + "preds": NeuralType(('B', 'T', 'C'), ProbsType()), + } + ) + + def frontend_encoder(self, processed_signal, processed_signal_length): + """ + Generate encoder outputs from frontend encoder. + + Args: + processed_signal (torch.Tensor): tensor containing audio-feature (mel spectrogram, mfcc, etc.) + processed_signal_length (torch.Tensor): tensor containing lengths of audio signal in integers + + Returns: + emb_seq (torch.Tensor): tensor containing encoder outputs + emb_seq_length (torch.Tensor): tensor containing lengths of encoder outputs + """ + # Spec augment is not applied during evaluation/testing + if self.spec_augmentation is not None and self.training: + processed_signal = self.spec_augmentation(input_spec=processed_signal, length=processed_signal_length) + emb_seq, emb_seq_length = self.encoder(audio_signal=processed_signal, length=processed_signal_length) + emb_seq = emb_seq.transpose(1, 2) + if self.sortformer_modules.encoder_proj is not None: + emb_seq = self.sortformer_modules.encoder_proj(emb_seq) + return emb_seq, emb_seq_length + + def forward_infer(self, emb_seq): + """ + The main forward pass for diarization for offline diarization inference. + + Args: + emb_seq (torch.Tensor): tensor containing FastConformer encoder states (embedding vectors). + Dimension: (batch_size, diar_frame_count, emb_dim) + + Returns: + preds (torch.Tensor): Sorted tensor containing Sigmoid values for predicted speaker labels. + Dimension: (batch_size, diar_frame_count, num_speakers) + """ + encoder_mask = self.sortformer_modules.length_to_mask(emb_seq) + trans_emb_seq = self.transformer_encoder(encoder_states=emb_seq, encoder_mask=encoder_mask) + preds = self.sortformer_modules.forward_speaker_sigmoids(trans_emb_seq) + return preds + + def process_signal(self, audio_signal, audio_signal_length): + """ + Extract audio features from time-series signal for further processing in the model. + + This function performs the following steps: + 1. Moves the audio signal to the correct device. + 2. Normalizes the time-series audio signal. + 3. Extrac audio feature from from the time-series audio signal using the model's preprocessor. + + Args: + audio_signal (torch.Tensor): The input audio signal. + Shape: (batch_size, num_samples) + audio_signal_length (torch.Tensor): The length of each audio signal in the batch. + Shape: (batch_size,) + + Returns: + tuple: A tuple containing: + - processed_signal (torch.Tensor): The preprocessed audio signal. + Shape: (batch_size, num_features, num_frames) + - processed_signal_length (torch.Tensor): The length of each processed signal. + Shape: (batch_size,) + """ + audio_signal = audio_signal.to(self.device) + audio_signal = (1 / (audio_signal.max() + self.eps)) * audio_signal + processed_signal, processed_signal_length = self.preprocessor( + input_signal=audio_signal, length=audio_signal_length + ) + return processed_signal, processed_signal_length + + def forward( + self, + audio_signal, + audio_signal_length, + ): + """ + Forward pass for training and inference. + + Args: + audio_signal (torch.Tensor): tensor containing audio waveform + Dimension: (batch_size, num_samples) + audio_signal_length (torch.Tensor): tensor containing lengths of audio waveforms + Dimension: (batch_size,) + + Returns: + preds (torch.Tensor): Sorted tensor containing predicted speaker labels + Dimension: (batch_size, diar_frame_count, num_speakers) + """ + processed_signal, processed_signal_length = self.process_signal( + audio_signal=audio_signal, audio_signal_length=audio_signal_length + ) + processed_signal = processed_signal[:, :, : processed_signal_length.max()] + if self._cfg.get("streaming_mode", False): + raise NotImplementedError("Streaming mode is not implemented yet.") + else: + emb_seq, _ = self.frontend_encoder( + processed_signal=processed_signal, processed_signal_length=processed_signal_length + ) + preds = self.forward_infer(emb_seq) + return preds + + def _get_aux_train_evaluations(self, preds, targets, target_lens) -> dict: + """ + Compute auxiliary training evaluations including losses and metrics. + + This function calculates various losses and metrics for the training process, + including Arrival Time Sort (ATS) Loss and Permutation Invariant Loss (PIL) + based evaluations. + + Args: + preds (torch.Tensor): Predicted speaker labels. + Shape: (batch_size, diar_frame_count, num_speakers) + targets (torch.Tensor): Ground truth speaker labels. + Shape: (batch_size, diar_frame_count, num_speakers) + target_lens (torch.Tensor): Lengths of target sequences. + Shape: (batch_size,) + + Returns: + (dict): A dictionary containing the following training metrics. + """ + targets_ats = get_ats_targets(targets.clone(), preds, speaker_permutations=self.speaker_permutations) + targets_pil = get_pil_targets(targets.clone(), preds, speaker_permutations=self.speaker_permutations) + ats_loss = self.loss(probs=preds, labels=targets_ats, target_lens=target_lens) + pil_loss = self.loss(probs=preds, labels=targets_pil, target_lens=target_lens) + loss = self.ats_weight * ats_loss + self.pil_weight * pil_loss + + self._accuracy_train(preds, targets_pil, target_lens) + train_f1_acc, train_precision, train_recall = self._accuracy_train.compute() + + self._accuracy_train_ats(preds, targets_ats, target_lens) + train_f1_acc_ats, _, _ = self._accuracy_train_ats.compute() + + train_metrics = { + 'loss': loss, + 'ats_loss': ats_loss, + 'pil_loss': pil_loss, + 'learning_rate': self._optimizer.param_groups[0]['lr'], + 'train_f1_acc': train_f1_acc, + 'train_precision': train_precision, + 'train_recall': train_recall, + 'train_f1_acc_ats': train_f1_acc_ats, + } + return train_metrics + + def training_step(self, batch: list) -> dict: + """ + Performs a single training step. + + Args: + batch (list): A list containing the following elements: + - audio_signal (torch.Tensor): The input audio signal in time-series format. + - audio_signal_length (torch.Tensor): The length of each audio signal in the batch. + - targets (torch.Tensor): The target labels for the batch. + - target_lens (torch.Tensor): The length of each target sequence in the batch. + + Returns: + (dict): A dictionary containing the 'loss' key with the calculated loss value. + """ + audio_signal, audio_signal_length, targets, target_lens = batch + preds = self.forward(audio_signal=audio_signal, audio_signal_length=audio_signal_length) + train_metrics = self._get_aux_train_evaluations(preds, targets, target_lens) + self._reset_train_metrics() + self.log_dict(train_metrics, sync_dist=True, on_step=True, on_epoch=False, logger=True) + return {'loss': train_metrics['loss']} + + def _get_aux_validation_evaluations(self, preds, targets, target_lens) -> dict: + """ + Compute auxiliary validation evaluations including losses and metrics. + + This function calculates various losses and metrics for the training process, + including Arrival Time Sort (ATS) Loss and Permutation Invariant Loss (PIL) + based evaluations. + + Args: + preds (torch.Tensor): Predicted speaker labels. + Shape: (batch_size, diar_frame_count, num_speakers) + targets (torch.Tensor): Ground truth speaker labels. + Shape: (batch_size, diar_frame_count, num_speakers) + target_lens (torch.Tensor): Lengths of target sequences. + Shape: (batch_size,) + + Returns: + val_metrics (dict): A dictionary containing the following validation metrics + """ + targets_ats = get_ats_targets(targets.clone(), preds, speaker_permutations=self.speaker_permutations) + targets_pil = get_pil_targets(targets.clone(), preds, speaker_permutations=self.speaker_permutations) + + val_ats_loss = self.loss(probs=preds, labels=targets_ats, target_lens=target_lens) + val_pil_loss = self.loss(probs=preds, labels=targets_pil, target_lens=target_lens) + val_loss = self.ats_weight * val_ats_loss + self.pil_weight * val_pil_loss + + self._accuracy_valid(preds, targets_pil, target_lens) + val_f1_acc, val_precision, val_recall = self._accuracy_valid.compute() + + self._accuracy_valid_ats(preds, targets_ats, target_lens) + valid_f1_acc_ats, _, _ = self._accuracy_valid_ats.compute() + + self._accuracy_valid.reset() + self._accuracy_valid_ats.reset() + + val_metrics = { + 'val_loss': val_loss, + 'val_ats_loss': val_ats_loss, + 'val_pil_loss': val_pil_loss, + 'val_f1_acc': val_f1_acc, + 'val_precision': val_precision, + 'val_recall': val_recall, + 'val_f1_acc_ats': valid_f1_acc_ats, + } + return val_metrics + + def validation_step(self, batch: list, dataloader_idx: int = 0): + """ + Performs a single validation step. + + This method processes a batch of data during the validation phase. It forward passes + the audio signal through the model, computes various validation metrics, and stores + these metrics for later aggregation. + + Args: + batch (list): A list containing the following elements: + - audio_signal (torch.Tensor): The input audio signal. + - audio_signal_length (torch.Tensor): The length of each audio signal in the batch. + - targets (torch.Tensor): The target labels for the batch. + - target_lens (torch.Tensor): The length of each target sequence in the batch. + batch_idx (int): The index of the current batch. + dataloader_idx (int, optional): The index of the dataloader in case of multiple + validation dataloaders. Defaults to 0. + + Returns: + dict: A dictionary containing various validation metrics for this batch. + """ + audio_signal, audio_signal_length, targets, target_lens = batch + preds = self.forward( + audio_signal=audio_signal, + audio_signal_length=audio_signal_length, + ) + val_metrics = self._get_aux_validation_evaluations(preds, targets, target_lens) + if isinstance(self.trainer.val_dataloaders, list) and len(self.trainer.val_dataloaders) > 1: + self.validation_step_outputs[dataloader_idx].append(val_metrics) + else: + self.validation_step_outputs.append(val_metrics) + return val_metrics + + def multi_validation_epoch_end(self, outputs: list, dataloader_idx: int = 0): + if not outputs: + logging.warning(f"`outputs` is None; empty outputs for dataloader={dataloader_idx}") + return None + val_loss_mean = torch.stack([x['val_loss'] for x in outputs]).mean() + val_ats_loss_mean = torch.stack([x['val_ats_loss'] for x in outputs]).mean() + val_pil_loss_mean = torch.stack([x['val_pil_loss'] for x in outputs]).mean() + val_f1_acc_mean = torch.stack([x['val_f1_acc'] for x in outputs]).mean() + val_precision_mean = torch.stack([x['val_precision'] for x in outputs]).mean() + val_recall_mean = torch.stack([x['val_recall'] for x in outputs]).mean() + val_f1_acc_ats_mean = torch.stack([x['val_f1_acc_ats'] for x in outputs]).mean() + + self._reset_valid_metrics() + + multi_val_metrics = { + 'val_loss': val_loss_mean, + 'val_ats_loss': val_ats_loss_mean, + 'val_pil_loss': val_pil_loss_mean, + 'val_f1_acc': val_f1_acc_mean, + 'val_precision': val_precision_mean, + 'val_recall': val_recall_mean, + 'val_f1_acc_ats': val_f1_acc_ats_mean, + } + return {'log': multi_val_metrics} + + def _get_aux_test_batch_evaluations(self, batch_idx: int, preds, targets, target_lens): + """ + Compute auxiliary validation evaluations including losses and metrics. + + This function calculates various losses and metrics for the training process, + including Arrival Time Sort (ATS) Loss and Permutation Invariant Loss (PIL) + based evaluations. + + Args: + preds (torch.Tensor): Predicted speaker labels. + Shape: (batch_size, diar_frame_count, num_speakers) + targets (torch.Tensor): Ground truth speaker labels. + Shape: (batch_size, diar_frame_count, num_speakers) + target_lens (torch.Tensor): Lengths of target sequences. + Shape: (batch_size,) + """ + targets_ats = get_ats_targets(targets.clone(), preds, speaker_permutations=self.speaker_permutations) + targets_pil = get_pil_targets(targets.clone(), preds, speaker_permutations=self.speaker_permutations) + self._accuracy_test(preds, targets_pil, target_lens) + f1_acc, precision, recall = self._accuracy_test.compute() + self.batch_f1_accs_list.append(f1_acc) + self.batch_precision_list.append(precision) + self.batch_recall_list.append(recall) + logging.info(f"batch {batch_idx}: f1_acc={f1_acc}, precision={precision}, recall={recall}") + + self._accuracy_test_ats(preds, targets_ats, target_lens) + f1_acc_ats, precision_ats, recall_ats = self._accuracy_test_ats.compute() + self.batch_f1_accs_ats_list.append(f1_acc_ats) + logging.info( + f"batch {batch_idx}: f1_acc_ats={f1_acc_ats}, precision_ats={precision_ats}, recall_ats={recall_ats}" + ) + + self._accuracy_test.reset() + self._accuracy_test_ats.reset() + + def test_batch( + self, + ): + """ + Perform batch testing on the model. + + This method iterates through the test data loader, making predictions for each batch, + and calculates various evaluation metrics. It handles both single and multi-sample batches. + """ + ( + self.preds_total_list, + self.batch_f1_accs_list, + self.batch_precision_list, + self.batch_recall_list, + self.batch_f1_accs_ats_list, + ) = ([], [], [], [], []) + + with torch.no_grad(): + for batch_idx, batch in enumerate(tqdm(self._test_dl)): + audio_signal, audio_signal_length, targets, target_lens = batch + audio_signal = audio_signal.to(self.device) + audio_signal_length = audio_signal_length.to(self.device) + preds = self.forward( + audio_signal=audio_signal, + audio_signal_length=audio_signal_length, + ) + preds = preds.detach().to('cpu') + if preds.shape[0] == 1: # batch size = 1 + self.preds_total_list.append(preds) + else: + self.preds_total_list.extend(torch.split(preds, [1] * preds.shape[0])) + torch.cuda.empty_cache() + self._get_aux_test_batch_evaluations(batch_idx, preds, targets, target_lens) + + logging.info(f"Batch F1Acc. MEAN: {torch.mean(torch.tensor(self.batch_f1_accs_list))}") + logging.info(f"Batch Precision MEAN: {torch.mean(torch.tensor(self.batch_precision_list))}") + logging.info(f"Batch Recall MEAN: {torch.mean(torch.tensor(self.batch_recall_list))}") + logging.info(f"Batch ATS F1Acc. MEAN: {torch.mean(torch.tensor(self.batch_f1_accs_ats_list))}") + + def diarize( + self, + ): + """One-clieck runner function for diarization.""" + # TODO: A direct one-click runner function that generates + # speaker labels from audio file path lists. + raise NotImplementedError diff --git a/nemo/collections/asr/models/ssl_models.py b/nemo/collections/asr/models/ssl_models.py index 633a00d73f5e..9150da7bf7c2 100644 --- a/nemo/collections/asr/models/ssl_models.py +++ b/nemo/collections/asr/models/ssl_models.py @@ -17,8 +17,8 @@ import torch import torch.nn as nn +from lightning.pytorch import Trainer from omegaconf import DictConfig -from pytorch_lightning import Trainer from nemo.collections.asr.data import audio_to_text_dataset, ssl_dataset from nemo.collections.asr.data.audio_to_text_dali import DALIOutputs diff --git a/nemo/collections/asr/models/transformer_bpe_models.py b/nemo/collections/asr/models/transformer_bpe_models.py index 089186e142bf..4692cb662b4b 100644 --- a/nemo/collections/asr/models/transformer_bpe_models.py +++ b/nemo/collections/asr/models/transformer_bpe_models.py @@ -22,8 +22,8 @@ import editdistance import torch import torch.distributed as dist +from lightning.pytorch import Trainer from omegaconf import DictConfig, OmegaConf, open_dict -from pytorch_lightning import Trainer from torch.utils.data import DataLoader from torchmetrics.text import SacreBLEUScore from tqdm.auto import tqdm @@ -225,10 +225,14 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]): config = self._update_default_values(config) return get_lhotse_dataloader_from_config( config, - global_rank=self.global_rank, - world_size=self.world_size, + # During transcription, the model is initially loaded on the CPU. + # To ensure the correct global_rank and world_size are set, + # these values must be passed from the configuration. + global_rank=self.global_rank if not config.get("do_transcribe", False) else config.get("global_rank"), + world_size=self.world_size if not config.get("do_transcribe", False) else config.get("world_size"), dataset=LhotseSpeechToTextBpeDataset( tokenizer=self.tokenizer, + return_cuts=config.get("do_transcribe", False), ), tokenizer=self.tokenizer, ) diff --git a/nemo/collections/asr/modules/sortformer_modules.py b/nemo/collections/asr/modules/sortformer_modules.py new file mode 100644 index 000000000000..d99bf3b93e38 --- /dev/null +++ b/nemo/collections/asr/modules/sortformer_modules.py @@ -0,0 +1,108 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from nemo.core.classes.exportable import Exportable +from nemo.core.classes.module import NeuralModule + +__all__ = ['SortformerModules'] + + +class SortformerModules(NeuralModule, Exportable): + """ + A class including auxiliary functions for Sortformer models. + This class contains and will contain the following functions that performs streaming features, + and any neural layers that are not included in the NeMo neural modules (e.g. Transformer, Fast-Conformer). + """ + + def init_weights(self, m): + """Init weights for linear layers.""" + if type(m) == nn.Linear: + torch.nn.init.xavier_uniform_(m.weight) + m.bias.data.fill_(0.01) + + def __init__( + self, + num_spks: int = 4, + hidden_size: int = 192, + dropout_rate: float = 0.5, + fc_d_model: int = 512, + tf_d_model: int = 192, + ): + """ + Args: + num_spks (int): + Max number of speakers that are processed by the model. + hidden_size (int): + Number of hidden units in sequence models and intermediate layers. + dropout_rate (float): + Dropout rate for linear layers, CNN and LSTM. + fc_d_model (int): + Dimension of the embedding vectors. + tf_d_model (int): + Dimension of the embedding vectors. + """ + super().__init__() + self.fc_d_model = fc_d_model + self.tf_d_model = tf_d_model + self.hidden_size = tf_d_model + self.unit_n_spks: int = num_spks + self.hidden_to_spks = nn.Linear(2 * self.hidden_size, self.unit_n_spks) + self.first_hidden_to_hidden = nn.Linear(self.hidden_size, self.hidden_size) + self.single_hidden_to_spks = nn.Linear(self.hidden_size, self.unit_n_spks) + self.dropout = nn.Dropout(dropout_rate) + self.encoder_proj = nn.Linear(self.fc_d_model, self.tf_d_model) + + def length_to_mask(self, context_embs): + """ + Convert length values to encoder mask input tensor. + + Args: + lengths (torch.Tensor): tensor containing lengths of sequences + max_len (int): maximum sequence length + + Returns: + mask (torch.Tensor): tensor of shape (batch_size, max_len) containing 0's + in the padded region and 1's elsewhere + """ + lengths = torch.tensor([context_embs.shape[1]] * context_embs.shape[0]) + batch_size = context_embs.shape[0] + max_len = context_embs.shape[1] + # create a tensor with the shape (batch_size, 1) filled with ones + row_vector = torch.arange(max_len).unsqueeze(0).expand(batch_size, -1).to(lengths.device) + # create a tensor with the shape (batch_size, max_len) filled with lengths + length_matrix = lengths.unsqueeze(1).expand(-1, max_len).to(lengths.device) + # create a mask by comparing the row vector and length matrix + mask = row_vector < length_matrix + return mask.float().to(context_embs.device) + + def forward_speaker_sigmoids(self, hidden_out): + """ + A set of layers for predicting speaker probabilities with a sigmoid activation function. + + Args: + hidden_out (torch.Tensor): tensor of shape (batch_size, seq_len, hidden_size) + + Returns: + preds (torch.Tensor): tensor of shape (batch_size, seq_len, num_spks) containing speaker probabilities + """ + hidden_out = self.dropout(F.relu(hidden_out)) + hidden_out = self.first_hidden_to_hidden(hidden_out) + hidden_out = self.dropout(F.relu(hidden_out)) + spk_preds = self.single_hidden_to_spks(hidden_out) + preds = nn.Sigmoid()(spk_preds) + return preds diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py index 4153af060941..219e9d0453b2 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py @@ -961,7 +961,9 @@ def compute_tdt_alphas_kernel( if t > 0 and t < T: alphas[offset + t * maxU + u] = -INF - for i in range(1, num_durations): # skip 0 since blank emission has to advance by at least one + for i in range(num_durations): + if durations[i] == 0: # skip 0 since blank emission has to advance by at least one + continue if t >= durations[i]: alphas[offset + t * maxU + u] = rnnt_helper.log_sum_exp( alphas[offset + t * maxU + u], # the current alpha value @@ -981,21 +983,26 @@ def compute_tdt_alphas_kernel( elif u < U: # when t == 0, we only consider the non-blank emission. if t == 0: - alphas[offset + u] = ( - alphas[offset + u - 1] # alpha(t, u - 1) - + logp( - denom, acts, maxT, maxU, alphabet_size, b, t, u - 1, labels[u - 1] - ) # logp of token emission - - sigma # logit under-normalization - + logp_duration( - duration_acts, maxT, maxU, num_durations, b, t, u - 1, 0 - ) # t = 0, so it must be duration = 0. Therefore the last argument passed to logp_duration() is 0. - ) + if durations[0] == 0: + alphas[offset + u] = ( + alphas[offset + u - 1] # alpha(t, u - 1) + + logp( + denom, acts, maxT, maxU, alphabet_size, b, t, u - 1, labels[u - 1] + ) # logp of token emission + - sigma # logit under-normalization + + logp_duration( + duration_acts, maxT, maxU, num_durations, b, t, u - 1, 0 + ) # t = 0, so it must be duration = 0. Therefore the last argument passed to logp_duration() is 0. + ) + else: + alphas[offset + u] = -INF # now we have t != 0 and u != 0, and we need to consider both non-blank and blank emissions. elif t > 0 and t < T: no_emit = -INF # no_emit stores the score for all blank emissions. - for i in range(1, num_durations): + for i in range(num_durations): + if durations[i] == 0: + continue if t >= durations[i]: no_emit = rnnt_helper.log_sum_exp( no_emit, # current score @@ -1012,7 +1019,7 @@ def compute_tdt_alphas_kernel( break # we can exit the loop early here, same as the case for u == 0 above. emit = -INF # emit stores the score for non-blank emissions. - for i in range(0, num_durations): + for i in range(num_durations): if t >= durations[i]: emit = rnnt_helper.log_sum_exp( emit, # current score @@ -1037,16 +1044,21 @@ def compute_tdt_alphas_kernel( # After final sync, the forward log-likelihood can be computed as the summataion of # alpha(T - duration, U - 1) + logp(blank, duration | t - duration, U - 1), over different durations. if u == 0: - # first we consider duration = 1 - loglike = ( - alphas[offset + (T - 1) * maxU + U - 1] - + logp(denom, acts, maxT, maxU, alphabet_size, b, T - 1, U - 1, blank_) - - sigma - + logp_duration(duration_acts, maxT, maxU, num_durations, b, T - 1, U - 1, 1) - ) + # initialize with negative infinite and start add terms later + loglike = -INF # then we add the scores for duration > 1, if such durations are possible given the audio lengths. - for i in range(2, num_durations): + for i in range(num_durations): + if durations[i] == 0: + continue + if durations[i] == 1: + loglike = ( + alphas[offset + (T - 1) * maxU + U - 1] + + logp(denom, acts, maxT, maxU, alphabet_size, b, T - 1, U - 1, blank_) + - sigma + + logp_duration(duration_acts, maxT, maxU, num_durations, b, T - 1, U - 1, i) + ) + continue if T >= durations[i]: big_blank_loglike = ( alphas[offset + (T - durations[i]) * maxU + U - 1] @@ -1122,11 +1134,18 @@ def compute_tdt_betas_kernel( # Initilize beta[b, t=T-1, u=U-1] for all b in B with log_probs[b, t=T-1, u=U-1, blank] if u == 0: - betas[offset + (T - 1) * maxU + U - 1] = ( - logp(denom, acts, maxT, maxU, alphabet_size, b, T - 1, U - 1, blank_) - - sigma - + logp_duration(duration_acts, maxT, maxU, num_durations, b, T - 1, U - 1, 1) - ) + if durations[0] == 1: + betas[offset + (T - 1) * maxU + U - 1] = ( + logp(denom, acts, maxT, maxU, alphabet_size, b, T - 1, U - 1, blank_) + - sigma + + logp_duration(duration_acts, maxT, maxU, num_durations, b, T - 1, U - 1, 0) + ) + elif durations[1] == 1: + betas[offset + (T - 1) * maxU + U - 1] = ( + logp(denom, acts, maxT, maxU, alphabet_size, b, T - 1, U - 1, blank_) + - sigma + + logp_duration(duration_acts, maxT, maxU, num_durations, b, T - 1, U - 1, 1) + ) # sync until all betas are initialized cuda.syncthreads() @@ -1140,11 +1159,12 @@ def compute_tdt_betas_kernel( # u == U - 1, we only consider blank emissions. if t >= 0 and t + 1 < T: betas[offset + t * maxU + U - 1] = -INF - for i in range(1, num_durations): + for i in range(num_durations): # although similar, the computation for beta's is slightly more complex for boundary cases. # the following two cases correspond to whether t is exactly certain duration away from T. # and they have slightly different update rules. - + if durations[i] == 0: + continue if t + durations[i] < T: betas[offset + t * maxU + U - 1] = rnnt_helper.log_sum_exp( betas[offset + t * maxU + U - 1], @@ -1172,17 +1192,24 @@ def compute_tdt_betas_kernel( elif u < U - 1: if t == T - 1: # t == T - 1, so we only consider non-blank with duration 0. (Note, we can't have blank emissions with duration = 0) - betas[offset + (T - 1) * maxU + u] = ( - betas[offset + (T - 1) * maxU + u + 1] - + logp(denom, acts, maxT, maxU, alphabet_size, b, T - 1, u, labels[u]) # non-blank log prob - + logp_duration(duration_acts, maxT, maxU, num_durations, b, T - 1, u, 0) # log prob of duration 0 - - sigma - ) + if durations[0] == 0: + betas[offset + (T - 1) * maxU + u] = ( + betas[offset + (T - 1) * maxU + u + 1] + + logp(denom, acts, maxT, maxU, alphabet_size, b, T - 1, u, labels[u]) # non-blank log prob + + logp_duration( + duration_acts, maxT, maxU, num_durations, b, T - 1, u, 0 + ) # log prob of duration 0 + - sigma + ) + else: + betas[offset + (T - 1) * maxU + u] = -INF elif t >= 0 and t < T - 1: # now we need to consider both blank andnon-blanks. Similar to alphas, we first compute them separately with no_emit and emit. no_emit = -INF - for i in range(1, num_durations): + for i in range(num_durations): + if durations[i] == 0: + continue if t + durations[i] < T: no_emit = rnnt_helper.log_sum_exp( no_emit, @@ -1193,7 +1220,7 @@ def compute_tdt_betas_kernel( ) emit = -INF - for i in range(0, num_durations): + for i in range(num_durations): if t + durations[i] < T: emit = rnnt_helper.log_sum_exp( emit, @@ -1304,10 +1331,10 @@ def compute_tdt_grad_kernel( logpk_label = denom[col] + acts[col * alphabet_size + labels[u]] - sigma grad -= math.exp(alphas[col] + betas[col + 1 + durations[idx] * maxU] + logpk_label - logll[mb]) - if t + durations[idx] < T and idx > 0: # for blank in the middle + if t + durations[idx] < T and durations[idx] > 0: # for blank in the middle grad -= math.exp(alphas[col] + betas[col + durations[idx] * maxU] + logpk_blank - logll[mb]) - if t + durations[idx] == T and idx >= 1 and u == U - 1: # for blank as the last symbol + if t + durations[idx] == T and u == U - 1 and durations[idx] > 0: # for blank as the last symbol grad -= math.exp(alphas[col] + logpk_blank - logll[mb]) grad = grad * math.exp(duration_acts[col * num_durations + idx]) @@ -1335,7 +1362,7 @@ def compute_tdt_grad_kernel( if fastemit_lambda > 0.0 and u < U - 1: fastemit_grad = 0.0 - for i in range(0, num_durations): + for i in range(num_durations): if t + durations[i] < T: fastemit_grad += fastemit_lambda * math.exp( alphas[col] # alphas(t, u) @@ -1355,7 +1382,9 @@ def compute_tdt_grad_kernel( # grad to last blank transition # grad[b, T-1, U-1, v=blank] -= exp(alphas[b, t, u] + logpk - sigma - logll[b] + logp(duration) for all possible non-zero durations. if idx == blank_ and u == U - 1: - for i in range(1, num_durations): + for i in range(num_durations): + if durations[i] == 0: + continue if t == T - durations[i]: grad -= math.exp( alphas[col] + logpk - sigma - logll[mb] + duration_acts[col * num_durations + i] @@ -1364,7 +1393,9 @@ def compute_tdt_grad_kernel( # grad of blank across t < T; # grad[b, t torch.Tensor: + """ + Finds the first nonzero value in the matrix, discretizing it to the specified maximum capacity. + + Args: + mat (Tensor): A torch tensor representing the matrix. + max_cap_val (int): The maximum capacity to which the matrix values will be discretized. + thres (float): The threshold value for discretizing the matrix values. + + Returns: + mask_max_indices (Tensor): A torch tensor representing the discretized matrix with the first + nonzero value in each row. + """ + # Discretize the matrix to the specified maximum capacity + labels_discrete = mat.clone() + labels_discrete[labels_discrete < thres] = 0 + labels_discrete[labels_discrete >= thres] = 1 + + # non zero values mask + non_zero_mask = labels_discrete != 0 + # operations on the mask to find first nonzero values in the rows + mask_max_values, mask_max_indices = torch.max(non_zero_mask, dim=1) + # if the max-mask is zero, there is no nonzero value in the row + mask_max_indices[mask_max_values == 0] = max_cap_val + return mask_max_indices + + +def find_best_permutation(match_score: torch.Tensor, speaker_permutations: torch.Tensor) -> torch.Tensor: + """ + Finds the best permutation indices based on the match score. + + Args: + match_score (torch.Tensor): A tensor containing the match scores for each permutation. + Shape: (batch_size, num_permutations) + speaker_permutations (torch.Tensor): A tensor containing all possible speaker permutations. + Shape: (num_permutations, num_speakers) + + Returns: + torch.Tensor: A tensor containing the best permutation indices for each batch. + Shape: (batch_size, num_speakers) + """ + batch_best_perm = torch.argmax(match_score, axis=1) + rep_speaker_permutations = speaker_permutations.repeat(batch_best_perm.shape[0], 1).to(match_score.device) + perm_size = speaker_permutations.shape[0] + global_inds_vec = ( + torch.arange(0, perm_size * batch_best_perm.shape[0], perm_size).to(batch_best_perm.device) + batch_best_perm + ) + return rep_speaker_permutations[global_inds_vec.to(rep_speaker_permutations.device), :] + + +def reconstruct_labels(labels: torch.Tensor, batch_perm_inds: torch.Tensor) -> torch.Tensor: + """ + Reconstructs the labels using the best permutation indices with matrix operations. + + Args: + labels (torch.Tensor): A tensor containing the original labels. + Shape: (batch_size, num_frames, num_speakers) + batch_perm_inds (torch.Tensor): A tensor containing the best permutation indices for each batch. + Shape: (batch_size, num_speakers) + + Returns: + torch.Tensor: A tensor containing the reconstructed labels using the best permutation indices. + Shape: (batch_size, num_frames, num_speakers) + """ + # Expanding batch_perm_inds to align with labels dimensions + batch_size, num_frames, num_speakers = labels.shape + batch_perm_inds_exp = batch_perm_inds.unsqueeze(1).expand(-1, num_frames, -1) + + # Reconstructing the labels using advanced indexing + reconstructed_labels = torch.gather(labels, 2, batch_perm_inds_exp) + return reconstructed_labels + + +def get_ats_targets( + labels: torch.Tensor, + preds: torch.Tensor, + speaker_permutations: torch.Tensor, + thres: float = 0.5, + tolerance: float = 0, +) -> torch.Tensor: + """ + Sorts labels and predictions to get the optimal of all arrival-time ordered permutations. + + Args: + labels (torch.Tensor): A tensor containing the original labels. + Shape: (batch_size, num_frames, num_speakers) + preds (torch.Tensor): A tensor containing the predictions. + Shape: (batch_size, num_frames, num_speakers) + speaker_permutations (torch.Tensor): A tensor containing all possible speaker permutations. + Shape: (num_permutations, num_speakers) + thres (float): The threshold value for discretizing the matrix values. Default is 0.5. + tolerance (float): The tolerance for comparing the first speech frame indices. Default is 0. + + Returns: + torch.Tensor: A tensor containing the reconstructed labels using the best permutation indices. + Shape: (batch_size, num_frames, num_speakers) + """ + # Find the first nonzero frame index for each speaker in each batch + nonzero_ind = find_first_nonzero( + mat=labels, max_cap_val=labels.shape[1], thres=thres + ) # (batch_size, num_speakers) + + # Sort the first nonzero frame indices for arrival-time ordering + sorted_values = torch.sort(nonzero_ind)[0] # (batch_size, num_speakers) + perm_size = speaker_permutations.shape[0] # Scalar value (num_permutations) + permed_labels = labels[:, :, speaker_permutations] # (batch_size, num_frames, num_permutations, num_speakers) + permed_nonzero_ind = find_first_nonzero( + mat=permed_labels, max_cap_val=labels.shape[1] + ) # (batch_size, num_permutations, num_speakers) + + # Compare the first frame indices of sorted labels with those of the permuted labels using tolerance + perm_compare = ( + torch.abs(sorted_values.unsqueeze(1) - permed_nonzero_ind) <= tolerance + ) # (batch_size, num_permutations, num_speakers) + perm_mask = torch.all(perm_compare, dim=2).float() # (batch_size, num_permutations) + preds_rep = torch.unsqueeze(preds, 2).repeat( + 1, 1, perm_size, 1 + ) # Exapnd the preds: (batch_size, num_frames, num_permutations, num_speakers) + + # Compute the match score for each permutation by comparing permuted labels with preds + match_score = ( + torch.sum(permed_labels * preds_rep, axis=1).sum(axis=2) * perm_mask + ) # (batch_size, num_permutations) + batch_perm_inds = find_best_permutation(match_score, speaker_permutations) # (batch_size, num_speakers) + max_score_permed_labels = reconstruct_labels(labels, batch_perm_inds) # (batch_size, num_frames, num_speakers) + return max_score_permed_labels # (batch_size, num_frames, num_speakers) + + +def get_pil_targets(labels: torch.Tensor, preds: torch.Tensor, speaker_permutations: torch.Tensor) -> torch.Tensor: + """ + Sorts labels and predictions to get the optimal permutation based on the match score. + + Args: + labels (torch.Tensor): A tensor containing the ground truth labels. + Shape: (batch_size, num_speakers, num_classes) + preds (torch.Tensor): A tensor containing the predicted values. + Shape: (batch_size, num_speakers, num_classes) + speaker_permutations (torch.Tensor): A tensor containing all possible speaker permutations. + Shape: (num_permutations, num_speakers) + + Returns: + torch.Tensor: A tensor of permuted labels that best match the predictions. + Shape: (batch_size, num_speakers, num_classes) + """ + permed_labels = labels[:, :, speaker_permutations] # (batch_size, num_classes, num_permutations, num_speakers) + # Repeat preds to match permutations for comparison + preds_rep = torch.unsqueeze(preds, 2).repeat( + 1, 1, speaker_permutations.shape[0], 1 + ) # (batch_size, num_speakers, num_permutations, num_classes) + match_score = torch.sum(permed_labels * preds_rep, axis=1).sum(axis=2) # (batch_size, num_permutations) + batch_perm_inds = find_best_permutation(match_score, speaker_permutations) # (batch_size, num_speakers) + # Reconstruct labels based on the best permutation for each batch + max_score_permed_labels = reconstruct_labels(labels, batch_perm_inds) # (batch_size, num_speakers, num_classes) + return max_score_permed_labels # (batch_size, num_speakers, num_classes) + + +def find_segments_from_rttm( + recording_id: str, + rttms: SupervisionSet, + start_after: float, + end_before: float, + adjust_offset: bool = True, + tolerance: float = 0.001, +): + """ + Finds segments from the given rttm file. + This function is designed to replace rttm + + Args: + recording_id (str): The recording ID in string format. + rttms (SupervisionSet): The SupervisionSet instance. + start_after (float): The start time after which segments are selected. + end_before (float): The end time before which segments are selected. + adjust_offset (bool): Whether to adjust the offset of the segments. + tolerance (float): The tolerance for time matching. 0.001 by default. + + Returns: + segments (List[SupervisionSegment]): A list of SupervisionSegment instances. + """ + segment_by_recording_id = rttms._segments_by_recording_id + if segment_by_recording_id is None: + from cytoolz import groupby + + segment_by_recording_id = groupby(lambda seg: seg.recording_id, rttms) + + return [ + # We only modify the offset - the duration remains the same, as we're only shifting the segment + # relative to the Cut's start, and not truncating anything. + segment.with_offset(-start_after) if adjust_offset else segment + for segment in segment_by_recording_id.get(recording_id, []) + if segment.start < end_before + tolerance and segment.end > start_after + tolerance + ] + + +def get_mask_from_segments( + segments: list, + a_cut: Optional[Union[MonoCut, MixedCut]], + speaker_to_idx_map: torch.Tensor, + num_speakers: int = 4, + feat_per_sec: int = 100, + ignore_num_spk_mismatch: bool = False, +): + """ + Generate mask matrix from segments list. + This function is needed for speaker diarization with ASR model trainings. + + Args: + segments: A list of Lhotse Supervision segments iterator. + cut (MonoCut, MixedCut): Lhotse MonoCut or MixedCut instance. + speaker_to_idx_map (dict): A dictionary mapping speaker names to indices. + num_speakers (int): max number of speakers for all cuts ("mask" dim0), 4 by default + feat_per_sec (int): number of frames per second, 100 by default, 0.01s frame rate + ignore_num_spk_mismatch (bool): This is a temporary solution to handle speaker mismatch. + Will be removed in the future. + + Returns: + mask (Tensor): A numpy array of shape (num_speakers, encoder_hidden_len). + Dimension: (num_speakers, num_frames) + """ + # get targets with 0.01s frame rate + num_samples = round(a_cut.duration * feat_per_sec) + mask = torch.zeros((num_samples, num_speakers)) + for rttm_sup in segments: + speaker_idx = speaker_to_idx_map[rttm_sup.speaker] + if speaker_idx >= num_speakers: + if ignore_num_spk_mismatch: + continue + else: + raise ValueError(f"Speaker Index {speaker_idx} exceeds the max index: {num_speakers-1}") + stt = max(rttm_sup.start, 0) + ent = min(rttm_sup.end, a_cut.duration) + stf = int(stt * feat_per_sec) + enf = int(ent * feat_per_sec) + mask[stf:enf, speaker_idx] = 1.0 + return mask + + +def get_soft_mask(feat_level_target, num_frames, stride): + """ + Get soft mask from feat_level_target with stride. + This function is needed for speaker diarization with ASR model trainings. + + Args: + feat_level_target (Tensor): A numpy array of shape (num_frames, num_speakers). + Dimension: (num_frames, num_speakers) + num_sample (int): The total number of samples. + stride (int): The stride for the mask. + + Returns: + mask: The soft mask of shape (num_frames, num_speakers). + Dimension: (num_frames, num_speakers) + """ + + num_speakers = feat_level_target.shape[1] + mask = torch.zeros(num_frames, num_speakers) + + for index in range(num_frames): + if index == 0: + seg_stt_feat = 0 + else: + seg_stt_feat = stride * index - 1 - int(stride / 2) + if index == num_frames - 1: + seg_end_feat = feat_level_target.shape[0] + else: + seg_end_feat = stride * index - 1 + int(stride / 2) + mask[index] = torch.mean(feat_level_target[seg_stt_feat : seg_end_feat + 1, :], axis=0) + return mask + + +def get_hidden_length_from_sample_length( + num_samples: int, num_sample_per_mel_frame: int = 160, num_mel_frame_per_asr_frame: int = 8 +) -> int: + """ + Calculate the hidden length from the given number of samples. + This function is needed for speaker diarization with ASR model trainings. + + This function computes the number of frames required for a given number of audio samples, + considering the number of samples per mel frame and the number of mel frames per ASR frame. + + Parameters: + num_samples (int): The total number of audio samples. + num_sample_per_mel_frame (int, optional): The number of samples per mel frame. Default is 160. + num_mel_frame_per_asr_frame (int, optional): The number of mel frames per ASR frame. Default is 8. + + Returns: + hidden_length (int): The calculated hidden length in terms of the number of frames. + """ + mel_frame_count = math.ceil((num_samples + 1) / num_sample_per_mel_frame) + hidden_length = math.ceil(mel_frame_count / num_mel_frame_per_asr_frame) + return int(hidden_length) + + +def speaker_to_target( + a_cut, + num_speakers: int = 4, + num_sample_per_mel_frame: int = 160, + num_mel_frame_per_asr_frame: int = 8, + spk_tar_all_zero: bool = False, + boundary_segments: bool = False, + soft_label: bool = False, + ignore_num_spk_mismatch: bool = True, + soft_thres: float = 0.5, +): + """ + Get rttm samples corresponding to one cut, generate speaker mask numpy.ndarray with shape + (num_speaker, hidden_length). This function is needed for speaker diarization with ASR model trainings. + + Args: + a_cut (MonoCut, MixedCut): + Lhotse Cut instance which is MonoCut or MixedCut instance. + num_speakers (int): + Max number of speakers for all cuts ("mask" dim0), 4 by default + num_sample_per_mel_frame (int): + Number of sample per mel frame, sample_rate / 1000 * window_stride, 160 by default (10ms window stride) + num_mel_frame_per_asr_frame (int): + Encoder subsampling_factor, 8 by default + spk_tar_all_zero (Tensor): + Set to True gives all zero "mask" + boundary_segments (bool): + Set to True to include segments containing the boundary of the cut, + False by default for multi-speaker ASR training + soft_label (bool): + Set to True to use soft label that enables values in [0, 1] range, + False by default and leads to binary labels. + ignore_num_spk_mismatch (bool): + This is a temporary solution to handle speaker mismatch. Will be removed in the future. + + Returns: + mask (Tensor): Speaker mask with shape (num_speaker, hidden_lenght) + """ + # get cut-related segments from rttms + if isinstance(a_cut, MixedCut): + cut_list = [track.cut for track in a_cut.tracks if isinstance(track.cut, MonoCut)] + offsets = [track.offset for track in a_cut.tracks if isinstance(track.cut, MonoCut)] + elif isinstance(a_cut, MonoCut): + cut_list = [a_cut] + offsets = [0] + else: + raise ValueError(f"Unsupported cut type type{a_cut}: only MixedCut and MonoCut are supported") + + segments_total = [] + for i, cut in enumerate(cut_list): + rttms = SupervisionSet.from_rttm(cut.rttm_filepath) + if boundary_segments: # segments with seg_start < total_end and seg_end > total_start are included + segments_iterator = find_segments_from_rttm( + recording_id=cut.recording_id, rttms=rttms, start_after=cut.start, end_before=cut.end, tolerance=0.0 + ) + else: # segments with seg_start > total_start and seg_end < total_end are included + segments_iterator = rttms.find( + recording_id=cut.recording_id, start_after=cut.start, end_before=cut.end, adjust_offset=True + ) + + for seg in segments_iterator: + if seg.start < 0: + seg.duration += seg.start + seg.start = 0 + if seg.end > cut.duration: + seg.duration -= seg.end - cut.duration + seg.start += offsets[i] + segments_total.append(seg) + + # apply arrival time sorting to the existing segments + segments_total.sort(key=lambda rttm_sup: rttm_sup.start) + + seen = set() + seen_add = seen.add + speaker_ats = [s.speaker for s in segments_total if not (s.speaker in seen or seen_add(s.speaker))] + + speaker_to_idx_map = {spk: idx for idx, spk in enumerate(speaker_ats)} + if len(speaker_to_idx_map) > num_speakers and not ignore_num_spk_mismatch: # raise error if number of speakers + raise ValueError( + f"Number of speakers {len(speaker_to_idx_map)} is larger than " + f"the maximum number of speakers {num_speakers}" + ) + + # initialize mask matrices (num_speaker, encoder_hidden_len) + feat_per_sec = int(a_cut.sampling_rate / num_sample_per_mel_frame) # 100 by default + num_samples = get_hidden_length_from_sample_length( + a_cut.num_samples, num_sample_per_mel_frame, num_mel_frame_per_asr_frame + ) + if spk_tar_all_zero: + frame_mask = torch.zeros((num_samples, num_speakers)) + else: + frame_mask = get_mask_from_segments( + segments_total, a_cut, speaker_to_idx_map, num_speakers, feat_per_sec, ignore_num_spk_mismatch + ) + soft_mask = get_soft_mask(frame_mask, num_samples, num_mel_frame_per_asr_frame) + + if soft_label: + mask = soft_mask + else: + mask = (soft_mask > soft_thres).float() + + return mask diff --git a/nemo/collections/asr/parts/utils/manifest_utils.py b/nemo/collections/asr/parts/utils/manifest_utils.py index e9f91045c9a2..418f95832f48 100644 --- a/nemo/collections/asr/parts/utils/manifest_utils.py +++ b/nemo/collections/asr/parts/utils/manifest_utils.py @@ -24,7 +24,7 @@ from nemo.collections.asr.parts.utils.speaker_utils import ( audio_rttm_map, - get_subsegments, + get_subsegments_scriptable, get_uniqname_from_filepath, rttm_to_labels, segments_manifest_to_subsegments_manifest, @@ -66,13 +66,15 @@ def get_ctm_line( output_precision: int = 2, ) -> str: """ - Get a line in Conversation Time Mark (CTM) format. Following CTM format appeared in `Rich Transcription Meeting Eval Plan: RT09` document. - - CTM Format: + Get a line in Conversation Time Mark (CTM) format. Following CTM format appeared in + `Rich Transcription Meeting Eval Plan: RT09` document. + + CTM Format: - - Reference: - https://web.archive.org/web/20170119114252/http://www.itl.nist.gov/iad/mig/tests/rt/2009/docs/rt09-meeting-eval-plan-v2.pdf + + Reference: + https://web.archive.org/web/20170119114252/ + http://www.itl.nist.gov/iad/mig/tests/rt/2009/docs/rt09-meeting-eval-plan-v2.pdf Args: source (str): is name of the source file, session name or utterance ID @@ -80,11 +82,14 @@ def get_ctm_line( start_time (float): is the begin time of the word, which we refer to as `start_time` in NeMo. duration (float): is duration of the word token (str): Token or word for the current entry - conf (float): is a floating point number between 0 (no confidence) and 1 (certainty). A value of “NA” is used (in CTM format data) - when no confidence is computed and in the reference data. - type_of_token (str): is the token type. The legal values of are “lex”, “frag”, “fp”, “un-lex”, “for-lex”, “non-lex”, “misc”, or “noscore” - speaker (str): is a string identifier for the speaker who uttered the token. This should be “null” for non-speech tokens and “unknown” when - the speaker has not been determined. + conf (float): is a floating point number between 0 (no confidence) and 1 (certainty). + A value of “NA” is used (in CTM format data) + when no confidence is computed and in the reference data. + type_of_token (str): is the token type. The legal values of are + “lex”, “frag”, “fp”, “un-lex”, “for-lex”, “non-lex”, “misc”, or “noscore” + speaker (str): is a string identifier for the speaker who uttered the token. + This should be “null” for non-speech tokens and “unknown” when + the speaker has not been determined. NA_token (str, optional): A token for . Defaults to ''. output_precision (int, optional): The precision of the output floating point number. Defaults to 3. @@ -179,7 +184,7 @@ def get_subsegment_dict(subsegments_manifest_file: str, window: float, shift: fl segment = segment.strip() dic = json.loads(segment) audio, offset, duration, label = dic['audio_filepath'], dic['offset'], dic['duration'], dic['label'] - subsegments = get_subsegments(offset=offset, window=window, shift=shift, duration=duration) + subsegments = get_subsegments_scriptable(offset=offset, window=window, shift=shift, duration=duration) if dic['uniq_id'] is not None: uniq_id = dic['uniq_id'] else: @@ -368,7 +373,11 @@ def create_segment_manifest( segments_manifest_file = write_rttm2manifest(AUDIO_RTTM_MAP, segment_manifest_path, deci) subsegments_manifest_file = subsegment_manifest_path segments_manifest_to_subsegments_manifest( - segments_manifest_file, subsegments_manifest_file, window, shift, min_subsegment_duration, + segments_manifest_file, + subsegments_manifest_file, + window, + shift, + min_subsegment_duration, ) subsegments_dict = get_subsegment_dict(subsegments_manifest_file, window, shift, deci) write_truncated_subsegments(input_manifest_dict, subsegments_dict, output_manifest_path, step_count, deci) @@ -505,7 +514,9 @@ def write_manifest(output_path: Union[Path, str], target_manifest: List[dict], e Args: output_path (str or Path): Path to output manifest file target_manifest (list): List of manifest file entries - ensure_ascii (bool): default is True, meaning the output is guaranteed to have all incoming non-ASCII characters escaped. If ensure_ascii is false, these characters will be output as-is. + ensure_ascii (bool): default is True, meaning the output is guaranteed to have all incoming + non-ASCII characters escaped. If ensure_ascii is false, these characters + will be output as-is. """ with open(output_path, "w", encoding="utf-8") as outfile: for tgt in target_manifest: diff --git a/nemo/collections/asr/parts/utils/speaker_utils.py b/nemo/collections/asr/parts/utils/speaker_utils.py index 5d3a0bf4274e..223916e60a76 100644 --- a/nemo/collections/asr/parts/utils/speaker_utils.py +++ b/nemo/collections/asr/parts/utils/speaker_utils.py @@ -21,21 +21,17 @@ from typing import Dict, List, Tuple, Union import numpy as np -import omegaconf import soundfile as sf import torch -from pyannote.core import Annotation, Segment +from omegaconf.listconfig import ListConfig +from pyannote.core import Annotation, Segment, Timeline from tqdm import tqdm from nemo.collections.asr.data.audio_to_label import repeat_signal from nemo.collections.asr.parts.utils.longform_clustering import LongFormSpeakerClustering -from nemo.collections.asr.parts.utils.offline_clustering import SpeakerClustering, get_argmin_mat, split_input_data +from nemo.collections.asr.parts.utils.offline_clustering import get_argmin_mat, split_input_data from nemo.utils import logging -""" -This file contains all the utility functions required for speaker embeddings part in diarization scripts -""" - def get_uniqname_from_filepath(filepath): """ @@ -81,10 +77,13 @@ def audio_rttm_map(manifest, attach_dur=False): """ This function creates AUDIO_RTTM_MAP which is used by all diarization components to extract embeddings, cluster and unify time stamps - Args: manifest file that contains keys audio_filepath, rttm_filepath if exists, text, num_speakers if known and uem_filepath if exists - returns: - AUDIO_RTTM_MAP (dict) : A dictionary with keys of uniq id, which is being used to map audio files and corresponding rttm files + Args: + manifest (str): Path to the manifest file + attach_dur (bool, optional): If True, attach duration information to the unique name. Defaults to False. + + Returns: + AUDIO_RTTM_MAP (dict) : Dictionary with unique names as keys and corresponding metadata as values. """ AUDIO_RTTM_MAP = {} @@ -108,15 +107,17 @@ def audio_rttm_map(manifest, attach_dur=False): if attach_dur: uniqname = get_uniq_id_with_dur(meta) else: - uniqname = get_uniqname_from_filepath(filepath=meta['audio_filepath']) + if "uniq_id" in dic.keys(): + uniqname = dic['uniq_id'] + else: + uniqname = get_uniqname_from_filepath(filepath=meta['audio_filepath']) if uniqname not in AUDIO_RTTM_MAP: AUDIO_RTTM_MAP[uniqname] = meta else: raise KeyError( - "file {} is already part of AUDIO_RTTM_MAP, it might be duplicated, Note: file basename must be unique".format( - meta['audio_filepath'] - ) + f"file {meta['audio_filepath']} is already part of AUDIO_RTTM_MAP, it might be duplicated, " + "Note: file basename must be unique" ) return AUDIO_RTTM_MAP @@ -144,7 +145,7 @@ def parse_scale_configs(window_lengths_in_sec, shift_lengths_in_sec, multiscale_ """ check_float_config = [isinstance(var, float) for var in (window_lengths_in_sec, shift_lengths_in_sec)] check_list_config = [ - isinstance(var, (omegaconf.listconfig.ListConfig, list, tuple)) + isinstance(var, (ListConfig, list, tuple)) for var in (window_lengths_in_sec, shift_lengths_in_sec, multiscale_weights) ] if all(check_list_config) or all(check_float_config): @@ -247,7 +248,8 @@ def get_embs_and_timestamps(multiscale_embeddings_and_timestamps, multiscale_arg def get_timestamps(multiscale_timestamps, multiscale_args_dict): """ The timestamps in `multiscale_timestamps` dictionary are indexed by scale index. - This function rearranges the extracted speaker embedding and timestamps by unique ID to make the further processing more convenient. + This function rearranges the extracted speaker embedding and timestamps by unique ID + to make the further processing more convenient. Args: multiscale_timestamps (dict): @@ -441,13 +443,20 @@ def perform_clustering( 'embeddings' : Tensor containing embeddings. Dimensions:(# of embs) x (emb. dimension) 'timestamps' : Tensor containing ime stamps list for each audio recording 'multiscale_segment_counts' : Tensor containing the number of segments for each scale - AUDIO_RTTM_MAP (dict): AUDIO_RTTM_MAP for mapping unique id with audio file path and rttm path - out_rttm_dir (str): Path to write predicted rttms - clustering_params (dict): clustering parameters provided through config that contains max_num_speakers (int), - oracle_num_speakers (bool), max_rp_threshold(float), sparse_search_volume(int) and enhance_count_threshold (int) - use_torch_script (bool): Boolean that determines whether to use torch.jit.script for speaker clustering - device (torch.device): Device we are running on ('cpu', 'cuda'). - verbose (bool): Enable TQDM progress bar. + AUDIO_RTTM_MAP (dict): + AUDIO_RTTM_MAP for mapping unique id with audio file path and rttm path + out_rttm_dir (str): + Path to write predicted rttms + clustering_params (dict): + Clustering parameters provided through config that contains max_num_speakers (int), + oracle_num_speakers (bool), max_rp_threshold(float), sparse_search_volume(int) + and enhance_count_threshold (int). + use_torch_script (bool): + Boolean that determines whether to use torch.jit.script for speaker clustering + device (torch.device): + Device we are running on ('cpu', 'cuda'). + verbose (bool): + Enable TQDM progress bar. Returns: all_reference (list[uniq_name,Annotation]): reference annotations for score calculation @@ -585,7 +594,7 @@ def write_overlap_segments(outfile, AUDIO_RTTM_MAP, uniq_id, overlap_range_list, Number of decimals to round the offset and duration values. """ audio_path = AUDIO_RTTM_MAP[uniq_id]['audio_filepath'] - for (stt, end) in overlap_range_list: + for stt, end in overlap_range_list: meta = { "audio_filepath": audio_path, "offset": round(stt, decimals), @@ -614,9 +623,8 @@ def read_rttm_lines(rttm_file_path): lines = f.readlines() else: raise FileNotFoundError( - "Requested to construct manifest from rttm with oracle VAD option or from NeMo VAD but received filename as {}".format( - rttm_file_path - ) + "Requested to construct manifest from rttm with oracle VAD option " + f"or from NeMo VAD but received filename as {rttm_file_path}" ) return lines @@ -745,14 +753,14 @@ def fl2int(x: float, decimals: int = 3) -> int: """ Convert floating point number to integer. """ - return torch.round(torch.tensor([x * (10 ** decimals)]), decimals=0).int().item() + return torch.round(torch.tensor([x * (10**decimals)]), decimals=0).int().item() def int2fl(x: int, decimals: int = 3) -> float: """ Convert integer to floating point number. """ - return torch.round(torch.tensor([x / (10 ** decimals)]), decimals=decimals).item() + return torch.round(torch.tensor([x / (10**decimals)]), decimals=decimals).item() def merge_float_intervals(ranges: List[List[float]], decimals: int = 5, margin: int = 2) -> List[List[float]]: @@ -886,7 +894,8 @@ def segments_manifest_to_subsegments_manifest( Generate subsegments manifest from segments manifest file Args: segments_manifest file (str): path to segments manifest file, typically from VAD output - subsegments_manifest_file (str): path to output subsegments manifest file (default (None) : writes to current working directory) + subsegments_manifest_file (str): path to output subsegments manifest file + (default (None) : writes to current working directory) window (float): window length for segments to subsegments length shift (float): hop length for subsegments shift min_subsegments_duration (float): exclude subsegments smaller than this duration value @@ -898,15 +907,16 @@ def segments_manifest_to_subsegments_manifest( pwd = os.getcwd() subsegments_manifest_file = os.path.join(pwd, 'subsegments.json') - with open(segments_manifest_file, 'r') as segments_manifest, open( - subsegments_manifest_file, 'w' - ) as subsegments_manifest: + with ( + open(segments_manifest_file, 'r') as segments_manifest, + open(subsegments_manifest_file, 'w') as subsegments_manifest, + ): segments = segments_manifest.readlines() for segment in segments: segment = segment.strip() dic = json.loads(segment) audio, offset, duration, label = dic['audio_filepath'], dic['offset'], dic['duration'], dic['label'] - subsegments = get_subsegments(offset=offset, window=window, shift=shift, duration=duration) + subsegments = get_subsegments_scriptable(offset=offset, window=window, shift=shift, duration=duration) if include_uniq_id and 'uniq_id' in dic: uniq_id = dic['uniq_id'] else: @@ -928,16 +938,82 @@ def segments_manifest_to_subsegments_manifest( return subsegments_manifest_file -def get_subsegments(offset: float, window: float, shift: float, duration: float) -> List[List[float]]: +def get_subsegments( + offset: float, + window: float, + shift: float, + duration: float, + min_subsegment_duration: float = 0.01, + decimals: int = 2, + use_asr_style_frame_count: bool = False, + sample_rate: int = 16000, + feat_per_sec: int = 100, +) -> List[List[float]]: + """ + Return subsegments from a segment of audio file. + + Example: + (window, shift) = 1.5, 0.75 + Segment: [12.05, 14.45] + Subsegments: [[12.05, 13.55], [12.8, 14.3], [13.55, 14.45], [14.3, 14.45]] + + Args: + offset (float): Start time of audio segment + window (float): Window length for segments to subsegments length + shift (float): Hop length for subsegments shift + duration (float): Duration of segment + min_subsegment_duration (float): Exclude subsegments smaller than this duration value + decimals (int): Number of decimal places to round to + use_asr_style_frame_count (bool): If True, use asr style frame count to generate subsegments. + For example, if duration is 10 secs and frame_shift is 0.08 secs, + it results in (10/0.08)+1 = 125 + 1 frames. + + Returns: + subsegments (List[tuple[float, float]]): subsegments generated for the segments as + list of tuple of start and duration of each subsegment + """ + subsegments: List[List[float]] = [] + start = offset + slice_end = start + duration + if min_subsegment_duration <= duration <= shift: + slices = 1 + elif use_asr_style_frame_count is True: + num_feat_frames = np.ceil((1 + duration * sample_rate) / int(sample_rate / feat_per_sec)).astype(int) + slices = np.ceil(num_feat_frames / int(feat_per_sec * shift)).astype(int) + slice_end = start + shift * slices + else: + slices = np.ceil(1 + (duration - window) / shift).astype(int) + if slices == 1: + if min(duration, window) >= min_subsegment_duration: + subsegments.append([start, min(duration, window)]) + elif slices > 0: # What if slcies = 0 ? + start_col = torch.arange(offset, slice_end, shift)[:slices] + dur_col_raw = torch.min( + slice_end * torch.ones_like(start_col) - start_col, window * torch.ones_like(start_col) + ) + dur_col = torch.round(dur_col_raw, decimals=decimals) + valid_mask = dur_col >= min_subsegment_duration + valid_subsegments = torch.stack([start_col[valid_mask], dur_col[valid_mask]], dim=1) + subsegments = valid_subsegments.tolist() + return subsegments + + +def get_subsegments_scriptable(offset: float, window: float, shift: float, duration: float) -> List[List[float]]: """ - Return subsegments from a segment of audio file + This function returns subsegments from a segment of an audio file. + Although this implementation is inefficient due to the use of a for-loop for segmentation, + it is designed to be torch-jit-scriptable. + Use `get_subsegments` for a more efficient implementation. + Args: offset (float): start time of audio segment window (float): window length for segments to subsegments length shift (float): hop length for subsegments shift duration (float): duration of segment Returns: - subsegments (List[tuple[float, float]]): subsegments generated for the segments as list of tuple of start and duration of each subsegment + subsegments (List[tuple[float, float]]): subsegments generated for the segments + as list of tuple of start and duration of + each subsegment """ subsegments: List[List[float]] = [] start = offset @@ -953,7 +1029,13 @@ def get_subsegments(offset: float, window: float, shift: float, duration: float) return subsegments -def get_target_sig(sig, start_sec: float, end_sec: float, slice_length: int, sample_rate: int,) -> torch.Tensor: +def get_target_sig( + sig, + start_sec: float, + end_sec: float, + slice_length: int, + sample_rate: int, +) -> torch.Tensor: """ Extract time-series signal from the given audio buffer based on the start and end timestamps. @@ -1000,6 +1082,34 @@ def tensor_to_list(range_tensor: torch.Tensor) -> List[List[float]]: return [[float(range_tensor[k][0]), float(range_tensor[k][1])] for k in range(range_tensor.shape[0])] +def generate_diarization_output_lines(speaker_timestamps: List[List[float]], model_spk_num: int) -> List[str]: + """ + Generate diarization output lines list from the speaker timestamps list by merging overlapping intervals. + + Args: + speaker_timestamps (list): + List containing the start and end time of the speech intervals for each speaker. + Example: + >>> speaker_timestamps = [[0.5, 3.12], [3.51, 7.26],... ] + model_spk_num (int): + Number of speakers in the model. + + Returns: + speaker_lines_total (list): + List containing the diarization output lines in the format: + "start_time end_time speaker_id" + Example: + >>> speaker_lines_total = ["0.5 3.12 speaker_0", "3.51 7.26 speaker_1",...] + """ + speaker_lines_total = [] + for spk_idx in range(model_spk_num): + ts_invervals = speaker_timestamps[spk_idx] + merged_ts_intervals = merge_float_intervals(ts_invervals) + for ts_interval in merged_ts_intervals: + speaker_lines_total.extend([f"{ts_interval[0]:.3f} {ts_interval[1]:.3f} speaker_{int(spk_idx)}"]) + return speaker_lines_total + + def get_speech_labels_for_update( frame_start: float, buffer_end: float, @@ -1067,9 +1177,12 @@ def get_speech_labels_for_update( return speech_label_for_new_segments, cumulative_speech_labels -def get_new_cursor_for_update(frame_start: float, segment_range_ts: List[List[float]],) -> Tuple[float, int]: +def get_new_cursor_for_update( + frame_start: float, + segment_range_ts: List[List[float]], +) -> Tuple[float, int]: """ - Function for updating a cursor online speaker diarization. + Function for updating a cursor online speaker diarization. Remove the old segments that overlap with the new frame (self.frame_start) cursor_for_old_segments is set to the onset of the t_range popped lastly. @@ -1226,8 +1339,11 @@ def get_online_subsegments_from_buffer( range_offs = [float(range_spl[0].item() - buffer_start), float(range_spl[1].item() - buffer_start)] range_t = [max(0, range_offs[0]), range_offs[1]] - subsegments = get_subsegments( - offset=range_t[0], window=window, shift=shift, duration=(range_t[1] - range_t[0]), + subsegments = get_subsegments_scriptable( + offset=range_t[0], + window=window, + shift=shift, + duration=(range_t[1] - range_t[0]), ) ind_offset, sigs, ranges, inds = get_online_segments_from_slices( sig=audio_buffer, @@ -1277,20 +1393,22 @@ def get_scale_mapping_argmat(uniq_embs_and_timestamps: Dict[str, dict]) -> Dict[ def get_overlap_stamps(cont_stamps: List[str], ovl_spk_idx: List[str]): """ - Generate timestamps that include overlap speech. Overlap-including timestamps are created based on the segments that are - created for clustering diarizer. Overlap speech is assigned to the existing speech segments in `cont_stamps`. + Generate timestamps that include overlap speech. Overlap-including timestamps are created based on + the segments that are created for clustering diarizer. Overlap speech is assigned to the existing + speech segments in `cont_stamps`. Args: cont_stamps (list): - Non-overlapping (single speaker per segment) diarization output in string format. - Each line contains the start and end time of segments and corresponding speaker labels. + Non-overlapping (single speaker per segment) diarization output in string format. Each line + contains the start and end time of segments and corresponding speaker labels. ovl_spk_idx (list): - List containing segment index of the estimated overlapped speech. The start and end of segments are based on the - single-speaker (i.e., non-overlap-aware) RTTM generation. + List containing segment index of the estimated overlapped speech. The start and end of + segments are based on the single-speaker (i.e., non-overlap-aware) RTTM generation. + Returns: total_ovl_cont_list (list): - Rendered diarization output in string format. Each line contains the start and end time of segments and - corresponding speaker labels. This format is identical to `cont_stamps`. + Rendered diarization output in string format. Each line contains the start and end time of + segments and corresponding speaker labels. This format is identical to `cont_stamps`. """ ovl_spk_cont_list = [[] for _ in range(len(ovl_spk_idx))] for spk_idx in range(len(ovl_spk_idx)): @@ -1307,18 +1425,21 @@ def get_overlap_stamps(cont_stamps: List[str], ovl_spk_idx: List[str]): def get_adaptive_threshold(estimated_num_of_spks: int, min_threshold: float, overlap_infer_spk_limit: int): """ - This function controls the magnitude of the sigmoid threshold based on the estimated number of speakers. As the number of - speakers becomes larger, diarization error rate is very sensitive on overlap speech detection. This function linearly increases - the threshold in proportion to the estimated number of speakers so more confident overlap speech results are reflected when - the number of estimated speakers are relatively high. + This function controls the magnitude of the sigmoid threshold based on the estimated number of + speakers. As the number of speakers becomes larger, diarization error rate is very sensitive + to overlap speech detection. This function linearly increases the threshold in proportion to + the estimated number of speakers so more confident overlap speech results are reflected when + the number of estimated speakers is relatively high. Args: estimated_num_of_spks (int): Estimated number of speakers from the clustering result. min_threshold (float): - Sigmoid threshold value from the config file. This threshold value is minimum threshold value when `estimated_num_of_spks=2` + Sigmoid threshold value from the config file. This threshold value is the minimum + threshold when `estimated_num_of_spks=2`. overlap_infer_spk_limit (int): - If the `estimated_num_of_spks` is less then `overlap_infer_spk_limit`, overlap speech estimation is skipped. + If the `estimated_num_of_spks` is less than `overlap_infer_spk_limit`, overlap speech + estimation is skipped. Returns: adaptive_threshold (float): @@ -1333,37 +1454,41 @@ def get_adaptive_threshold(estimated_num_of_spks: int, min_threshold: float, ove def generate_speaker_timestamps( clus_labels: List[Union[float, int]], msdd_preds: List[torch.Tensor], **params ) -> Tuple[List[str], List[str]]: - ''' - Generate speaker timestamps from the segmentation information. If `use_clus_as_main=True`, use clustering result for main speaker - labels and use timestamps from the predicted sigmoid values. In this function, the main speaker labels in `maj_labels` exist for - every subsegment steps while overlap speaker labels in `ovl_labels` only exist for segments where overlap-speech is occuring. + """ + Generate speaker timestamps from the segmentation information. If `use_clus_as_main=True`, use + clustering result for main speaker labels and use timestamps from the predicted sigmoid values. + In this function, the main speaker labels in `maj_labels` exist for every subsegment step, while + overlap speaker labels in `ovl_labels` only exist for segments where overlap speech occurs. Args: clus_labels (list): List containing integer-valued speaker clustering results. msdd_preds (list): - List containing tensors of the predicted sigmoid values. - Each tensor has shape of: (Session length, estimated number of speakers). + List containing tensors of the predicted sigmoid values. Each tensor has shape of: + (Session length, estimated number of speakers). params: Parameters for generating RTTM output and evaluation. Parameters include: - infer_overlap (bool): If False, overlap-speech will not be detected. - use_clus_as_main (bool): Add overlap-speech detection from MSDD to clustering results. If False, only MSDD output - is used for constructing output RTTM files. + infer_overlap (bool): If False, overlap speech will not be detected. + use_clus_as_main (bool): Add overlap-speech detection from MSDD to clustering results. + If False, only MSDD output is used for constructing output + RTTM files. overlap_infer_spk_limit (int): Above this limit, overlap-speech detection is bypassed. - use_adaptive_thres (bool): Boolean that determines whehther to use adaptive_threshold depending on the estimated - number of speakers. + use_adaptive_thres (bool): Boolean that determines whether to use adaptive thresholds + depending on the estimated number of speakers. max_overlap_spks (int): Maximum number of overlap speakers detected. Default is 2. threshold (float): Sigmoid threshold for MSDD output. Returns: maj_labels (list): - List containing string-formated single-speaker speech segment timestamps and corresponding speaker labels. + List containing string-formatted single-speaker speech segment timestamps and corresponding + speaker labels. Example: [..., '551.685 552.77 speaker_1', '552.99 554.43 speaker_0', '554.97 558.19 speaker_0', ...] ovl_labels (list): - List containing string-formated additional overlapping speech segment timestamps and corresponding speaker labels. - Note that `ovl_labels` includes only overlapping speech that is not included in `maj_labels`. + List containing string-formatted additional overlapping speech segment timestamps and + corresponding speaker labels. Note that `ovl_labels` includes only overlapping speech that + is not included in `maj_labels`. Example: [..., '152.495 152.745 speaker_1', '372.71 373.085 speaker_0', '554.97 555.885 speaker_1', ...] - ''' + """ msdd_preds.squeeze(0) estimated_num_of_spks = msdd_preds.shape[-1] overlap_speaker_list = [[] for _ in range(estimated_num_of_spks)] @@ -1398,8 +1523,7 @@ def generate_speaker_timestamps( def get_uniq_id_list_from_manifest(manifest_file: str): - """Retrieve `uniq_id` values from the given manifest_file and save the IDs to a list. - """ + """Retrieve `uniq_id` values from the given manifest_file and save the IDs to a list.""" uniq_id_list = [] with open(manifest_file, 'r', encoding='utf-8') as manifest: for i, line in enumerate(manifest.readlines()): @@ -1418,7 +1542,8 @@ def get_id_tup_dict(uniq_id_list: List[str], test_data_collection, preds_list: L uniq_id_list (list): List containing the `uniq_id` values. test_data_collection (collections.DiarizationLabelEntity): - Class instance that is containing session information such as targeted speaker indices, audio filepath and RTTM filepath. + Class instance that is containing session information such as targeted speaker indices, + audio filepath and RTTM filepath. preds_list (list): List containing tensors of predicted sigmoid values. @@ -1447,11 +1572,14 @@ def prepare_split_data(manifest_filepath, _out_dir, multiscale_args_dict, global Returns: multiscale_args_dict (dict): - - Dictionary containing two types of arguments: multi-scale weights and subsegment timestamps for each data sample. + - Dictionary containing two types of arguments: multi-scale weights and subsegment timestamps + for each data sample. - Each data sample has two keys: `multiscale_weights` and `scale_dict`. - `multiscale_weights` key contains a list containing multiscale weights. - `scale_dict` is indexed by integer keys which are scale index. - - Each data sample is indexed by using the following naming convention: `__` + - Each data sample is indexed by using the following naming convention: + `__` + Example: `fe_03_00106_mixed_626310_642300` """ speaker_dir = os.path.join(_out_dir, 'speaker_outputs') @@ -1580,6 +1708,86 @@ def make_rttm_with_overlap( return all_reference, all_hypothesis +def timestamps_to_pyannote_object( + speaker_timestamps: List[Tuple[float, float]], + uniq_id: str, + audio_rttm_values: Dict[str, str], + all_hypothesis: List[Tuple[str, Timeline]], + all_reference: List[Tuple[str, Timeline]], + all_uems: List[Tuple[str, Timeline]], + out_rttm_dir: str | None, +): + """ + Convert speaker timestamps to pyannote.core.Timeline object. + + Args: + speaker_timestamps (List[Tuple[float, float]]): + Timestamps of each speaker: start time and end time of each speaker. + uniq_id (str): + Unique ID of each speaker. + audio_rttm_values (Dict[str, str]): + Dictionary of manifest values. + all_hypothesis (List[Tuple[str, pyannote.core.Timeline]]): + List of hypothesis in pyannote.core.Timeline object. + all_reference (List[Tuple[str, pyannote.core.Timeline]]): + List of reference in pyannote.core.Timeline object. + all_uems (List[Tuple[str, pyannote.core.Timeline]]): + List of uems in pyannote.core.Timeline object. + out_rttm_dir (str | None): + Directory to save RTTMs + + Returns: + all_hypothesis (List[Tuple[str, pyannote.core.Timeline]]): + List of hypothesis in pyannote.core.Timeline object with an added Timeline object. + all_reference (List[Tuple[str, pyannote.core.Timeline]]): + List of reference in pyannote.core.Timeline object with an added Timeline object. + all_uems (List[Tuple[str, pyannote.core.Timeline]]): + List of uems in pyannote.core.Timeline object with an added Timeline object. + """ + offset, dur = float(audio_rttm_values.get('offset', None)), float(audio_rttm_values.get('duration', None)) + hyp_labels = generate_diarization_output_lines( + speaker_timestamps=speaker_timestamps, model_spk_num=len(speaker_timestamps) + ) + hypothesis = labels_to_pyannote_object(hyp_labels, uniq_name=uniq_id) + if out_rttm_dir is not None and os.path.exists(out_rttm_dir): + with open(f'{out_rttm_dir}/{uniq_id}.rttm', 'w') as f: + hypothesis.write_rttm(f) + all_hypothesis.append([uniq_id, hypothesis]) + rttm_file = audio_rttm_values.get('rttm_filepath', None) + if rttm_file is not None and os.path.exists(rttm_file): + uem_lines = [[offset, dur + offset]] + org_ref_labels = rttm_to_labels(rttm_file) + ref_labels = org_ref_labels + reference = labels_to_pyannote_object(ref_labels, uniq_name=uniq_id) + uem_obj = get_uem_object(uem_lines, uniq_id=uniq_id) + all_uems.append(uem_obj) + all_reference.append([uniq_id, reference]) + return all_hypothesis, all_reference, all_uems + + +def get_uem_object(uem_lines: List[List[float]], uniq_id: str): + """ + Generate pyannote timeline segments for uem file. + + file format + UNIQ_SPEAKER_ID CHANNEL START_TIME END_TIME + + Args: + uem_lines (list): list of session ID and start, end times. + Example: + [[0.0, 30.41], [60.04, 165.83]] + uniq_id (str): Unique session ID. + + Returns: + timeline (pyannote.core.Timeline): pyannote timeline object. + """ + timeline = Timeline(uri=uniq_id) + for uem_stt_end in uem_lines: + start_time, end_time = uem_stt_end + timeline.add(Segment(float(start_time), float(end_time))) + return timeline + + def embedding_normalize(embs, use_std=False, eps=1e-10): """ Mean and l2 length normalize the input speaker embeddings @@ -1635,7 +1843,7 @@ def run_online_segmentation( segment_indexes: List[int], window: float, shift: float, - ): + ) -> Tuple[List[torch.Tensor], List[List[float]], List[int]]: """ Remove the old segments that overlap with the new frame (self.frame_start) cursor_for_old_segments is pointing at the onset of the t_range popped most recently. diff --git a/nemo/collections/asr/parts/utils/vad_utils.py b/nemo/collections/asr/parts/utils/vad_utils.py index aea04b8cafcf..83a811ee4adb 100644 --- a/nemo/collections/asr/parts/utils/vad_utils.py +++ b/nemo/collections/asr/parts/utils/vad_utils.py @@ -23,31 +23,22 @@ from pathlib import Path from typing import Dict, List, Optional, Tuple, Union +import IPython.display as ipd import librosa import matplotlib.pyplot as plt import numpy as np import pandas as pd import torch -from omegaconf import DictConfig +from omegaconf import DictConfig, OmegaConf from pyannote.core import Annotation, Segment from pyannote.metrics import detection from sklearn.metrics import roc_auc_score from sklearn.model_selection import ParameterGrid from tqdm import tqdm - from nemo.collections.asr.models import EncDecClassificationModel, EncDecFrameClassificationModel from nemo.collections.common.parts.preprocessing.manifest import get_full_path from nemo.utils import logging -HAVE_IPYTHON = False -try: - import IPython.display as ipd - - HAVE_IPYTHON = True -except: - HAVE_IPYTHON = False - - """ This file contains all the utility functions required for voice activity detection. """ @@ -74,8 +65,8 @@ def prepare_manifest(config: dict) -> str: input_list = config['input'] else: raise ValueError( - "The input for manifest preparation would either be a string of the filepath to \ - manifest or a list of {'audio_filepath': i, 'offset': 0, 'duration': null} " + "The input for manifest preparation would either be a string of the filepath to manifest " + "or a list of {'audio_filepath': i, 'offset': 0, 'duration': null}." ) args_func = { @@ -204,8 +195,7 @@ def write_vad_infer_manifest(file: dict, args_func: dict) -> list: def get_vad_stream_status(data: list) -> list: """ - Generate a list of status for each snippet in manifest. - A snippet should be in single, start, next or end status. + Generate a list of status for each snippet in manifest. A snippet should be in single, start, next or end status. Used for concatenating to full audio file. Args: data (list): list of filepath of audio snippet @@ -321,9 +311,8 @@ def generate_overlap_vad_seq_per_tensor( frame: torch.Tensor, per_args: Dict[str, float], smoothing_method: str ) -> torch.Tensor: """ - Use generated frame prediction (generated by shifting window of shift_length_in_sec (10ms)) - to generate prediction with overlapping input window/segments - See description in generate_overlap_vad_seq. + Use generated frame prediction (generated by shifting window of shift_length_in_sec (10ms)) to generate + prediction with overlapping input window/segments. See description in generate_overlap_vad_seq. Use this for single instance pipeline. """ # This function will be refactor for vectorization but this is okay for now @@ -484,8 +473,8 @@ def binarization(sequence: torch.Tensor, per_args: Dict[str, float]) -> torch.Te Binarize predictions to speech and non-speech Reference - Paper: Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of RNN-based Voice Activity Detection", \ - InterSpeech 2015. + Paper: Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of RNN-based Voice + Activity Detection", InterSpeech 2015. Implementation: https://github.com/pyannote/pyannote-audio/blob/master/pyannote/audio/utils/signal.py Args: @@ -498,8 +487,8 @@ def binarization(sequence: torch.Tensor, per_args: Dict[str, float]) -> torch.Te frame_length_in_sec (float): length of frame. Returns: - speech_segments(torch.Tensor): A tensor of speech segment in torch.Tensor([[start1, end1], [start2, end2]]) \ - format. + speech_segments(torch.Tensor): A tensor of speech segment in the form of: + `torch.Tensor([[start1, end1], [start2, end2]])`. """ frame_length_in_sec = per_args.get('frame_length_in_sec', 0.01) @@ -549,11 +538,10 @@ def binarization(sequence: torch.Tensor, per_args: Dict[str, float]) -> torch.Te def remove_segments(original_segments: torch.Tensor, to_be_removed_segments: torch.Tensor) -> torch.Tensor: """ Remove speech segments list in to_be_removed_segments from original_segments. - For example, - remove torch.Tensor([[start2, end2],[start4, end4]]) from torch.Tensor([[start1, end1],[start2, end2],\ - [start3, end3], [start4, end4]]), - -> - torch.Tensor([[start1, end1],[start3, end3]]) + (Example) Remove torch.Tensor([[start2, end2],[start4, end4]]) + from torch.Tensor([[start1, end1],[start2, end2],[start3, end3], [start4, end4]]), + -> + torch.Tensor([[start1, end1],[start3, end3]]) """ for y in to_be_removed_segments: original_segments = original_segments[original_segments.eq(y).all(dim=1).logical_not()] @@ -574,24 +562,30 @@ def get_gap_segments(segments: torch.Tensor) -> torch.Tensor: @torch.jit.script def filtering(speech_segments: torch.Tensor, per_args: Dict[str, float]) -> torch.Tensor: """ - Filter out short non_speech and speech segments. + Filter out short non-speech and speech segments. + + Reference: + Paper: Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of RNN-based Voice + Activity Detection", InterSpeech 2015. + Implementation: + https://github.com/pyannote/pyannote-audio/blob/master/pyannote/audio/utils/signal.py - Reference - Paper: Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of RNN-based Voice Activity Detection", \ - InterSpeech 2015. - Implementation: https://github.com/pyannote/pyannote-audio/blob/master/pyannote/audio/utils/signal.py Args: - speech_segments (torch.Tensor): A tensor of speech segment in torch.Tensor([[start1, end1], \ - [start2, end2]]) format. + speech_segments (torch.Tensor): + A tensor of speech segments in the format + torch.Tensor([[start1, end1], [start2, end2]]). per_args: - min_duration_on (float): threshold for small non_speech deletion - min_duration_off (float): threshold for short speech segment deletion - filter_speech_first (float): Whether to perform short speech segment deletion first. \ - Use 1.0 to represent True. + min_duration_on (float): + Threshold for small non-speech deletion. + min_duration_off (float): + Threshold for short speech segment deletion. + filter_speech_first (float): + Whether to perform short speech segment deletion first. Use 1.0 to represent True. Returns: - speech_segments(torch.Tensor): A tensor of filtered speech segment in \ - torch.Tensor([[start1, end1], [start2, end2]]) format. + speech_segments (torch.Tensor): + A tensor of filtered speech segments in the format + torch.Tensor([[start1, end1], [start2, end2]]). """ if speech_segments.shape == torch.Size([0]): return speech_segments @@ -840,18 +834,19 @@ def vad_tune_threshold_on_dev( num_workers: int = 20, ) -> Tuple[dict, dict]: """ - Tune thresholds on dev set. Return best thresholds which gives the lowest detection error rate - (DetER) in thresholds. + Tune thresholds on dev set. Return best thresholds which gives the lowest + detection error rate (DetER) in thresholds. + Args: params (dict): dictionary of parameters to be tuned on. vad_pred_method (str): suffix of prediction file. Use to locate file. - Should be either in "frame", "mean" or "median". - groundtruth_RTTM_dir (str): directory of ground-truth rttm files or a file contains the paths of them. - focus_metric (str): metrics we care most when tuning threshold. Should be either in "DetER", "FA", "MISS" - frame_length_in_sec (float): frame length. - num_workers (int): number of workers. + Should be either in "frame", "mean" or "median". + groundtruth_RTTM_dir (str): Directory of ground-truth rttm files or a file contains the paths of them. + focus_metric (str): Metrics we care most when tuning threshold. Should be either in "DetER", "FA", "MISS" + frame_length_in_sec (float): Frame length. + num_workers (int): Number of workers. Returns: - best_threshold (float): threshold that gives lowest DetER. + best_threshold (float): Threshold that gives lowest DetER. """ min_score = 100 all_perf = {} @@ -936,8 +931,7 @@ def check_if_param_valid(params: dict) -> bool: for j in params[i]: if not j >= 0: raise ValueError( - "Invalid inputs! All float parameters except pad_onset and pad_offset should be \ - larger than 0!" + "Invalid inputs! All float parameters except pad_onset and pad_offset should be larger than 0!" ) if not (all(i <= 1 for i in params['onset']) and all(i <= 1 for i in params['offset'])): @@ -995,7 +989,7 @@ def plot( unit_frame_len: float = 0.01, label_repeat: int = 1, xticks_step: int = 5, -) -> "ipd.Audio": +) -> ipd.Audio: """ Plot Audio and/or VAD output and/or groundtruth labels for visualization Args: @@ -1009,13 +1003,10 @@ def plot( threshold (float): threshold for prediction score (from 0 to 1). per_args(dict): a dict that stores the thresholds for postprocessing. unit_frame_len (float): unit frame length in seconds for VAD predictions. - label_repeat (int): repeat the label for this number of times to match different \ - frame lengths in preds and labels. + label_repeat (int): repeat the label for this number of times to match different + frame lengths in preds and labels. xticks_step (int): step size for xticks. """ - if HAVE_IPYTHON is False: - raise ImportError("IPython is not installed. Please install IPython to use this function.") - plt.figure(figsize=[20, 2]) audio, sample_rate = librosa.load( @@ -1281,8 +1272,8 @@ def stitch_segmented_asr_output( fout.flush() logging.info( - f"Finish stitch segmented ASR output to {stitched_output_manifest}, \ - the speech segments info has been stored in directory {speech_segments_tensor_dir}" + f"Finish stitch segmented ASR output to {stitched_output_manifest}, " + f"the speech segments info has been stored in directory {speech_segments_tensor_dir}" ) return stitched_output_manifest @@ -1462,13 +1453,10 @@ def plot_sample_from_rttm( show: bool = True, offset: float = 0.0, unit_frame_len: float = 0.01, -) -> "ipd.Audio": +): """ Plot audio signal and frame-level labels from RTTM file """ - if HAVE_IPYTHON is False: - raise ImportError("IPython is not installed. Please install IPython to use this function.") - plt.figure(figsize=[20, 2]) audio, sample_rate = librosa.load(path=audio_file, sr=16000, mono=True, offset=offset, duration=max_duration) @@ -1502,17 +1490,22 @@ def plot_sample_from_rttm( def align_labels_to_frames(probs, labels, threshold=0.2): """ - Aligns labels to frames when the frame length (e.g., 10ms) is different from the label length (e.g., 20ms). - The threshold 0.2 is not important, since the actual ratio will always be close to an integer - unless using frame/label. lengths that are not multiples of each other - (e.g., 15ms frame length and 20ms label length), which is not valid. - The value 0.2 here is just for easier unit testing. + Aligns labels to frames when the frame length (e.g., 10ms) is different from the label length + (e.g., 20ms). The threshold 0.2 is not critical, as the actual ratio will always be close to an + integer unless using frame/label lengths that are not multiples of each other (e.g., 15ms frame + length and 20ms label length), which is not valid. The value 0.2 is chosen for easier unit testing. + Args: - probs (List[float]): list of probabilities - labels (List[int]): list of labels - threshold (float): threshold for rounding ratio to integer + probs (List[float]): + List of probabilities. + labels (List[int]): + List of labels. + threshold (float): + Threshold for rounding the ratio to an integer. + Returns: - labels (List[int]): list of labels aligned to frames + labels (List[int]): + List of labels aligned to frames. """ frames_len = len(probs) labels_len = len(labels) @@ -1543,13 +1536,13 @@ def align_labels_to_frames(probs, labels, threshold=0.2): ratio = frames_len / labels_len res = frames_len % labels_len if ceil(ratio) - ratio < threshold: - # e.g., ratio is 1.83, ceil(ratio) = 2, then we repeat labels to make it a - # multiple of 2, and discard the redundant labels + # e.g., ratio is 1.83, ceil(ratio) = 2, then we repeat labels + # to make it a multiple of 2, and discard the redundant labels labels = labels.repeat_interleave(ceil(ratio), dim=0).long().tolist() labels = labels[:frames_len] else: - # e.g., ratio is 2.02, floor(ratio) = 2, then we repeat labels to make it a multiple of - # 2 and add additional labels + # e.g., ratio is 2.02, floor(ratio) = 2, then we repeat labels + # to make it a multiple of 2 and add additional labels labels = labels.repeat_interleave(floor(ratio), dim=0).long().tolist() if res > 0: labels += labels[-res:] @@ -1743,3 +1736,52 @@ def frame_vad_eval_detection_error( auroc = roc_auc_score(y_true=all_labels, y_score=all_probs) report = metric.report(display=False) return auroc, report + + +def ts_vad_post_processing( + ts_vad_binary_vec: torch.Tensor, + cfg_vad_params: OmegaConf, + unit_10ms_frame_count: int = 8, + bypass_postprocessing: bool = False, +): + """ + Post-processing on diarization results using VAD style post-processing methods. + These post-processing methods are inspired by the following paper: + Medennikov, Ivan, et al. "Target-Speaker Voice Activity Detection: + a Novel Approach for Multi-Speaker Diarization in a Dinner Party Scenario." (2020). + + Args: + ts_vad_binary_vec (Tensor): + Sigmoid values of each frame and each speaker. + Dimension: (num_frames,) + cfg_vad_params (OmegaConf): + Configuration (omega config) of VAD parameters. + unit_10ms_frame_count (int, optional): + an integer indicating the number of 10ms frames in a unit. + For example, if unit_10ms_frame_count is 8, then each frame is 0.08 seconds. + bypass_postprocessing (bool, optional): + If True, diarization post-processing will be bypassed. + + Returns: + speech_segments (Tensor): + start and end of each speech segment. + Dimension: (num_segments, 2) + + Example: + tensor([[ 0.0000, 3.0400], + [ 6.0000, 6.0800], + ... + [587.3600, 591.0400], + [591.1200, 597.7600]]) + """ + ts_vad_binary_frames = torch.repeat_interleave(ts_vad_binary_vec, unit_10ms_frame_count) + if not bypass_postprocessing: + speech_segments = binarization(ts_vad_binary_frames, cfg_vad_params) + speech_segments = filtering(speech_segments, cfg_vad_params) + else: + cfg_vad_params.onset = 0.5 + cfg_vad_params.offset = 0.5 + cfg_vad_params.pad_onset = 0.0 + cfg_vad_params.pad_offset = 0.0 + speech_segments = binarization(ts_vad_binary_frames, cfg_vad_params) + return speech_segments diff --git a/nemo/collections/asr/parts/utils/wfst_utils.py b/nemo/collections/asr/parts/utils/wfst_utils.py index 31f394fb60ac..9dbb9fc751b2 100644 --- a/nemo/collections/asr/parts/utils/wfst_utils.py +++ b/nemo/collections/asr/parts/utils/wfst_utils.py @@ -32,7 +32,7 @@ import kaldifst # check that kaldifst package is not empty - # Note: pytorch_lightning.utilities.imports.package_available may not help here + # Note: lightning.pytorch.utilities.imports.package_available may not help here kaldifst.StdVectorFst() _KALDIFST_AVAILABLE = True except (ImportError, ModuleNotFoundError, AttributeError): diff --git a/nemo/collections/audio/models/audio_to_audio.py b/nemo/collections/audio/models/audio_to_audio.py index e1732c1658b7..60c16f756f58 100644 --- a/nemo/collections/audio/models/audio_to_audio.py +++ b/nemo/collections/audio/models/audio_to_audio.py @@ -22,8 +22,8 @@ import librosa import soundfile as sf import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig, OmegaConf -from pytorch_lightning import Trainer from tqdm import tqdm from nemo.collections.asr.data.audio_to_text_dataset import inject_dataloader_value_from_model_config diff --git a/nemo/collections/audio/models/enhancement.py b/nemo/collections/audio/models/enhancement.py index cd9f47b98096..8e2206afcef1 100644 --- a/nemo/collections/audio/models/enhancement.py +++ b/nemo/collections/audio/models/enhancement.py @@ -17,8 +17,8 @@ import einops import hydra import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig -from pytorch_lightning import Trainer from nemo.collections.audio.models.audio_to_audio import AudioToAudioModel from nemo.core.classes.common import PretrainedModelInfo, typecheck diff --git a/nemo/collections/audio/parts/utils/callbacks.py b/nemo/collections/audio/parts/utils/callbacks.py index 093d5a11f419..ff975c93ecc7 100644 --- a/nemo/collections/audio/parts/utils/callbacks.py +++ b/nemo/collections/audio/parts/utils/callbacks.py @@ -16,10 +16,10 @@ import einops import torch -from pytorch_lightning import Callback, LightningModule, Trainer -from pytorch_lightning.loggers import TensorBoardLogger -from pytorch_lightning.loggers.logger import Logger -from pytorch_lightning.loggers.wandb import WandbLogger +from lightning.pytorch import Callback, LightningModule, Trainer +from lightning.pytorch.loggers import TensorBoardLogger +from lightning.pytorch.loggers.logger import Logger +from lightning.pytorch.loggers.wandb import WandbLogger from nemo.utils import logging from nemo.utils.decorators import experimental diff --git a/nemo/collections/common/callbacks/callbacks.py b/nemo/collections/common/callbacks/callbacks.py index 1a6c011c38df..754b33726faf 100644 --- a/nemo/collections/common/callbacks/callbacks.py +++ b/nemo/collections/common/callbacks/callbacks.py @@ -13,15 +13,14 @@ # limitations under the License. import time -from pytorch_lightning.callbacks import Callback -from pytorch_lightning.utilities import rank_zero_only +from lightning.pytorch.callbacks import Callback +from lightning.pytorch.utilities import rank_zero_only # from sacrebleu import corpus_bleu class LogEpochTimeCallback(Callback): - """Simple callback that logs how long each epoch takes, in seconds, to a pytorch lightning log - """ + """Simple callback that logs how long each epoch takes, in seconds, to a pytorch lightning log""" @rank_zero_only def on_train_epoch_start(self, trainer, pl_module): diff --git a/nemo/collections/common/callbacks/ema.py b/nemo/collections/common/callbacks/ema.py index 2f295bf67354..f866a2639d63 100644 --- a/nemo/collections/common/callbacks/ema.py +++ b/nemo/collections/common/callbacks/ema.py @@ -17,11 +17,11 @@ import threading from typing import Any, Dict, Iterable -import pytorch_lightning as pl +import lightning.pytorch as pl import torch -from pytorch_lightning import Callback -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.rank_zero import rank_zero_info +from lightning.pytorch import Callback +from lightning.pytorch.utilities.exceptions import MisconfigurationException +from lightning.pytorch.utilities.rank_zero import rank_zero_info class EMA(Callback): @@ -40,7 +40,11 @@ class EMA(Callback): """ def __init__( - self, decay: float, validate_original_weights: bool = False, every_n_steps: int = 1, cpu_offload: bool = False, + self, + decay: float, + validate_original_weights: bool = False, + every_n_steps: int = 1, + cpu_offload: bool = False, ): if not (0 <= decay <= 1): raise MisconfigurationException("EMA decay value must be between 0 and 1") @@ -149,7 +153,9 @@ def on_load_checkpoint( def ema_update(ema_model_tuple, current_model_tuple, decay): torch._foreach_mul_(ema_model_tuple, decay) torch._foreach_add_( - ema_model_tuple, current_model_tuple, alpha=(1.0 - decay), + ema_model_tuple, + current_model_tuple, + alpha=(1.0 - decay), ) @@ -272,7 +278,13 @@ def update(self): if self.device.type == 'cpu': self.thread = threading.Thread( - target=run_ema_update_cpu, args=(self.ema_params, current_model_state, self.decay, self.stream,), + target=run_ema_update_cpu, + args=( + self.ema_params, + current_model_state, + self.decay, + self.stream, + ), ) self.thread.start() diff --git a/nemo/collections/common/data/lhotse/dataloader.py b/nemo/collections/common/data/lhotse/dataloader.py index 98b63a07fa9d..bf6b77ad907e 100644 --- a/nemo/collections/common/data/lhotse/dataloader.py +++ b/nemo/collections/common/data/lhotse/dataloader.py @@ -147,6 +147,28 @@ class LhotseDataLoadingConfig: # In most cases (such as regular multi-GPU training) it will result in a deadlock due to # a different number of steps on different DDP ranks. force_finite: bool = False + # The following two options may be used to override auto-detection of appropriate PyTorch dataset flavor + # for your data types. PyTorch DataLoader uses two objects to yield data: dataset and sampler. + # *Map-dataset flavor.* There is one sampler per GPU that lives in the training loop process; + # it selects the examples to be prepared by map-dataset class. Each batch selection determined by the sampler + # is then passed by the dataloader to one of its worker processes to be processed by the dataset class. + # *Iterable-dataset flavor.* Each dataloading worker has its own sampler replica instead; + # the sampler must have the logic for either data deduplication or unique order shuffling to avoid + # duplicated data across workers and GPUs. Lhotse relies on unique order shuffling. + # The default settings are: + # * use iterable dataset for tarred audio data. + # * use iterable dataset for any text data. + # * use map dataset for non-tarred audio data (we might change this in the future) + force_map_dataset: bool = False + force_iterable_dataset: bool = False + + +def determine_use_iterable_dataset(use_iterable_dataset: bool, config: DictConfig) -> bool: + assert not ( + config.force_map_dataset and config.force_iterable_dataset + ), "Conflicting options: force_map_dataset=True and force_iterable_dataset=True" + use_iterable_dataset = (use_iterable_dataset or config.force_iterable_dataset) and not config.force_map_dataset + return use_iterable_dataset def get_lhotse_dataloader_from_config( @@ -176,7 +198,6 @@ def get_lhotse_dataloader_from_config( Note that ``tokenizer`` can be any tokenizer type (e.g. both SentencePiece and Aggregate tokenizers work). """ logging.info("We will be using a Lhotse DataLoader.") - config = make_structured_with_schema_warnings(config) maybe_set_cuda_expandable_segments(enabled=config.cuda_expandable_segments) @@ -186,8 +207,8 @@ def get_lhotse_dataloader_from_config( fix_random_seed(seed) # 1. Load a manifest as a Lhotse CutSet. - cuts, is_tarred = read_cutset_from_config(config) - + cuts, use_iterable_dataset = read_cutset_from_config(config) + use_iterable_dataset = determine_use_iterable_dataset(use_iterable_dataset, config) # Apply channel selector if config.channel_selector is not None: logging.info('Using channel selector %s.', config.channel_selector) @@ -202,7 +223,7 @@ def get_lhotse_dataloader_from_config( if tokenizer is not None and config.pretokenize: from nemo.collections.asr.data.audio_to_text_lhotse import TokenizerWrapper - if not is_tarred: + if not use_iterable_dataset: logging.warning( "You are using a non-tarred dataset and requested tokenization during data sampling (pretokenize=True). " "This will cause the tokenization to happen in the main (GPU) process, possibly impacting the training speed " @@ -317,8 +338,8 @@ def get_lhotse_dataloader_from_config( duration_bins=determine_bucket_duration_bins(config), num_cuts_for_bins_estimate=config.num_cuts_for_bins_estimate, buffer_size=config.bucket_buffer_size, - rank=0 if is_tarred else global_rank, - world_size=1 if is_tarred else world_size, + rank=0 if use_iterable_dataset else global_rank, + world_size=1 if use_iterable_dataset else world_size, ) else: # Non-bucketing sampler, similar to original NeMo dataloading without bucketing, @@ -335,8 +356,8 @@ def get_lhotse_dataloader_from_config( drop_last=config.drop_last, shuffle_buffer_size=config.shuffle_buffer_size, seed=config.shard_seed, - rank=0 if is_tarred else global_rank, - world_size=1 if is_tarred else world_size, + rank=0 if use_iterable_dataset else global_rank, + world_size=1 if use_iterable_dataset else world_size, ) if config.concatenate_samples: @@ -368,7 +389,7 @@ def get_lhotse_dataloader_from_config( ) # 4. Creating dataloader. - if is_tarred and not config.tarred_random_access: + if use_iterable_dataset and not config.tarred_random_access: # Wrapper here is necessary when using NeMo tarred data or Lhotse Shar data, # because then I/O happens upon sampler iteration. Normally, the sampler resides # in the training loop process, but when we use iterable dataset, we can move it to @@ -601,8 +622,8 @@ class DurationFilter: """Callable, returns ``True`` if a cut's duration is in range [d_min, d_max] and ``False`` otherwise.""" def __init__(self, d_min: float, d_max: float) -> None: - self.d_min = d_min - self.d_max = d_max + self.d_min = d_min if d_min is not None else -1.0 + self.d_max = d_max if d_max is not None else float("inf") def __call__(self, example) -> bool: if isinstance(example, Cut): diff --git a/nemo/collections/common/data/lhotse/nemo_adapters.py b/nemo/collections/common/data/lhotse/nemo_adapters.py index ee623f617e26..a34a2c074a11 100644 --- a/nemo/collections/common/data/lhotse/nemo_adapters.py +++ b/nemo/collections/common/data/lhotse/nemo_adapters.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging import random import re @@ -398,40 +397,43 @@ def basename(d: dict) -> str: shard_manifest: dict[str, list[dict]] = groupby(basename, self.shard_id_to_manifest[sid]) tar_path = self.shard_id_to_tar_path[sid] - for data, raw_audio, tar_info in iter_fn(tar_path, shard_manifest, manifest_path): - meta = soundfile.info(BytesIO(raw_audio)) - recording = Recording( - id=tar_info.path, - sources=[AudioSource(type="memory", channels=list(range(meta.channels)), source=raw_audio)], - sampling_rate=int(meta.samplerate), - num_samples=meta.frames, - duration=meta.duration, - ) - cuts_for_recording = [] - for data in sorted(shard_manifest[tar_info.name], key=lambda d: d["audio_filepath"]): - # Cut the recording into corresponding segment and discard audio data outside the segment. - cut = make_cut_with_subset_inmemory_recording( - recording, offset=data.get("offset", 0.0), duration=data.get("duration") + try: + for data, raw_audio, tar_info in iter_fn(tar_path, shard_manifest, manifest_path): + meta = soundfile.info(BytesIO(raw_audio)) + recording = Recording( + id=tar_info.path, + sources=[AudioSource(type="memory", channels=list(range(meta.channels)), source=raw_audio)], + sampling_rate=int(meta.samplerate), + num_samples=meta.frames, + duration=meta.duration, ) - cut.supervisions.append( - SupervisionSegment( - id=cut.id, - recording_id=cut.recording_id, - start=0, - duration=cut.duration, - text=data.get(self.text_field), - language=data.get(self.lang_field), + cuts_for_recording = [] + for data in sorted(shard_manifest[tar_info.name], key=lambda d: d["audio_filepath"]): + # Cut the recording into corresponding segment and discard audio data outside the segment. + cut = make_cut_with_subset_inmemory_recording( + recording, offset=data.get("offset", 0.0), duration=data.get("duration") ) - ) - cut.custom = _to_custom_attr_dict(data) - cut.manifest_origin = manifest_path - cut.tar_origin = tar_path - for extra_field in extra_fields: - extra_field.attach_to(cut) - cuts_for_recording.append(cut) - del recording # free the memory - helps with very large audio files - del raw_audio - yield from cuts_for_recording + cut.supervisions.append( + SupervisionSegment( + id=cut.id, + recording_id=cut.recording_id, + start=0, + duration=cut.duration, + text=data.get(self.text_field), + language=data.get(self.lang_field), + ) + ) + cut.custom = _to_custom_attr_dict(data) + cut.manifest_origin = manifest_path + cut.tar_origin = tar_path + for extra_field in extra_fields: + extra_field.attach_to(cut) + cuts_for_recording.append(cut) + del recording # free the memory - helps with very large audio files + del raw_audio + yield from cuts_for_recording + except tarfile.ReadError: + logging.warning(f"Skipping tar file due to read errors (unstable storage or bad file?): {tar_path=}") def __len__(self) -> int: return len(self.source) diff --git a/nemo/collections/common/metrics/perf_metrics.py b/nemo/collections/common/metrics/perf_metrics.py index a6b001c884e4..285bf7b4f9d6 100644 --- a/nemo/collections/common/metrics/perf_metrics.py +++ b/nemo/collections/common/metrics/perf_metrics.py @@ -15,7 +15,7 @@ from typing import Any, Dict, List, Optional import numpy as np -from pytorch_lightning.callbacks import Callback +from lightning.pytorch.callbacks import Callback from nemo.collections.common.parts.perf_metrics_utils import LLM_VOCAB_SIZE_MAP, read_tb_log from nemo.utils import logging diff --git a/nemo/collections/common/parts/preprocessing/collections.py b/nemo/collections/common/parts/preprocessing/collections.py index b16ac50e4d56..d54c807f2637 100644 --- a/nemo/collections/common/parts/preprocessing/collections.py +++ b/nemo/collections/common/parts/preprocessing/collections.py @@ -308,6 +308,135 @@ def __init__( super().__init__(data) +class InstructionTuningAudioText(_Collection): + """`AudioText` collector from asr structured json files.""" + + OUTPUT_TYPE = collections.namedtuple( + typename='InstructionTuningText', + field_names=( + 'id context context_type context_duration question ' + 'question_type answer answer_type answer_duration speaker' + ), + ) + + def __init__( + self, + manifests_files: Union[str, List[str]], + min_duration: Optional[float] = None, + max_duration: Optional[float] = None, + max_seq_length: Optional[float] = None, + max_number: Optional[int] = None, + do_sort_by_duration: bool = False, + index_by_file_id: bool = False, + decoder_only_model: bool = False, + use_phoneme_tokenizer: bool = False, + ): + """Parse lists of audio files, durations and transcripts texts. + Args: + manifests_files: Either single string file or list of such - + manifests to yield items from. + *args: Args to pass to `AudioText` constructor. + **kwargs: Kwargs to pass to `AudioText` constructor. + """ + + output_type = self.OUTPUT_TYPE + self.use_phoneme_tokenizer = use_phoneme_tokenizer + data, duration_filtered, num_filtered, total_duration = [], 0.0, 0, 0.0 + if index_by_file_id: + self.mapping = {} + + for item in manifest.item_iter(manifests_files): + + id = item['id'] + context = item['context'] + context_duration = item['context_duration'] + context_type = item['context_type'] + question = item['question'] + question_type = item['question_type'] + speaker = item['speaker'] + answer = item['answer'] + answer_duration = item['answer_duration'] + answer_type = item['answer_type'] + task = item['task'] + + task = 'tts' if task is None else task + duration = answer_duration if task == 'tts' else context_duration + if min_duration is not None and duration < min_duration: + duration_filtered += duration + num_filtered += 1 + continue + + if max_duration is not None and duration > max_duration: + duration_filtered += duration + num_filtered += 1 + continue + + # Check segment length + approx_context_len = min(self._get_len(context_type, context, context_duration) * 0.3, 400) + approx_question_len = self._get_len(question_type, question, None) + approx_answer_len = self._get_len(answer_type, answer, answer_duration) + + if ( + decoder_only_model and approx_context_len + approx_question_len + approx_answer_len >= max_seq_length + ) or (approx_context_len + approx_question_len >= max_seq_length or approx_answer_len >= max_seq_length): + duration_filtered += duration + num_filtered += 1 + continue + + total_duration += duration + data.append( + output_type( + id, + context, + context_type, + context_duration, + question, + question_type, + answer, + answer_type, + answer_duration, + speaker, + ) + ) + + if index_by_file_id: + file_id, _ = os.path.splitext(os.path.basename(context)) + if ".context" in file_id: + file_id = file_id[:-8] + if file_id not in self.mapping: + self.mapping[file_id] = [] + self.mapping[file_id].append(len(data) - 1) + + # Max number of entities filter. + if len(data) == max_number: + break + + if do_sort_by_duration: + if index_by_file_id: + logging.warning("Tried to sort dataset by duration, but cannot since index_by_file_id is set.") + else: + data.sort(key=lambda entity: entity.duration) + + logging.info("Dataset loaded with %d files totalling %.2f hours", len(data), total_duration / 3600) + logging.info("%d files were filtered totalling %.2f hours", num_filtered, duration_filtered / 3600) + + super().__init__(data) + + def _get_len(self, field_type, data, duration_data): + if field_type == "SPEECH": + return duration_data * 76 # TODO: add explanation for the hardcoded value. + elif field_type == "TEXT": + if self.use_phoneme_tokenizer: + # Approx len is number of characters + return len(data) + else: + return len(data.split(' ')) + 3 # # TODO: add explanation for the hardcoded value. + elif field_type == "TOKENS": + return len(data) + 3 + else: + raise ValueError(f"Unknown field type {field_type}.") + + class ASRAudioText(AudioText): """`AudioText` collector from asr structured json files.""" @@ -352,7 +481,10 @@ def __init__(self, manifests_files: Union[str, List[str]], parse_func: Optional[ class SpeechLLMAudioTextEntity(object): + """Class for SpeechLLM dataloader instance.""" + def __init__(self, sid, audio_file, duration, context, answer, offset, speaker, orig_sr, lang) -> None: + """Initialize the AudioTextEntity for a SpeechLLM dataloader instance.""" self.id = sid self.audio_file = audio_file self.duration = duration @@ -433,7 +565,6 @@ def __init__( ): """Instantiates audio-context-answer manifest with filters and preprocessing. - Args: ids: List of examples positions. audio_files: List of audio files. @@ -644,7 +775,8 @@ def __parse_item(self, line: str, manifest_file: str) -> Dict[str, Any]: elif 'question' in item: # compatability with old manifests that uses 'question' as context key logging.warning( - f"Neither `{self.context_key}` is found nor `context_file` is set, but found `question` in item: {item}", + f"Neither `{self.context_key}` is found nor" + f"`context_file` is set, but found `question` in item: {item}", mode=logging_mode.ONCE, ) item['context'] = item.pop('question') @@ -741,7 +873,8 @@ def __init__( else: logging.info(f"Filtered duration for loading collection is {duration_filtered / 3600: .2f} hours.") logging.info( - f"Dataset successfully loaded with {len(data)} items and total duration provided from manifest is {total_duration / 3600: .2f} hours." + f"Dataset successfully loaded with {len(data)} items " + f"and total duration provided from manifest is {total_duration / 3600: .2f} hours." ) self.uniq_labels = sorted(set(map(lambda x: x.label, data))) @@ -882,13 +1015,15 @@ def __init__( if len(data) == max_number: break - logging.info("# {} files loaded including # {} unique labels".format(len(data), len(self.uniq_labels))) + logging.info(f"# {len(data)} files loaded including # {len(self.uniq_labels)} unique labels") super().__init__(data) def relative_speaker_parser(self, seq_label): """Convert sequence of speaker labels to relative labels. Convert sequence of absolute speaker to sequence of relative speaker [E A C A E E C] -> [0 1 2 1 0 0 2] - In this seq of label , if label do not appear before, assign new relative labels len(pos); else reuse previous assigned relative labels. + In this seq of label , if label do not appear before, assign new relative labels len(pos); + else reuse previous assigned relative labels. + Args: seq_label (str): A string of a sequence of labels. @@ -925,10 +1060,13 @@ def __init__( """Parse lists of feature files and sequences of labels. Args: - manifests_files: Either single string file or list of such - - manifests to yield items from. - max_number: Maximum number of samples to collect; pass to `FeatureSequenceLabel` constructor. - index_by_file_id: If True, saves a mapping from filename base (ID) to index in data; pass to `FeatureSequenceLabel` constructor. + manifests_files: + Either single string file or list of such manifests to yield items from. + max_number: + Maximum number of samples to collect; pass to `FeatureSequenceLabel` constructor. + index_by_file_id: + If True, saves a mapping from filename base (ID) to index in data; + pass to `FeatureSequenceLabel` constructor. """ feature_files, seq_labels = [], [] @@ -1083,35 +1221,37 @@ def __init__( manifests_files: Union[str, List[str]], emb_dict: Dict, clus_label_dict: Dict, - round_digit=2, + round_digits: int = 2, seq_eval_mode=False, pairwise_infer=False, *args, **kwargs, ): """ - Parse lists of audio files, durations, RTTM (Diarization annotation) files. Since diarization model infers only - two speakers, speaker pairs are generated from the total number of speakers in the session. + Parse lists of audio files, durations, RTTM (Diarization annotation) files. Since the diarization + model infers only two speakers, speaker pairs are generated from the total number of speakers in + the session. Args: manifest_filepath (str): - Path to input manifest json files. + Path to input manifest JSON files. emb_dict (Dict): Dictionary containing cluster-average embeddings and speaker mapping information. clus_label_dict (Dict): Segment-level speaker labels from clustering results. round_digit (int): - Number of digits to be rounded. + Number of digits to round. seq_eval_mode (bool): If True, F1 score will be calculated for each speaker pair during inference mode. pairwise_infer (bool): - If True, this dataset class operates in inference mode. In inference mode, a set of speakers in the input audio - is split into multiple pairs of speakers and speaker tuples (e.g. 3 speakers: [(0,1), (1,2), (0,2)]) and then - fed into the diarization system to merge the individual results. + If True, this dataset class operates in inference mode. In inference mode, a set of + speakers in the input audio is split into multiple pairs of speakers and speaker tuples + (e.g., 3 speakers: [(0,1), (1,2), (0,2)]) and then fed into the diarization system to + merge the individual results. *args: Args to pass to `SpeechLabel` constructor. **kwargs: Kwargs to pass to `SpeechLabel` constructor. """ - self.round_digit = round_digit + self.round_digits = round_digits self.emb_dict = emb_dict self.clus_label_dict = clus_label_dict self.seq_eval_mode = seq_eval_mode @@ -1245,6 +1385,188 @@ def __parse_item_rttm(self, line: str, manifest_file: str) -> Dict[str, Any]: return item +class EndtoEndDiarizationLabel(_Collection): + """List of end-to-end diarization audio-label correspondence with preprocessing.""" + + OUTPUT_TYPE = collections.namedtuple( + typename='DiarizationLabelEntity', + field_names='audio_file uniq_id duration rttm_file offset', + ) + + def __init__( + self, + audio_files: List[str], + uniq_ids: List[str], + durations: List[float], + rttm_files: List[str], + offsets: List[float], + max_number: Optional[int] = None, + do_sort_by_duration: bool = False, + index_by_file_id: bool = False, + ): + """ + Instantiates audio-label manifest with filters and preprocessing. + + This method initializes the EndtoEndDiarizationLabel object by processing the input data + and applying optional filters and sorting. + + Args: + audio_files (List[str]): List of audio file paths. + uniq_ids (List[str]): List of unique identifiers for each audio file. + durations (List[float]): List of float durations for each audio file. + rttm_files (List[str]): List of RTTM path strings (Groundtruth diarization annotation file). + offsets (List[float]): List of offsets or None for each audio file. + max_number (Optional[int]): Maximum number of samples to collect. Defaults to None. + do_sort_by_duration (bool): If True, sort samples list by duration. Defaults to False. + index_by_file_id (bool): If True, saves a mapping from filename base (ID) to index in data. + Defaults to False. + + """ + if index_by_file_id: + self.mapping = {} + output_type = self.OUTPUT_TYPE + data, duration_filtered = [], 0.0 + + zipped_items = zip(audio_files, uniq_ids, durations, rttm_files, offsets) + for ( + audio_file, + uniq_id, + duration, + rttm_file, + offset, + ) in zipped_items: + + if duration is None: + duration = 0 + + data.append( + output_type( + audio_file, + uniq_id, + duration, + rttm_file, + offset, + ) + ) + + if index_by_file_id: + if isinstance(audio_file, list): + if len(audio_file) == 0: + raise ValueError(f"Empty audio file list: {audio_file}") + file_id, _ = os.path.splitext(os.path.basename(audio_file)) + self.mapping[file_id] = len(data) - 1 + + # Max number of entities filter. + if len(data) == max_number: + break + + if do_sort_by_duration: + if index_by_file_id: + logging.warning("Tried to sort dataset by duration, but cannot since index_by_file_id is set.") + else: + data.sort(key=lambda entity: entity.duration) + + logging.info( + "Filtered duration for loading collection is %f.", + duration_filtered, + ) + logging.info(f"Total {len(data)} session files loaded accounting to # {len(audio_files)} audio clips") + + super().__init__(data) + + +class EndtoEndDiarizationSpeechLabel(EndtoEndDiarizationLabel): + """End-to-end speaker diarization data sample collector from structured json files.""" + + def __init__( + self, + manifests_files: Union[str, List[str]], + round_digits: int = 2, + *args, + **kwargs, + ): + """ + Parse lists of audio files, durations, RTTM (Diarization annotation) files. + Since diarization model infers only two speakers, speaker pairs are generated + from the total number of speakers in the session. + + Args: + manifest_filepath (str): + Path to input manifest json files. + round_digit (int): + Number of digits to be rounded. + *args: Args to pass to `SpeechLabel` constructor. + **kwargs: Kwargs to pass to `SpeechLabel` constructor. + """ + self.round_digits = round_digits + audio_files, uniq_ids, durations, rttm_files, offsets = ( + [], + [], + [], + [], + [], + ) + + for item in manifest.item_iter(manifests_files, parse_func=self.__parse_item_rttm): + # Training mode + rttm_labels = [] + with open(item['rttm_file'], 'r') as f: + for index, rttm_line in enumerate(f.readlines()): + rttm = rttm_line.strip().split() + start = round(float(rttm[3]), round_digits) + end = round(float(rttm[4]), round_digits) + round(float(rttm[3]), round_digits) + speaker = rttm[7] + rttm_labels.append('{} {} {}'.format(start, end, speaker)) + audio_files.append(item['audio_file']) + uniq_ids.append(item['uniq_id']) + durations.append(item['duration']) + rttm_files.append(item['rttm_file']) + offsets.append(item['offset']) + + super().__init__( + audio_files, + uniq_ids, + durations, + rttm_files, + offsets, + *args, + **kwargs, + ) + + def __parse_item_rttm(self, line: str, manifest_file: str) -> Dict[str, Any]: + """Parse each rttm file and save it to in Dict format""" + item = json.loads(line) + if 'audio_filename' in item: + item['audio_file'] = item.pop('audio_filename') + elif 'audio_filepath' in item: + item['audio_file'] = item.pop('audio_filepath') + else: + raise ValueError( + f"Manifest file has invalid json line " f"structure: {line} without proper audio file key." + ) + if isinstance(item['audio_file'], list): + item['audio_file'] = [os.path.expanduser(audio_file_path) for audio_file_path in item['audio_file']] + else: + item['audio_file'] = os.path.expanduser(item['audio_file']) + + if not isinstance(item['audio_file'], list): + if 'uniq_id' not in item: + item['uniq_id'] = os.path.splitext(os.path.basename(item['audio_file']))[0] + elif 'uniq_id' not in item: + raise ValueError(f"Manifest file has invalid json line " f"structure: {line} without proper uniq_id key.") + + if 'duration' not in item: + raise ValueError(f"Manifest file has invalid json line " f"structure: {line} without proper duration key.") + item = dict( + audio_file=item['audio_file'], + uniq_id=item['uniq_id'], + duration=item['duration'], + rttm_file=item['rttm_filepath'], + offset=item.get('offset', None), + ) + return item + + class Audio(_Collection): """Prepare a list of all audio items, filtered by duration.""" @@ -1515,7 +1837,8 @@ def __init__( manifests_files: Either single string file or list of such - manifests to yield items from. max_number: Maximum number of samples to collect; pass to `FeatureSequenceLabel` constructor. - index_by_file_id: If True, saves a mapping from filename base (ID) to index in data; pass to `FeatureSequenceLabel` constructor. + index_by_file_id: If True, saves a mapping from filename base (ID) to index in data; + pass to `FeatureSequenceLabel` constructor. """ feature_files, labels, durations = [], [], [] diff --git a/nemo/collections/common/parts/preprocessing/manifest.py b/nemo/collections/common/parts/preprocessing/manifest.py index 1d49bd7c7019..e2ad08bd04c2 100644 --- a/nemo/collections/common/parts/preprocessing/manifest.py +++ b/nemo/collections/common/parts/preprocessing/manifest.py @@ -110,6 +110,8 @@ def __parse_item(line: str, manifest_file: str) -> Dict[str, Any]: item['audio_file'] = item.pop('audio_filename') elif 'audio_filepath' in item: item['audio_file'] = item.pop('audio_filepath') + elif 'context' in item: + item['audio_file'] = item['context'] # Video File if 'video_filename' in item: @@ -132,7 +134,9 @@ def __parse_item(line: str, manifest_file: str) -> Dict[str, Any]: item['video_file'] = get_full_path(audio_file=item['video_file'], manifest_file=manifest_file) # Duration. - if 'duration' not in item: + if 'context_duration' in item and 'duration' not in item: + item['duration'] = item['context_duration'] + elif 'duration' not in item: raise ValueError( f"Manifest file {manifest_file} has invalid json line structure: {line} without proper duration key." ) @@ -184,6 +188,15 @@ def __parse_item(line: str, manifest_file: str) -> Dict[str, Any]: orig_sr=item.get('orig_sample_rate', None), token_labels=item.get('token_labels', None), lang=item.get('lang', None), + context=item.get('context', None), + context_type=item.get('context_type', None), + context_duration=item.get('context_duration', None), + answer=item.get('answer', None), + answer_type=item.get('answer_type', None), + answer_duration=item.get('answer_duration', None), + question=item.get('question', None), + question_type=item.get('question_type', None), + task=item.get('task', None), ) return item @@ -247,7 +260,7 @@ def get_full_path( if ( (len(audio_file) < audio_file_len_limit) and not os.path.isabs(audio_file) - and not os.path.isfile(audio_file) + # and not os.path.isfile(audio_file) # Commented out because it slows down dataloading ): # If audio_file is not available and the path is not absolute, the full path is assumed # to be relative to the manifest file parent directory or data directory. diff --git a/nemo/collections/common/parts/ptl_overrides.py b/nemo/collections/common/parts/ptl_overrides.py index 0225ecd50fee..263c865f8270 100644 --- a/nemo/collections/common/parts/ptl_overrides.py +++ b/nemo/collections/common/parts/ptl_overrides.py @@ -13,11 +13,11 @@ # limitations under the License. import torch -from pytorch_lightning.plugins.precision import MixedPrecisionPlugin +from lightning.pytorch.plugins.precision import MixedPrecisionPlugin class NeMoMixedPrecisionPlugin(MixedPrecisionPlugin): - def __init__(self, init_scale: float = 2 ** 32, growth_interval: int = 1000) -> None: + def __init__(self, init_scale: float = 2**32, growth_interval: int = 1000) -> None: super().__init__(precision=16) self.scaler = torch.cuda.amp.GradScaler(init_scale=init_scale, growth_interval=growth_interval) diff --git a/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py b/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py index a8ea949019c1..56a4b04dfe0f 100644 --- a/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py +++ b/nemo/collections/common/tokenizers/sentencepiece_tokenizer.py @@ -25,7 +25,7 @@ from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec from nemo.utils import logging -__all__ = ['SentencePieceTokenizer', 'create_spt_model'] +__all__ = ['SentencePieceTokenizer', 'SentencePieceSpeechLLMTTSTokenizer', 'create_spt_model'] class SentencePieceTokenizer(TokenizerSpec, ChatTemplateMixin): @@ -315,6 +315,14 @@ def vocab(self): return main_vocab + special_tokens +class SentencePieceSpeechLLMTTSTokenizer(SentencePieceTokenizer): + def add_phone_tokens_to_special_tokens(self): + for i, word in enumerate(self.vocab): + if word.startswith("p{"): + self.special_token_to_id[word] = i + self.id_to_special_token[i] = word + + def create_spt_model( data_file: str, vocab_size: int, diff --git a/nemo/collections/diffusion/data/diffusion_energon_datamodule.py b/nemo/collections/diffusion/data/diffusion_energon_datamodule.py index 67a26609dd51..07747528363a 100644 --- a/nemo/collections/diffusion/data/diffusion_energon_datamodule.py +++ b/nemo/collections/diffusion/data/diffusion_energon_datamodule.py @@ -15,9 +15,9 @@ import logging from typing import Any, Dict, Literal +from lightning.pytorch.utilities.types import EVAL_DATALOADERS from megatron.core import parallel_state from megatron.energon import DefaultTaskEncoder, WorkerConfig, get_savable_loader, get_train_dataset -from pytorch_lightning.utilities.types import EVAL_DATALOADERS from nemo.collections.multimodal.data.energon.base import SimpleMultiModalDataModule diff --git a/nemo/collections/diffusion/data/diffusion_fake_datamodule.py b/nemo/collections/diffusion/data/diffusion_fake_datamodule.py index 6cb686c1c305..a9fc7ad5b484 100644 --- a/nemo/collections/diffusion/data/diffusion_fake_datamodule.py +++ b/nemo/collections/diffusion/data/diffusion_fake_datamodule.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl import torch -from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS +from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS from torch.utils.data import DataLoader from nemo.collections.diffusion.models.model import DiTConfig diff --git a/nemo/collections/diffusion/train.py b/nemo/collections/diffusion/train.py index 5428e0eeefa2..404602084b85 100644 --- a/nemo/collections/diffusion/train.py +++ b/nemo/collections/diffusion/train.py @@ -14,13 +14,13 @@ import os +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch +from lightning.pytorch.loggers import WandbLogger from megatron.core.distributed import DistributedDataParallelConfig from megatron.core.optimizer import OptimizerConfig from megatron.core.transformer.enums import AttnMaskType -from pytorch_lightning.loggers import WandbLogger from nemo import lightning as nl from nemo.collections import llm diff --git a/nemo/collections/llm/__init__.py b/nemo/collections/llm/__init__.py index 2051f844d888..f17128cdb36d 100644 --- a/nemo/collections/llm/__init__.py +++ b/nemo/collections/llm/__init__.py @@ -22,7 +22,7 @@ AlpacaDataModule, DollyDataModule, FineTuningDataModule, - HfDatasetDataModule, + HFDatasetDataModule, MockDataModule, PreTrainingDataModule, SquadDataModule, @@ -64,7 +64,7 @@ GPTConfig126M, GPTConfig175B, GPTModel, - HfAutoModelForCausalLM, + HFAutoModelForCausalLM, Llama2Config7B, Llama2Config13B, Llama2Config70B, @@ -73,6 +73,8 @@ Llama31Config8B, Llama31Config70B, Llama31Config405B, + Llama32Config1B, + Llama32Config3B, LlamaConfig, LlamaModel, MaskedTokenLossReduction, @@ -93,6 +95,9 @@ NemotronModel, NVIDIAMambaConfig8B, NVIDIAMambaHybridConfig8B, + Phi3Config, + Phi3ConfigMini, + Phi3Model, Qwen2Config, Qwen2Config1P5B, Qwen2Config7B, @@ -112,10 +117,15 @@ gpt_forward_step, ) from nemo.collections.llm.quantization import Quantizer, get_calib_data_iter +from nemo.collections.llm.t5.data import FineTuningDataModule as T5FineTuningDataModule +from nemo.collections.llm.t5.data import MockDataModule as T5MockDataModule +from nemo.collections.llm.t5.data import PreTrainingDataModule as T5PreTrainingDataModule +from nemo.collections.llm.t5.data import SquadDataModule as T5SquadDataModule from nemo.collections.llm.t5.model import T5Config, T5Model, t5_data_step, t5_forward_step __all__ = [ "MockDataModule", + "T5MockDataModule", "GPTModel", "GPTConfig", "gpt_data_step", @@ -143,6 +153,9 @@ "Nemotron4Config15B", "Nemotron4Config340B", "NemotronConfig", + "Phi3Config", + "Phi3ConfigMini", + "Phi3Model", "SSMConfig", "BaseMambaConfig130M", "BaseMambaConfig370M", @@ -160,6 +173,8 @@ "Llama31Config8B", "Llama31Config70B", "Llama31Config405B", + "Llama32Config1B", + "Llama32Config3B", "CodeLlamaConfig7B", "CodeLlamaConfig13B", "CodeLlamaConfig34B", @@ -192,6 +207,10 @@ "PreTrainingDataModule", "FineTuningDataModule", "SquadDataModule", + "T5PreTrainingDataModule", + "T5FineTuningDataModule", + "T5SquadDataModule", + "T5MockDataModule", "DollyDataModule", "tokenizer", "mock", @@ -199,7 +218,7 @@ "dolly", "peft", "hf_dataset", - "HfAutoModelForCausalLM", + "HFAutoModelForCausalLM", ] @@ -208,7 +227,7 @@ try: import nemo_run as run - from nemo.collections.llm.api import export_ckpt, finetune, generate, import_ckpt, pretrain, train, validate + from nemo.collections.llm.api import export_ckpt, finetune, generate, import_ckpt, pretrain, ptq, train, validate from nemo.collections.llm.recipes import * # noqa __all__.extend( @@ -220,6 +239,7 @@ "validate", "finetune", "generate", + "ptq", ] ) except ImportError as error: @@ -231,3 +251,10 @@ __all__.append("deploy") except ImportError as error: logging.warning(f"The deploy module could not be imported: {error}") + +try: + from nemo.collections.llm.api import evaluate + + __all__.append("evaluate") +except ImportError as error: + logging.warning(f"The evaluate module could not be imported: {error}") diff --git a/nemo/collections/llm/api.py b/nemo/collections/llm/api.py index fdceff5d959e..adf98747059c 100644 --- a/nemo/collections/llm/api.py +++ b/nemo/collections/llm/api.py @@ -11,20 +11,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import json import os from copy import deepcopy from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Optional, Union +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch +from megatron.core import parallel_state from rich.console import Console +from torch.distributed import all_gather_object from typing_extensions import Annotated import nemo.lightning as nl +from nemo.collections.llm.quantization import ExportConfig, QuantizationConfig from nemo.lightning import ( AutoResume, NeMoLogger, @@ -36,6 +38,8 @@ from nemo.lightning.base import NEMO_MODELS_CACHE from nemo.lightning.pytorch.callbacks import PEFT, ModelTransform from nemo.utils import logging +from nemo.utils.get_rank import is_global_rank_zero + if TYPE_CHECKING: from megatron.core.inference.common_inference_params import CommonInferenceParams @@ -68,7 +72,8 @@ def train( resume (Optional[Union[AutoResume, Resume]]): Resume training from a checkpoint. optim (Optional[OptimizerModule]): The optimizer module to be used. If not provided, the default optimizer from the model will be used. - tokenizer (Optional[TokenizerType]): Tokenizer setting to be applied. Can be 'data' or 'model' or an instance of TokenizerSpec. + tokenizer (Optional[TokenizerType]): Tokenizer setting to be applied. Can be 'data' or 'model' + or an instance of TokenizerSpec. export (Optional[str]): Filename to save the exported checkpoint after training. model_transform (Optional[Union[Callable[[nn.Module], nn.Module], PEFT]]): A model transform to be applied. @@ -84,7 +89,7 @@ def train( >>> data = llm.SquadDataModule(seq_length=4096, global_batch_size=16, micro_batch_size=2) >>> precision = nl.MegatronMixedPrecision(precision="bf16-mixed") >>> trainer = nl.Trainer(strategy=nl.MegatronStrategy(tensor_model_parallel_size=2), plugins=precision) - >>> train(model, data, trainer, tokenizer="data") + >>> llm.train(model, data, trainer, tokenizer="data") PosixPath('/path/to/log_dir') """ app_state = _setup( @@ -186,7 +191,7 @@ def finetune( >>> data = llm.SquadDataModule(seq_length=4096, global_batch_size=16, micro_batch_size=2) >>> precision = nl.MegatronMixedPrecision(precision="bf16-mixed") >>> trainer = nl.Trainer(strategy=nl.MegatronStrategy(tensor_model_parallel_size=2), plugins=precision) - >>> finetune(model, data, trainer, peft=llm.peft.LoRA()]) + >>> llm.finetune(model, data, trainer, peft=llm.peft.LoRA()]) PosixPath('/path/to/log_dir') """ @@ -224,7 +229,8 @@ def validate( resume (Optional[AutoResume]): Resume from a checkpoint for validation. optim (Optional[OptimizerModule]): The optimizer module to be used. If not provided, the default optimizer from the model will be used. - tokenizer (Optional[TokenizerType]): Tokenizer setting to be applied. Can be 'data' or 'model' or an instance of TokenizerSpec. + tokenizer (Optional[TokenizerType]): Tokenizer setting to be applied. Can be 'data' or 'model' + or an instance of TokenizerSpec. model_transform (Optional[Union[Callable[[nn.Module], nn.Module], PEFT]]): A model transform to be applied. Returns: @@ -237,7 +243,7 @@ def validate( >>> data = llm.SquadDataModule(seq_length=4096, global_batch_size=16, micro_batch_size=2) >>> precision = nl.MegatronMixedPrecision(precision="bf16-mixed") >>> trainer = nl.Trainer(strategy=nl.MegatronStrategy(tensor_model_parallel_size=2), plugins=precision) - >>> validate(model, data, trainer, tokenizer="data") + >>> llm.validate(model, data, trainer, tokenizer="data") PosixPath('/path/to/log_dir') """ app_state = _setup( @@ -256,84 +262,69 @@ def validate( return app_state.exp_dir -def get_trtllm_deployable( - nemo_checkpoint, - model_type, - triton_model_repository, - num_gpus, - tensor_parallelism_size, - pipeline_parallelism_size, - max_input_len, - max_output_len, - max_batch_size, - dtype, -): - from nemo.export.tensorrt_llm import TensorRTLLM +@run.cli.entrypoint(name="ptq", namespace="llm") +def ptq( + nemo_checkpoint: str, + export_config: ExportConfig, + calib_tp: int = 1, + calib_pp: int = 1, + quantization_config: Annotated[Optional[QuantizationConfig], run.Config[QuantizationConfig]] = None, +) -> Path: + """ + Applies Post-Training Quantization (PTQ) for a model using the specified quantization and export configs. It runs + calibration for a small dataset to collect scaling factors low-precision GEMMs used by desired quantization method. + This function produces TensorRT-LLM checkpoint ready for deployment using nemo.export and nemo.deploy modules + or direcly using TensorRT-LLM library. + The function can be used through the NeMo CLI in the following way: + ```bash + # Run calibration using tensor parallel set to 8 and export quantized checkpoint with tensor parallel equal 2 + nemo llm ptq nemo_checkpoint=/models/Llama-3-70B \ + export_config.path=/models/Llama-3-70B-FP8 \ + calib_tp=8 \ + export_config.inference_tensor_parallel=2 + # Choose different quantization method, for example, INT8 SmoothQuant + nemo llm ptq nemo_checkpoint=/models/Llama-3-8B \ + export_config.path=/models/Llama-3-8B-INT8_SQ \ + quantization_config.algorithm=int8_sq + ``` + Args: + nemo_checkpoint (str): The path to model to be quantized. + calib_tp (int): Calibration tensor parallelism. + calib_pp (int): Calibration pipeline parallelism. + quantization_config (QuantizationConfig): Configuration for quantization algorithm. + export_config (ExportConfig): Export configuration for TensorRT-LLM checkpoint. + Returns: + Path: The path where the quantized checkpoint has been saved after calibration. + """ + if not quantization_config: + quantization_config = QuantizationConfig() - if triton_model_repository is None: - trt_llm_path = "/tmp/trt_llm_model_dir/" - Path(trt_llm_path).mkdir(parents=True, exist_ok=True) - else: - trt_llm_path = triton_model_repository + if export_config.path is None: + raise ValueError("The export_config.path needs to be specified, got None.") - if nemo_checkpoint is None and triton_model_repository is None: - raise ValueError( - "The provided model repository is not a valid TensorRT-LLM model " - "directory. Please provide a --nemo_checkpoint or a TensorRT-LLM engine." - ) + from nemo.collections.llm import quantization - if nemo_checkpoint is None and not os.path.isdir(triton_model_repository): - raise ValueError( - "The provided model repository is not a valid TensorRT-LLM model " - "directory. Please provide a --nemo_checkpoint or a valid TensorRT-LLM engine." - ) + quantizer = quantization.Quantizer(quantization_config, export_config) - if nemo_checkpoint is not None and model_type is None: - raise ValueError("Model type is required to be defined if a nemo checkpoint is provided.") + model = quantization.load_with_modelopt_layer_spec(nemo_checkpoint, calib_tp, calib_pp) - trt_llm_exporter = TensorRTLLM( - model_dir=trt_llm_path, - load_model=(nemo_checkpoint is None), - ) + model = quantizer.quantize(model) - if nemo_checkpoint is not None: - try: - logging.info("Export operation will be started to export the nemo checkpoint to TensorRT-LLM.") - trt_llm_exporter.export( - nemo_checkpoint_path=nemo_checkpoint, - model_type=model_type, - n_gpus=num_gpus, - tensor_parallelism_size=tensor_parallelism_size, - pipeline_parallelism_size=pipeline_parallelism_size, - max_input_len=max_input_len, - max_output_len=max_output_len, - max_batch_size=max_batch_size, - dtype=dtype, - ) - except Exception as error: - raise RuntimeError("An error has occurred during the model export. Error message: " + str(error)) - - return trt_llm_exporter - - -def store_args_to_json(triton_http_address, triton_port, triton_request_timeout, openai_format_response): - args_dict = { - "triton_service_ip": triton_http_address, - "triton_service_port": triton_port, - "triton_request_timeout": triton_request_timeout, - "openai_format_response": openai_format_response, - } - with open("nemo/deploy/service/config.json", "w") as f: - json.dump(args_dict, f) + quantizer.export(model, nemo_checkpoint) + + console = Console() + console.print(f"[green]✓ PTQ succeded, quantized checkpoint exported to {export_config.path}[/green]") + + return export_config.path @run.cli.entrypoint(namespace="llm") def deploy( nemo_checkpoint: Path = None, model_type: str = "llama", - triton_model_name: str = "xxx", + triton_model_name: str = 'triton_model', triton_model_version: Optional[int] = 1, - triton_port: int = 8080, + triton_port: int = 8000, triton_http_address: str = "0.0.0.0", triton_request_timeout: int = 60, triton_model_repository: Path = None, @@ -344,21 +335,61 @@ def deploy( max_input_len: int = 256, max_output_len: int = 256, max_batch_size: int = 8, - start_rest_service: bool = False, + start_rest_service: bool = True, rest_service_http_address: str = "0.0.0.0", - rest_service_port: int = 8000, - openai_format_response: bool = False, + rest_service_port: int = 8080, + openai_format_response: bool = True, + output_generation_logits: bool = True, ): + """ + Deploys nemo model on a PyTriton server by converting the nemo ckpt to trtllm. + Also starts rest service that is used to send OpenAI API compatible input request + to the PyTiton server. + + Args: + nemo_checkpoint (Path): Path for nemo checkpoint. + model_type (str): Type of the model. Choices: gpt, llama, falcon, starcoder. Default: llama. + triton_model_name (str): Name for the model that gets deployed on PyTriton. Please ensure that the same model + name is passed to the evalute method for the model to be accessible while sending evalution requests. + Default: 'triton_model'. + triton_model_version (Optional[int]): Version for the triton model. Default: 1. + triton_port (int): Port for the PyTriton server. Default: 8000. + triton_http_address (str): HTTP address for the PyTriton server. Default: "0.0.0.0". + triton_request_timeout (int): Timeout in seconds for Triton server. Default: 60. + triton_model_repository (Path): Folder for the trt-llm conversion, trt-llm engine gets saved in this specified + path. If None, saves it in /tmp dir. Default: None. + num_gpus (int): Number of GPUs for export to trtllm and deploy. Default: 1. + tensor_parallelism_size (int): Tensor parallelism size. Default: 1. + pipeline_parallelism_size (int): Pipeline parallelism size. Default: 1. + dtype (str): dtype of the TensorRT-LLM model. Default: "bfloat16". + max_input_len (int): Max input length of the model. Default: 256. + max_output_len (int): Max output length of the model. Default: 256. + max_batch_size (int): Max batch size of the model. Default: 8. + start_rest_service (bool): Start rest service that is used to send evaluation requests to the PyTriton server. + Needs to be True to be able to run evaluation. Default: True. + rest_service_http_address (str): HTTP address for the rest service. Default: "0.0.0.0". + rest_service_port (int): Port for the rest service. Default: 8080. + openai_format_response (bool): Return the response from PyTriton server in OpenAI compatible format. Needs to + be True while running evaluation. Default: True. + output_generation_logits (bool): If True builds trtllm engine with gather_generation_logits set to True. + generation_logits are used to compute the logProb of the output token. Default: True. + """ + from nemo.collections.llm import deploy from nemo.deploy import DeployPyTriton + deploy.unset_environment_variables() if start_rest_service: if triton_port == rest_service_port: logging.error("REST service port and Triton server port cannot use the same port.") return - # Store triton ip, port and other args relevant for REST API in config.json to be accessible by rest_model_api.py - store_args_to_json(triton_http_address, triton_port, triton_request_timeout, openai_format_response) - - triton_deployable = get_trtllm_deployable( + # Store triton ip, port and other args relevant for REST API as env vars to be accessible by rest_model_api.py + os.environ['TRITON_HTTP_ADDRESS'] = triton_http_address + os.environ['TRITON_PORT'] = str(triton_port) + os.environ['TRITON_REQUEST_TIMEOUT'] = str(triton_request_timeout) + os.environ['OPENAI_FORMAT_RESPONSE'] = str(openai_format_response) + os.environ['OUTPUT_GENERATION_LOGITS'] = str(output_generation_logits) + + triton_deployable = deploy.get_trtllm_deployable( nemo_checkpoint, model_type, triton_model_repository, @@ -369,6 +400,7 @@ def deploy( max_output_len, max_batch_size, dtype, + output_generation_logits, ) try: @@ -383,6 +415,7 @@ def deploy( logging.info("Triton deploy function will be called.") nm.deploy() + nm.run() except Exception as error: logging.error("Error message has occurred during deploy function. Error message: " + str(error)) return @@ -416,6 +449,81 @@ def deploy( nm.stop() +def evaluate( + nemo_checkpoint_path: Path, + url: str = "http://0.0.0.0:8080/v1", + model_name: str = "triton_model", + eval_task: str = "gsm8k", + num_fewshot: Optional[int] = None, + limit: Optional[Union[int, float]] = None, + bootstrap_iters: int = 100000, + # inference params + max_tokens_to_generate: Optional[int] = 256, + temperature: Optional[float] = 0.000000001, + top_p: Optional[float] = 0.0, + top_k: Optional[int] = 1, + add_bos: Optional[bool] = False, +): + """ + Evaluates nemo model deployed on PyTriton server (via trtllm) using lm-evaluation-harness + (https://github.com/EleutherAI/lm-evaluation-harness/tree/main). + + Args: + nemo_checkpoint_path (Path): Path for nemo 2.0 checkpoint. This is used to get the tokenizer from the ckpt + which is required to tokenize the evaluation input and output prompts. + url (str): rest service url and port that were used in the deploy method above in the format: + http://{rest_service_http}:{rest_service_port}. Post requests with evaluation input prompts + (from lm-eval-harness) are sent to this url which is then passed to the model deployed on PyTriton server. + The rest service url and port serve as the entry point to evaluate model deployed on PyTriton server. + model_name (str): Name of the model that is deployed on PyTriton server. It should be the same as + triton_model_name passed to the deploy method above to be able to launch evaluation. Deafult: "triton_model". + eval_task (str): task to be evaluated on. For ex: "gsm8k", "gsm8k_cot", "mmlu", "lambada". Default: "gsm8k". + These are the tasks that are supported currently. Any other task of type generate_until or loglikelihood from + lm-evaluation-harness can be run, but only the above mentioned ones are tested. Tasks of type + loglikelihood_rolling are not supported yet. + num_fewshot (int): number of examples in few-shot context. Default: None. + limit (Union[int, float]): Limit the number of examples per task. If <1 (i.e float val between 0 and 1), limit + is a percentage of the total number of examples. If int say x, then run evaluation only on x number of samples + from the eval dataset. Default: None, which means eval is run the entire dataset. + bootstrap_iters (int): Number of iterations for bootstrap statistics, used when calculating stderrs. Set to 0 + for no stderr calculations to be performed. Default: 100000. + # inference params + max_tokens_to_generate (int): max tokens to generate. Default: 256. + temperature: Optional[float]: float value between 0 and 1. temp of 0 indicates greedy decoding, where the token + with highest prob is chosen. Temperature can't be set to 0.0 currently, due to a bug with TRTLLM + (# TODO to be investigated). Hence using a very samll value as the default. Default: 0.000000001. + top_p: Optional[float]: float value between 0 and 1. limits to the top tokens within a certain probability. + top_p=0 means the model will only consider the single most likely token for the next prediction. Default: 0.0. + top_k: Optional[int]: limits to a certain number (K) of the top tokens to consider. top_k=1 means the model + will only consider the single most likely token for the next prediction. Default: 1 + add_bos: Optional[bool]: whether a special token representing the beginning of a sequence should be added when + encoding a string. Default: False since typically for CausalLM its set to False. If needed set add_bos to True. + """ + try: + # lm-evaluation-harness import + from lm_eval import evaluator + except ImportError: + raise ImportError( + "Please ensure that lm-evaluation-harness is installed in your env as it is required " "to run evaluations" + ) + + from nemo.collections.llm import evaluation + + # Get tokenizer from nemo ckpt. This works only with NeMo 2.0 ckpt. + tokenizer = io.load_context(nemo_checkpoint_path + '/context', subpath="model").tokenizer + # Wait for rest service to be ready before starting evaluation + evaluation.wait_for_rest_service(rest_url=f"{url}/v1/health") + # Create an object of the NeMoFWLM which is passed as a model to evaluator.simple_evaluate + model = evaluation.NeMoFWLMEval( + model_name, url, tokenizer, max_tokens_to_generate, temperature, top_p, top_k, add_bos + ) + results = evaluator.simple_evaluate( + model=model, tasks=eval_task, limit=limit, num_fewshot=num_fewshot, bootstrap_iters=bootstrap_iters + ) + + print("score", results['results'][eval_task]) + + @run.cli.entrypoint(name="import", namespace="llm") def import_ckpt( model: pl.LightningModule, @@ -560,9 +668,10 @@ def export_ckpt( @run.cli.entrypoint(name="generate", namespace="llm") def generate( path: Union[Path, str], - prompts: list[str], trainer: nl.Trainer, + prompts: Optional[list[str]] = None, encoder_prompts: Optional[list[str]] = None, + input_dataset: Optional[Union[pl.LightningDataModule, str]] = None, params_dtype: torch.dtype = torch.bfloat16, add_BOS: bool = False, max_batch_size: int = 4, @@ -570,6 +679,7 @@ def generate( inference_batch_times_seqlen_threshold: int = 1000, inference_params: Optional["CommonInferenceParams"] = None, text_only: bool = False, + output_path: Optional[Union[Path, str]] = None, ) -> list[Union["InferenceRequest", str]]: """ Generates text using a NeMo LLM model. @@ -623,6 +733,8 @@ def generate( prompts (list[str]): The list of prompts to generate text for. trainer (nl.Trainer): The trainer object. encoder_prompts (Optional[list[str]], optional): The list of encoder prompts. Defaults to None. + input_dataset (Optional[Union[pl.LightningDataModule, str]], optional): The input data module or jsonl file. + Test set will be used for generation for data modules. Defaults to None. params_dtype (torch.dtype, optional): The data type of the model parameters. Defaults to torch.bfloat16. add_BOS (bool, optional): Whether to add the beginning of sequence token. Defaults to False. max_batch_size (int, optional): The maximum batch size. Defaults to 4. @@ -632,6 +744,8 @@ def generate( inference_params (Optional["CommonInferenceParams"], optional): The inference parameters defined in Mcore's CommonInferenceParams. Defaults to None. text_only (bool, optional): Whether to return only the generated text as a string. Defaults to False. + output_path (Optional[Union[Path, str]], optional): The path to save the generated text or test dataset + predictions. Defaults to None. Returns: list[Union["InferenceRequest", str]]: A list of generated text, @@ -639,24 +753,63 @@ def generate( """ from nemo.collections.llm import inference + if input_dataset is not None: + input_path = input_dataset if isinstance(input_dataset, str) else input_dataset.test_path + with open(input_path) as f: + dataset = [json.loads(sample) for sample in f.readlines()] + inputs = [sample["input"] for sample in dataset] + elif prompts is not None: + inputs = prompts + else: + raise ValueError("Either prompts or input_dataset must be provided.") + inference_wrapped_model, mcore_tokenizer = inference.setup_model_and_tokenizer( path=path, trainer=trainer, params_dtype=params_dtype, inference_batch_times_seqlen_threshold=inference_batch_times_seqlen_threshold, ) - results = inference.generate( + + dp_size = trainer.strategy.distributed_sampler_kwargs['num_replicas'] + dp_rank = trainer.strategy.distributed_sampler_kwargs['rank'] + chunk_size = (len(inputs) + dp_size - 1) // dp_size + start_idx = dp_rank * chunk_size + end_idx = min(start_idx + chunk_size, len(inputs)) + inputs_on_this_dp_rank = inputs[start_idx:end_idx] + + results_on_this_dp_rank = inference.generate( model=inference_wrapped_model, tokenizer=mcore_tokenizer, - prompts=prompts, + prompts=inputs_on_this_dp_rank, encoder_prompts=encoder_prompts, add_BOS=add_BOS, max_batch_size=max_batch_size, random_seed=random_seed, inference_params=inference_params, ) + gathered_results = [None] * dp_size - return [r.generated_text if text_only else r for r in results] + all_gather_object( + gathered_results, + [r.generated_text if text_only else r for r in results_on_this_dp_rank], + group=parallel_state.get_data_parallel_group(), + ) + gathered_results = [result for sublist in gathered_results for result in sublist] + + assert len(gathered_results) == len(inputs) + + if output_path is not None and is_global_rank_zero(): + with open(output_path, "w") as f: + for sample, pred in zip(dataset if input_dataset else inputs, gathered_results): + if type(sample) == dict: + sample["label"] = sample.pop("output", None) + sample["prediction"] = pred if text_only else pred.generated_text + elif type(sample) == str: + sample = {"input": sample, "prediction": pred if text_only else pred.generated_text} + f.write(json.dumps(sample) + "\n") + logging.info(f"Predictions written to {output_path}") + + return gathered_results def _use_tokenizer(model: pl.LightningModule, data: pl.LightningDataModule, tokenizer: TokenizerType) -> None: diff --git a/nemo/collections/llm/deploy/__init__.py b/nemo/collections/llm/deploy/__init__.py new file mode 100644 index 000000000000..24c102bfa0d2 --- /dev/null +++ b/nemo/collections/llm/deploy/__init__.py @@ -0,0 +1,3 @@ +from nemo.collections.llm.deploy.base import get_trtllm_deployable, unset_environment_variables + +__all__ = ["unset_environment_variables", "get_trtllm_deployable"] diff --git a/nemo/collections/llm/deploy/base.py b/nemo/collections/llm/deploy/base.py new file mode 100644 index 000000000000..e21198f5884b --- /dev/null +++ b/nemo/collections/llm/deploy/base.py @@ -0,0 +1,117 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import subprocess +from pathlib import Path + +from nemo.utils import logging + + +def unset_environment_variables() -> None: + """ + SLURM_, PMI_, PMIX_ Variables are needed to be unset for trtllm export to work + on clusters. This method takes care of unsetting these env variables + """ + logging.info("Unsetting all SLURM_, PMI_, PMIX_ Variables") + + # Function to unset variables with a specific prefix + def unset_vars_with_prefix(prefix): + unset_vars = [] + cmd = f"env | grep ^{prefix} | cut -d= -f1" + result = subprocess.run(cmd, shell=True, capture_output=True, text=True) + vars_to_unset = result.stdout.strip().split('\n') + for var in vars_to_unset: + if var: # Check if the variable name is not empty + os.environ.pop(var, None) + unset_vars.append(var) + return unset_vars + + # Collect all unset variables across all prefixes + all_unset_vars = [] + + # Unset variables for each prefix + for prefix in ['SLURM_', 'PMI_', 'PMIX_']: + unset_vars = unset_vars_with_prefix(prefix) + all_unset_vars.extend(unset_vars) + + if all_unset_vars: + logging.info(f"Unset env variables: {', '.join(all_unset_vars)}") + else: + logging.info("No env variables were unset.") + + +def get_trtllm_deployable( + nemo_checkpoint, + model_type, + triton_model_repository, + num_gpus, + tensor_parallelism_size, + pipeline_parallelism_size, + max_input_len, + max_output_len, + max_batch_size, + dtype, + output_generation_logits, +): + """ + Exports the nemo checkpoint to trtllm and returns trt_llm_exporter that is used to deploy on PyTriton. + """ + from nemo.export.tensorrt_llm import TensorRTLLM + + if triton_model_repository is None: + trt_llm_path = "/tmp/trt_llm_model_dir/" + Path(trt_llm_path).mkdir(parents=True, exist_ok=True) + else: + trt_llm_path = triton_model_repository + + if nemo_checkpoint is None and triton_model_repository is None: + raise ValueError( + "The provided model repository is not a valid TensorRT-LLM model " + "directory. Please provide a --nemo_checkpoint or a TensorRT-LLM engine." + ) + + if nemo_checkpoint is None and not os.path.isdir(triton_model_repository): + raise ValueError( + "The provided model repository is not a valid TensorRT-LLM model " + "directory. Please provide a --nemo_checkpoint or a valid TensorRT-LLM engine." + ) + + if nemo_checkpoint is not None and model_type is None: + raise ValueError("Model type is required to be defined if a nemo checkpoint is provided.") + + trt_llm_exporter = TensorRTLLM( + model_dir=trt_llm_path, + load_model=(nemo_checkpoint is None), + ) + + if nemo_checkpoint is not None: + try: + logging.info("Export operation will be started to export the nemo checkpoint to TensorRT-LLM.") + trt_llm_exporter.export( + nemo_checkpoint_path=nemo_checkpoint, + model_type=model_type, + n_gpus=num_gpus, + tensor_parallelism_size=tensor_parallelism_size, + pipeline_parallelism_size=pipeline_parallelism_size, + max_input_len=max_input_len, + max_output_len=max_output_len, + max_batch_size=max_batch_size, + dtype=dtype, + gather_generation_logits=output_generation_logits, + ) + except Exception as error: + raise RuntimeError("An error has occurred during the model export. Error message: " + str(error)) + + return trt_llm_exporter diff --git a/nemo/collections/llm/evaluation/__init__.py b/nemo/collections/llm/evaluation/__init__.py new file mode 100644 index 000000000000..3012689bb8da --- /dev/null +++ b/nemo/collections/llm/evaluation/__init__.py @@ -0,0 +1,3 @@ +from nemo.collections.llm.evaluation.base import NeMoFWLMEval, wait_for_rest_service + +__all__ = ["NeMoFWLMEval", "wait_for_rest_service"] diff --git a/nemo/collections/llm/evaluation/base.py b/nemo/collections/llm/evaluation/base.py new file mode 100644 index 000000000000..b1734d6f4d43 --- /dev/null +++ b/nemo/collections/llm/evaluation/base.py @@ -0,0 +1,210 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import requests +import torch +import torch.nn.functional as F +from lm_eval.api.instance import Instance +from lm_eval.api.model import LM +from requests.exceptions import RequestException + +from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer +from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer +from nemo.utils import logging + + +class NeMoFWLMEval(LM): + """ + NeMoFWLMEval is a wrapper class subclassing lm_eval.api.model.LM class, that defines how lm_eval interfaces with + NeMo model deployed on PyTriton server. + Created based on: https://github.com/EleutherAI/lm-evaluation-harness/blob/v0.4.4/docs/model_guide.md + """ + + def __init__(self, model_name, api_url, tokenizer, max_tokens_to_generate, temperature, top_p, top_k, add_bos): + self.model_name = model_name + self.api_url = api_url + self.tokenizer = tokenizer + self.max_tokens_to_generate = max_tokens_to_generate + self.temperature = temperature + self.top_p = top_p + self.top_k = top_k + self.add_bos = add_bos + super().__init__() + + def _generate_tokens_logits(self, payload, return_text: bool = False, return_logits: bool = False): + """ + A private method that sends post request to the model on PyTriton server and returns either generated text or + logits. + """ + # send a post request to /v1/completions/ endpoint with the payload + response = requests.post(f"{self.api_url}/v1/completions/", json=payload) + response_data = response.json() + + if 'error' in response_data: + raise Exception(f"API Error: {response_data['error']}") + + # Assuming the response is in OpenAI format + if return_text: + # in case of generate_until tasks return just the text + return response_data['choices'][0]['text'] + + if return_logits: + # in case of loglikelihood tasks return the logits + return response_data['choices'][0]['generation_logits'] + + def tokenizer_type(self, tokenizer): + """ + Returns the type of the tokenizer. + """ + if isinstance(tokenizer, AutoTokenizer): + return "AutoTokenizer" + elif isinstance(tokenizer, SentencePieceTokenizer): + return "SentencePieceTokenizer" + else: + raise ValueError( + "Tokenizer type is not one of SentencePieceTokenizer or HF's AutoTokenizer. Please check " + "how to handle special tokens for this tokenizer" + ) + + def loglikelihood(self, requests: list[Instance]): + """ + Defines the loglikelihood request. Takes input requests of type list[Instance] where Instance is a dataclass + defined in lm_eval.api.instance. Each Instance conists of the input prompt, output prompt, request type(here + loglikelihood) and other relevant args like few shot samples. + """ + special_tokens_kwargs = {} + tokenizer_type = self.tokenizer_type(self.tokenizer) + if tokenizer_type == "SentencePieceTokenizer": + special_tokens_kwargs['add_bos'] = self.add_bos + elif tokenizer_type == "AutoTokenizer": + special_tokens_kwargs['add_special_tokens'] = self.add_bos + + results = [] + for request in requests: + # get the input prompt from the request + context = request.arguments[0] + # get the output prompt from the request + continuation = request.arguments[1] + # get encoded tokens of continuation + continuation_enc = self.tokenizer.tokenizer.encode(continuation, **special_tokens_kwargs) + # for SentencePeice consider the encoded tokens from the 2nd token since first encoded token is space. + if self.tokenizer_type(self.tokenizer) == "SentencePieceTokenizer": + continuation_enc = continuation_enc[1:] + num_cont_tokens = len(continuation_enc) + # Update self.max_tokens_to_generate with number of continuation tokens (or output tokens) in the request + self.max_tokens_to_generate = num_cont_tokens + # Create payload to query the model deployed on PyTriton server + payload = { + "model": self.model_name, + "prompt": context, + "max_tokens": self.max_tokens_to_generate, + "temperature": self.temperature, + "top_p": self.top_p, + "top_k": self.top_k, + } + # Get the logits from the model + generation_logits = self._generate_tokens_logits(payload, return_logits=True) + # Convert generation_logits to torch tensor to easily get logprobs wo manual implementation of log_softmax + multi_logits = F.log_softmax(torch.tensor(generation_logits[0]), dim=-1) + # Convert encoded continuation tokens to torch tensor + cont_toks = torch.tensor(continuation_enc, dtype=torch.long).unsqueeze(0) + # Get the greedy token from the logits (i.e token with the highest prob) + greedy_tokens = multi_logits.argmax(dim=-1) + # Check if all greedy_tokens match the the actual continuation tokens + is_greedy = (greedy_tokens == cont_toks).all() + # Get the logits corresponding to the actual continuation tokens + logits = torch.gather(multi_logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) + # result is tuple of logProb of generating the continuation token and is_greedy + result = (float(logits.sum()), bool(is_greedy)) + + results.append(result) + + return results + + def loglikelihood_rolling(self, requests: list[Instance]): + """ + Defines the loglikelihood_rolling request type. Yet to be implemented. + """ + pass + + def generate_until(self, inputs: list[Instance]): + """ + Defines the generate_until request type. Takes input requests of type list[Instance] where Instance is a + dataclass defined in lm_eval.api.instance. Each Instance conists of the input prompt, output prompt, request + type(here loglikelihood) and other relevant args like few shot samples. + """ + results = [] + for instance in inputs: + # Access the 'arguments' attribute of the Instance which contains the input prompt string + prompt = instance.arguments[0] + # Create payload to query the model deployed on PyTriton server + payload = { + "model": self.model_name, + "prompt": prompt, + "max_tokens": self.max_tokens_to_generate, + "temperature": self.temperature, + "top_p": self.top_p, + "top_k": self.top_k, + } + # Get the text generated by the model + generated_text = self._generate_tokens_logits(payload, return_text=True) + + results.append(generated_text) + + return results + + +def wait_for_rest_service(rest_url, max_retries=60, retry_interval=2): + """ + Wait for REST service to be ready. + + Args: + rest_url (str): URL of the REST service's health endpoint + max_retries (int): Maximum number of retry attempts. Defaul: 60. + retry_interval (int): Time to wait between retries in seconds. Default: 2. + + Returns: + bool: True if rest service is ready, False otherwise + """ + + def check_service(url): + """ + Check if the service is ready by making a GET request to its health endpoint. + + Args: + url (str): URL of the service's health endpoint + + Returns: + bool: True if the service is ready, False otherwise + """ + try: + response = requests.get(url, timeout=5) + return response.status_code == 200 + except RequestException: + return False + + for _ in range(max_retries): + rest_ready = check_service(rest_url) + + if rest_ready: + logging.info("REST service is ready.") + return True + + logging.info(f"REST Service not ready yet. Retrying in {retry_interval} seconds...") + time.sleep(retry_interval) + + logging.info("Timeout: REST service did not become ready.") + return False diff --git a/nemo/collections/llm/fn/activation.py b/nemo/collections/llm/fn/activation.py index 5970846d32b2..db82f95b4bcc 100644 --- a/nemo/collections/llm/fn/activation.py +++ b/nemo/collections/llm/fn/activation.py @@ -13,6 +13,7 @@ # limitations under the License. import torch +from megatron.core.jit import jit_fuser @torch.jit.script @@ -25,6 +26,11 @@ def openai_gelu(x): return gelu_impl(x) +@jit_fuser +def quick_gelu(x: torch.Tensor) -> torch.Tensor: + return x * torch.sigmoid(1.702 * x) + + # @torch.jit.script # remove until we have serialization def squared_relu(x): """Squared ReLU activation function.""" diff --git a/nemo/collections/llm/gpt/data/__init__.py b/nemo/collections/llm/gpt/data/__init__.py index b42c350bcaba..c8690fd0668f 100644 --- a/nemo/collections/llm/gpt/data/__init__.py +++ b/nemo/collections/llm/gpt/data/__init__.py @@ -15,7 +15,7 @@ from nemo.collections.llm.gpt.data.alpaca import AlpacaDataModule from nemo.collections.llm.gpt.data.dolly import DollyDataModule from nemo.collections.llm.gpt.data.fine_tuning import FineTuningDataModule -from nemo.collections.llm.gpt.data.hf_dataset import HfDatasetDataModule +from nemo.collections.llm.gpt.data.hf_dataset import HFDatasetDataModule from nemo.collections.llm.gpt.data.mock import MockDataModule from nemo.collections.llm.gpt.data.pre_training import PreTrainingDataModule, build_pretraining_datamodule from nemo.collections.llm.gpt.data.squad import SquadDataModule @@ -28,5 +28,5 @@ "MockDataModule", "PreTrainingDataModule", "build_pretraining_datamodule", - "HfDatasetDataModule", + "HFDatasetDataModule", ] diff --git a/nemo/collections/llm/gpt/data/api.py b/nemo/collections/llm/gpt/data/api.py index 74ecb5272ac2..374bee83b8b2 100644 --- a/nemo/collections/llm/gpt/data/api.py +++ b/nemo/collections/llm/gpt/data/api.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl from nemo.collections.llm.gpt.data.dolly import DollyDataModule -from nemo.collections.llm.gpt.data.hf_dataset import HfDatasetDataModule +from nemo.collections.llm.gpt.data.hf_dataset import HFDatasetDataModule from nemo.collections.llm.gpt.data.mock import MockDataModule from nemo.collections.llm.gpt.data.squad import SquadDataModule @@ -42,7 +42,7 @@ def dolly() -> pl.LightningDataModule: @run.cli.factory @run.autoconvert def hf_dataset(dataset: str) -> pl.LightningDataModule: - return HfDatasetDataModule(dataset=dataset, global_batch_size=16, micro_batch_size=2) + return HFDatasetDataModule(dataset=dataset, global_batch_size=16, micro_batch_size=2) __all__ = ["mock", "squad", "dolly", "hf_dataset"] diff --git a/nemo/collections/llm/gpt/data/fine_tuning.py b/nemo/collections/llm/gpt/data/fine_tuning.py index 9d16ea8aa021..0d866bb600fe 100644 --- a/nemo/collections/llm/gpt/data/fine_tuning.py +++ b/nemo/collections/llm/gpt/data/fine_tuning.py @@ -17,7 +17,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union -import pytorch_lightning as pl +import lightning.pytorch as pl from torch.utils.data import DataLoader from nemo.collections.common.tokenizers import AutoTokenizer @@ -117,17 +117,28 @@ def prepare_data(self) -> None: """ Prepare packed sequence data """ - if self.packed_sequence_size > 0 and not self.train_path_packed.is_file(): + if self.packed_sequence_size > 0: from nemo.collections.llm.gpt.data.packed_sequence import prepare_packed_sequence_data - prepare_packed_sequence_data( - input_path=self.train_path, - output_path=self.train_path_packed, - packed_sequence_size=self.packed_sequence_size, - tokenizer=self.tokenizer, - max_seq_length=self.seq_length, - seed=self.seed, - ) + if not self.train_path_packed.is_file(): + prepare_packed_sequence_data( + input_path=self.train_path, + output_path=self.train_path_packed, + packed_sequence_size=self.packed_sequence_size, + tokenizer=self.tokenizer, + max_seq_length=self.seq_length, + seed=self.seed, + ) + + if not self.validation_path_packed.is_file(): + prepare_packed_sequence_data( + input_path=self.validation_path, + output_path=self.validation_path_packed, + packed_sequence_size=self.packed_sequence_size, + tokenizer=self.tokenizer, + max_seq_length=self.seq_length, + seed=self.seed, + ) def setup(self, stage: str): """Called by pytorch lightning in datamodule setup""" @@ -195,7 +206,7 @@ def val_dataloader(self) -> DataLoader: # pylint: disable=C0115,C0116 return self._create_dataloader( self._create_dataset( - self.validation_path, + self.validation_path if self.packed_sequence_size <= 0 else self.validation_path_packed, is_test=True, **self.dataset_kwargs, ), @@ -249,8 +260,8 @@ def train_path_packed(self) -> Path: """Path to training dataset file for packed sequence. The file path contains a reference to the tokenizer/model name since packed sequence dataset consists of tokenized indices.""" if self.packed_sequence_size > 0: - if self.packed_sequence_specs.packed_data_path is not None: - return self.packed_sequence_specs.packed_data_path + if self.packed_sequence_specs.packed_train_data_path is not None: + return self.packed_sequence_specs.packed_train_data_path tokenizer_model_name = self._extract_tokenizer_model_name() folder_name = self.dataset_root / "packed" / tokenizer_model_name folder_name.mkdir(parents=True, exist_ok=True) @@ -258,6 +269,20 @@ def train_path_packed(self) -> Path: else: raise ValueError("`train_path_packed` invalid since packed sequence size is not specified.") + @property + def validation_path_packed(self) -> Path: + """Path to validation dataset file for packed sequence. The file path contains a reference to the + tokenizer/model name since packed sequence dataset consists of tokenized indices.""" + if self.packed_sequence_size > 0: + if self.packed_sequence_specs.packed_val_data_path is not None: + return self.packed_sequence_specs.packed_val_data_path + tokenizer_model_name = self._extract_tokenizer_model_name() + folder_name = self.dataset_root / "packed" / tokenizer_model_name + folder_name.mkdir(parents=True, exist_ok=True) + return folder_name / f"validation_{self.packed_sequence_size}.npy" + else: + raise ValueError("`validation_path_packed` invalid since packed sequence size is not specified.") + @property def validation_path(self) -> Path: """Path to validation dataset file""" diff --git a/nemo/collections/llm/gpt/data/hf_dataset.py b/nemo/collections/llm/gpt/data/hf_dataset.py index 5c6b71c74797..0f45ecf265b7 100644 --- a/nemo/collections/llm/gpt/data/hf_dataset.py +++ b/nemo/collections/llm/gpt/data/hf_dataset.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from torch.utils.data import DataLoader from nemo.lightning.pytorch.plugins import MegatronDataSampler -class HfDatasetDataModule(pl.LightningDataModule): +class HFDatasetDataModule(pl.LightningDataModule): def __init__( self, dataset, @@ -88,7 +88,7 @@ def train_dataloader(self, collate_fn=None): from nemo.lightning.data import add_megatron_sampler if collate_fn is None: - collate_fn = lambda x: HfDatasetDataModule.collate_fn(x, pad_token_id=self.pad_token_id) + collate_fn = lambda x: HFDatasetDataModule.collate_fn(x, pad_token_id=self.pad_token_id) return DataLoader( self.dataset, diff --git a/nemo/collections/llm/gpt/data/mock.py b/nemo/collections/llm/gpt/data/mock.py index 5678597eda0b..f6b4e26ca355 100644 --- a/nemo/collections/llm/gpt/data/mock.py +++ b/nemo/collections/llm/gpt/data/mock.py @@ -14,10 +14,10 @@ from typing import TYPE_CHECKING, Dict, List, Optional +import lightning.pytorch as pl import numpy as np -import pytorch_lightning as pl import torch -from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS +from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS from torch.utils import data from torch.utils.data import DataLoader, Dataset diff --git a/nemo/collections/llm/gpt/data/packed_sequence.py b/nemo/collections/llm/gpt/data/packed_sequence.py index 153e79f94391..345489ea0b63 100644 --- a/nemo/collections/llm/gpt/data/packed_sequence.py +++ b/nemo/collections/llm/gpt/data/packed_sequence.py @@ -101,15 +101,31 @@ class PackedSequenceSpecs: This field is set by llm.finetune api. """ - packed_data_path: str = None + packed_train_data_path: str = None """ - If specified, use the packed dataset from this file instead of the default path. + If specified, use this file for the packed training dataset instead of the default path. + """ + + packed_val_data_path: str = None + """ + If specified, use this file for the packed validation dataset instead of the default path. """ def __post_init__(self): - if self.packed_data_path is not None: - self.packed_data_path = Path(self.packed_data_path) + if self.packed_train_data_path is not None: + self.packed_train_data_path = Path(self.packed_train_data_path) + assert ( + self.packed_train_data_path.suffix == ".npy" + ), f"packed training data file must be a .npy file: {self.packed_train_data_path}" + assert ( + self.packed_train_data_path.exists() + ), f"packed training data file does not exist: {self.packed_train_data_path}" + + if self.packed_val_data_path is not None: + self.packed_val_data_path = Path(self.packed_val_data_path) + assert ( + self.packed_val_data_path.suffix == ".npy" + ), f"packed validation data file must be a .npy file: {self.packed_val_data_path}" assert ( - self.packed_data_path.suffix == ".npy" - ), f"packed data file must be a .npy file: {self.packed_data_path}" - assert self.packed_data_path.exists(), f"packed data file does not exist: {self.packed_data_path}" + self.packed_val_data_path.exists() + ), f"packed validation data file does not exist: {self.packed_val_data_path}" diff --git a/nemo/collections/llm/gpt/data/pre_training.py b/nemo/collections/llm/gpt/data/pre_training.py index cfacde118b89..f659ce72796c 100644 --- a/nemo/collections/llm/gpt/data/pre_training.py +++ b/nemo/collections/llm/gpt/data/pre_training.py @@ -18,8 +18,8 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union -import pytorch_lightning as pl -from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS +import lightning.pytorch as pl +from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS from torch.utils import data from nemo.lightning.data import WrappedDataLoader diff --git a/nemo/collections/llm/gpt/model/__init__.py b/nemo/collections/llm/gpt/model/__init__.py index b42ceac564bc..4e9448eaef2c 100644 --- a/nemo/collections/llm/gpt/model/__init__.py +++ b/nemo/collections/llm/gpt/model/__init__.py @@ -45,7 +45,7 @@ Gemma2Config27B, Gemma2Model, ) -from nemo.collections.llm.gpt.model.hf_auto_model_for_causal_lm import HfAutoModelForCausalLM +from nemo.collections.llm.gpt.model.hf_auto_model_for_causal_lm import HFAutoModelForCausalLM from nemo.collections.llm.gpt.model.llama import ( CodeLlamaConfig7B, CodeLlamaConfig13B, @@ -59,6 +59,8 @@ Llama31Config8B, Llama31Config70B, Llama31Config405B, + Llama32Config1B, + Llama32Config3B, LlamaConfig, LlamaModel, ) @@ -79,6 +81,7 @@ NemotronConfig, NemotronModel, ) +from nemo.collections.llm.gpt.model.phi3mini import Phi3Config, Phi3ConfigMini, Phi3Model from nemo.collections.llm.gpt.model.qwen2 import ( Qwen2Config, Qwen2Config1P5B, @@ -133,6 +136,8 @@ "Llama31Config8B", "Llama31Config70B", "Llama31Config405B", + "Llama32Config1B", + "Llama32Config3B", "NemotronConfig", "Nemotron3Config4B", "Nemotron3Config8B", @@ -140,6 +145,9 @@ "Nemotron3Config22B", "Nemotron4Config340B", "NemotronModel", + "Phi3Config", + "Phi3ConfigMini", + "Phi3Model", "CodeLlamaConfig7B", "CodeLlamaConfig13B", "CodeLlamaConfig34B", @@ -183,5 +191,5 @@ "transformer_engine_layer_spec", "transformer_engine_full_layer_spec", "local_layer_spec", - "HfAutoModelForCausalLM", + "HFAutoModelForCausalLM", ] diff --git a/nemo/collections/llm/gpt/model/base.py b/nemo/collections/llm/gpt/model/base.py index 6b158a33b226..e411077aca31 100644 --- a/nemo/collections/llm/gpt/model/base.py +++ b/nemo/collections/llm/gpt/model/base.py @@ -15,7 +15,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Callable, Dict, Literal, Optional, Union -import pytorch_lightning as L +import lightning.pytorch as L import torch import torch.distributed from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import GPTInferenceWrapper @@ -179,7 +179,7 @@ class GPTConfig(TransformerConfig, io.IOMixin): forward_step_fn: Callable = gpt_forward_step data_step_fn: Callable = gpt_data_step - def configure_model(self, tokenizer) -> "MCoreGPTModel": + def configure_model(self, tokenizer, pre_process=None, post_process=None) -> "MCoreGPTModel": vp_size = self.virtual_pipeline_model_parallel_size if vp_size: p_size = self.pipeline_model_parallel_size @@ -214,8 +214,8 @@ def configure_model(self, tokenizer) -> "MCoreGPTModel": rotary_percent=self.rotary_percent, rotary_base=self.rotary_base, seq_len_interpolation_factor=self.seq_len_interpolation_factor, - pre_process=parallel_state.is_pipeline_first_stage(), - post_process=parallel_state.is_pipeline_last_stage(), + pre_process=pre_process or parallel_state.is_pipeline_first_stage(), + post_process=post_process or parallel_state.is_pipeline_last_stage(), ) # If using full TE layer, need to set TP, CP group since the module call diff --git a/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py b/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py index eada3f4c3eb8..481dd9a0e187 100644 --- a/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py +++ b/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl import torch import torch.nn.functional as F from transformers import AutoModelForCausalLM @@ -31,7 +31,7 @@ def masked_cross_entropy(logits, targets, mask=None): return F.cross_entropy(logits, targets) -class HfAutoModelForCausalLM(pl.LightningModule, io.IOMixin, fn.FNMixin): +class HFAutoModelForCausalLM(pl.LightningModule, io.IOMixin, fn.FNMixin): def __init__( self, model_name='gpt2', @@ -39,6 +39,9 @@ def __init__( tokenizer=None, loss_fn=masked_cross_entropy, model_transform=None, + model_accelerator=None, + trust_remote_code=False, + default_dtype=torch.bfloat16, ): super().__init__() self.save_hyperparameters() @@ -49,11 +52,14 @@ def __init__( self.load_pretrained_weights = load_pretrained_weights self.is_hf_model = True self.model_transform = model_transform + self.model_accelerator = model_accelerator + self.trust_remote_code = trust_remote_code + self.default_dtype = default_dtype @property def tokenizer(self): if self._tokenizer is None: - self._tokenizer = HfAutoModelForCausalLM.configure_tokenizer(self.model_name) + self._tokenizer = HFAutoModelForCausalLM.configure_tokenizer(self.model_name, self.trust_remote_code) return self._tokenizer @tokenizer.setter @@ -62,18 +68,27 @@ def tokenizer(self, value): self._tokenizer = value @staticmethod - def configure_tokenizer(model_name): - return AutoTokenizer(model_name) + def configure_tokenizer(model_name, trust_remote_code=False): + return AutoTokenizer(model_name, trust_remote_code=trust_remote_code) def configure_model(self): # create all your layers here if self.load_pretrained_weights: - self.model = AutoModelForCausalLM.from_pretrained(self.model_name, torch_dtype='auto') + self.model = AutoModelForCausalLM.from_pretrained( + self.model_name, torch_dtype='auto', trust_remote_code=self.trust_remote_code + ) else: from transformers import AutoConfig - config = AutoConfig.from_pretrained(self.model_name) - self.model = AutoModelForCausalLM.from_config(config) + config = AutoConfig.from_pretrained(self.model_name, trust_remote_code=self.trust_remote_code) + dtype = getattr(config, 'torch_dtype', self.default_dtype) + self.model = AutoModelForCausalLM.from_config( + config, torch_dtype=dtype, trust_remote_code=self.trust_remote_code + ) + + if self.model_accelerator is not None: + self.model_accelerator(self.model) + self.model.train() def forward(self, input_ids, attention_mask=None, labels=None, loss_mask=None): diff --git a/nemo/collections/llm/gpt/model/llama.py b/nemo/collections/llm/gpt/model/llama.py index a9d18220bcaf..a7e995addb83 100644 --- a/nemo/collections/llm/gpt/model/llama.py +++ b/nemo/collections/llm/gpt/model/llama.py @@ -14,6 +14,7 @@ import math from dataclasses import dataclass +from functools import partial from pathlib import Path from typing import TYPE_CHECKING, Annotated, Callable, Optional @@ -86,7 +87,7 @@ class Llama2Config70B(LlamaConfig): @dataclass -class Llama3Config(GPTConfig): +class Llama3Config(LlamaConfig): num_query_groups: int = 8 hidden_dropout: float = 0.0 attention_dropout: float = 0.0 @@ -115,8 +116,8 @@ class Llama31Config(Llama3Config): old_context_len: int = 8192 init_method_std: float = 0.02 - def configure_model(self, tokenizer) -> "MCoreGPTModel": - model = super().configure_model(tokenizer) + def configure_model(self, tokenizer, pre_process=None, post_process=None) -> "MCoreGPTModel": + model = super().configure_model(tokenizer, pre_process, post_process) # Apply rope scaling for Llama3.1 model model.rotary_pos_emb.inv_freq = apply_rope_scaling( model.rotary_pos_emb.inv_freq, @@ -182,6 +183,32 @@ class Llama31Config405B(Llama31Config): make_vocab_size_divisible_by: int = 128 +@dataclass +class Llama32Config1B(Llama31Config): + scale_factor: int = 32 + share_embeddings_and_output_weights: bool = True + rotary_base: int = 500_000 + num_layers: int = 16 + hidden_size: int = 2048 + ffn_hidden_size: int = 8192 + num_attention_heads: int = 32 + num_query_groups: int = 8 + make_vocab_size_divisible_by: int = 128 + + +@dataclass +class Llama32Config3B(Llama31Config): + scale_factor: int = 32 + share_embeddings_and_output_weights: bool = True + rotary_base: int = 500_000 + num_layers: int = 28 + hidden_size: int = 3072 + ffn_hidden_size: int = 8192 + num_attention_heads: int = 24 + num_query_groups: int = 8 + make_vocab_size_divisible_by: int = 128 + + @dataclass class CodeLlamaConfig7B(Llama2Config7B): rotary_base: int = 1_000_000 @@ -252,6 +279,9 @@ def convert_state(self, source, target): "model.norm.weight": "decoder.final_layernorm.weight", "lm_head.weight": "output_layer.weight", } + if getattr(source.config, "tie_word_embeddings", False): + # llama 3.2 1B and 3B models have no shared input output embeddings + del mapping["lm_head.weight"] return io.apply_transforms(source, target, mapping=mapping, transforms=[_import_qkv, _import_linear_fc1]) @@ -275,7 +305,7 @@ def make_vocab_size_divisible_by(vocab_size): if getattr(source, 'rope_scaling', None) is not None and source.rope_scaling.get('rope_type') == 'llama3': # Apply Llama3.1 customize rope scaling - cls = Llama31Config + cls = partial(Llama31Config, scale_factor=source.rope_scaling.get("factor", 8.0)) else: cls = LlamaConfig output = cls( @@ -289,7 +319,7 @@ def make_vocab_size_divisible_by(vocab_size): rotary_base=source.rope_theta, gated_linear_unit=True, make_vocab_size_divisible_by=make_vocab_size_divisible_by(source.vocab_size), - share_embeddings_and_output_weights=False, + share_embeddings_and_output_weights=getattr(source, "tie_word_embeddings", False), fp16=(dtype_from_hf(source) == torch.float16), bf16=(dtype_from_hf(source) == torch.bfloat16), params_dtype=dtype_from_hf(source), @@ -355,6 +385,7 @@ def config(self) -> "HFLlamaConfig": num_key_value_heads=source.num_query_groups, rope_theta=source.rotary_base, vocab_size=self.tokenizer.vocab_size, + tie_word_embeddings=source.share_embeddings_and_output_weights, ) @@ -509,6 +540,8 @@ def apply_rope_scaling( "Llama31Config8B", "Llama31Config70B", "Llama31Config405B", + "Llama32Config1B", + "Llama32Config3B", "CodeLlamaConfig7B", "CodeLlamaConfig13B", "CodeLlamaConfig34B", diff --git a/nemo/collections/llm/gpt/model/mistral.py b/nemo/collections/llm/gpt/model/mistral.py index a71042e2ba6f..0aa611b4454e 100644 --- a/nemo/collections/llm/gpt/model/mistral.py +++ b/nemo/collections/llm/gpt/model/mistral.py @@ -16,7 +16,6 @@ from pathlib import Path from typing import TYPE_CHECKING, Callable, List, Optional -import pytorch_lightning as pl import torch import torch.nn.functional as F from torch import nn diff --git a/nemo/collections/llm/gpt/model/phi3mini.py b/nemo/collections/llm/gpt/model/phi3mini.py new file mode 100644 index 000000000000..eb0b9c758dd7 --- /dev/null +++ b/nemo/collections/llm/gpt/model/phi3mini.py @@ -0,0 +1,258 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from pathlib import Path +from typing import Callable, Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel +from nemo.lightning import OptimizerModule, io, teardown +from nemo.lightning.pytorch.utils import dtype_from_hf + + +@dataclass +class Phi3Config(GPTConfig): + # pylint: disable=C0115,C0116 + normalization: str = "RMSNorm" + activation_func: Callable = F.silu + gated_linear_unit: bool = True + position_embedding_type: str = "rope" + add_bias_linear: bool = False + seq_length: int = 4096 + attention_dropout: float = 0.0 + hidden_dropout: float = 0.0 + share_embeddings_and_output_weights: bool = False + + +@dataclass +class Phi3ConfigMini(Phi3Config): + # pylint: disable=C0115,C0116 + num_layers: int = 32 + hidden_size: int = 3072 + ffn_hidden_size: int = 8192 + num_attention_heads: int = 32 + num_query_groups: int = 32 + rotary_base: float = 10000.0 + vocab_size: int = 32064 + + +class Phi3Model(GPTModel): + # pylint: disable=C0115,C0116 + def __init__( + self, + config: Optional[Phi3Config] = None, + optim: Optional[OptimizerModule] = None, + tokenizer: Optional["TokenizerSpec"] = None, + model_transform: Optional[Callable[[nn.Module], nn.Module]] = None, + ): + super().__init__(config or Phi3Config(), optim=optim, tokenizer=tokenizer, model_transform=model_transform) + + +@io.model_importer(Phi3Model, "hf") +class HFPhi3Importer(io.ModelConnector["Phi3ForCausalLM", Phi3Model]): + # pylint: disable=C0115,C0116 + def init(self) -> Phi3Model: + return Phi3Model(self.config, tokenizer=self.tokenizer) + + def apply(self, output_path: Path) -> Path: + from transformers import Phi3ForCausalLM + + # Check if the source is valid model identifier or path + try: + source = Phi3ForCausalLM.from_pretrained(str(self), torch_dtype='auto') + except Exception as e: + raise ValueError(f"Failed to load the model from source '{self}': {e}") + + target = self.init() + trainer = self.nemo_setup(target) + self.convert_state(source, target) + self.nemo_save(output_path, trainer) + + print(f"Converted Phi3 model to Nemo, model saved to {output_path} in {source.dtype}.") + + teardown(trainer, target) + del trainer, target + + return output_path + + def convert_state(self, source, target): + # pylint: disable=C0115,C0116 + # Define mapping for mini-4k-instruct + mapping = { + "model.embed_tokens.weight": "embedding.word_embeddings.weight", + "model.layers.*.self_attn.o_proj.weight": "decoder.layers.*.self_attention.linear_proj.weight", + "model.layers.*.self_attn.qkv_proj.weight": "decoder.layers.*.self_attention.linear_qkv.weight", + "model.layers.*.mlp.gate_up_proj.weight": "decoder.layers.*.mlp.linear_fc1.weight", + "model.layers.*.mlp.down_proj.weight": "decoder.layers.*.mlp.linear_fc2.weight", + "model.layers.*.input_layernorm.weight": "decoder.layers.*.self_attention.linear_qkv.layer_norm_weight", + "model.layers.*.post_attention_layernorm.weight": "decoder.layers.*.mlp.linear_fc1.layer_norm_weight", + "model.norm.weight": "decoder.final_layernorm.weight", + "lm_head.weight": "output_layer.weight", + } + + return io.apply_transforms(source, target, mapping=mapping, transforms=[_import_qkv, _import_linear_fc1]) + + @property + def tokenizer(self): + # pylint: disable=C0115,C0116 + from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer + + return AutoTokenizer(self.save_hf_tokenizer_assets(str(self))) + + @property + def config(self) -> Phi3Config: + # pylint: disable=C0115,C0116 + from transformers import Phi3Config as HFPhi3Config + + source = HFPhi3Config.from_pretrained(str(self)) + + def make_vocab_size_divisible_by(vocab_size): + base = 128 + while vocab_size % base != 0: + base //= 2 + return base + + output = Phi3Config( + num_layers=source.num_hidden_layers, + hidden_size=source.hidden_size, + ffn_hidden_size=source.intermediate_size, + num_attention_heads=source.num_attention_heads, + init_method_std=source.initializer_range, + layernorm_epsilon=source.rms_norm_eps, + rotary_base=source.rope_theta, + gated_linear_unit=True, + make_vocab_size_divisible_by=make_vocab_size_divisible_by(source.vocab_size), + share_embeddings_and_output_weights=False, + fp16=(dtype_from_hf(source) == torch.float16), + bf16=(dtype_from_hf(source) == torch.bfloat16), + params_dtype=dtype_from_hf(source), + ) + print("output:", output) + return output + + +@io.model_exporter(Phi3Model, "hf") +class HFPhi3Exporter(io.ModelConnector[Phi3Model, "Phi3ForCausalLM"]): + # pylint: disable=C0115,C0116 + def init(self) -> "Phi3ForCausalLM": + from transformers import AutoModelForCausalLM + + return AutoModelForCausalLM.from_config(self.config) + + def apply(self, output_path: Path) -> Path: + target = self.init() + source, _ = self.nemo_load(str(self)) + target = self.convert_state(source, target) + + target.cpu().save_pretrained(output_path) + self.tokenizer.save_pretrained(output_path) + + return output_path + + def convert_state(self, source, target): + # pylint: disable=C0115,C0116 + mapping = { + "embedding.word_embeddings.weight": "model.embed_tokens.weight", + "decoder.layers.*.self_attention.linear_proj.weight": "model.layers.*.self_attn.o_proj.weight", + "decoder.layers.*.mlp.linear_fc2.weight": "model.layers.*.mlp.down_proj.weight", + "decoder.layers.*.self_attention.linear_qkv.layer_norm_weight": "model.layers.*.input_layernorm.weight", + "decoder.layers.*.mlp.linear_fc1.layer_norm_weight": "model.layers.*.post_attention_layernorm.weight", + "decoder.final_layernorm.weight": "model.norm.weight", + "output_layer.weight": "lm_head.weight", + } + + # Convert source weights to target dtype if needed + for name, param in source.state_dict().items(): + if param.dtype != target.state_dict()[name].dtype: + param.data = param.data.to(target.state_dict()[name].dtype) + + return io.apply_transforms(source, target, mapping=mapping) + + @property + def tokenizer(self): + # pylint: disable=C0115,C0116 + return io.load_context(str(self)).model.tokenizer.tokenizer + + @property + def config(self) -> "HFPhi3Config": + # pylint: disable=C0115,C0116 + source: Phi3Config = io.load_context(str(self)).model.config + + from transformers import Phi3Config as HFPhi3Config + + return HFPhi3Config( + num_hidden_layers=source.num_layers, + hidden_size=source.hidden_size, + intermediate_size=source.ffn_hidden_size, + num_attention_heads=source.num_attention_heads, + max_position_embeddings=source.seq_length, + initializer_range=0.02, + rms_norm_eps=1e-05, + num_key_value_heads=source.num_query_groups, + rope_theta=source.rotary_base, + vocab_size=self.tokenizer.vocab_size, + ) + + +@io.state_transform( + source_key="model.layers.*.self_attn.qkv_proj.weight", + target_key="decoder.layers.*.self_attention.linear_qkv.weight", +) +def _import_qkv(ctx: io.TransformCTX, qkv_weight): + megatron_config = ctx.target.config + + head_num = megatron_config.num_attention_heads + num_query_groups = megatron_config.num_query_groups + heads_per_group = head_num // num_query_groups + hidden_size = megatron_config.hidden_size + head_size = megatron_config.kv_channels + + old_tensor_shape = qkv_weight.size() + new_q_tensor_shape = (head_num, head_size, old_tensor_shape[1]) + new_kv_tensor_shape = (num_query_groups, head_size, old_tensor_shape[1]) + q, k, v = qkv_weight.split( + [head_num * head_size, num_query_groups * head_size, num_query_groups * head_size], dim=0 + ) + q = q.view(*new_q_tensor_shape) + k = k.view(*new_kv_tensor_shape) + v = v.view(*new_kv_tensor_shape) + + qkv_weights = torch.empty((0, head_size, old_tensor_shape[1])).type_as(qkv_weight) + for i in range(num_query_groups): + qkv_weights = torch.cat((qkv_weights, q[i * heads_per_group : (i + 1) * heads_per_group, :, :])) + qkv_weights = torch.cat((qkv_weights, k[i : i + 1, :, :])) + qkv_weights = torch.cat((qkv_weights, v[i : i + 1, :, :])) + assert qkv_weights.ndim == 3, qkv_weights.shape + assert qkv_weights.shape[0] == (heads_per_group + 2) * num_query_groups, qkv_weights.shape + assert qkv_weights.shape[1] == head_size, qkv_weights.shape + assert qkv_weights.shape[2] == old_tensor_shape[1], qkv_weights.shape + + qkv_weights = qkv_weights.reshape([head_size * (head_num + 2 * num_query_groups), hidden_size]) + + return qkv_weights + + +@io.state_transform( + source_key=("model.layers.*.mlp.gate_proj.weight", "model.layers.*.mlp.up_proj.weight"), # phi-3-mini-4k-instruct + target_key="decoder.layers.*.mlp.linear_fc1.weight", +) +def _import_linear_fc1(down, gate): + return torch.cat((down, gate), axis=0) + + +__all__ = ["Phi3Config", "Phi3ConfigMini", "Phi3Model"] diff --git a/nemo/collections/llm/gpt/model/ssm.py b/nemo/collections/llm/gpt/model/ssm.py index d38a690cb4ad..f4190114042e 100644 --- a/nemo/collections/llm/gpt/model/ssm.py +++ b/nemo/collections/llm/gpt/model/ssm.py @@ -86,7 +86,7 @@ class SSMConfig(TransformerConfig, io.IOMixin): data_step_fn: Callable = gpt_data_step tokenizer_model_path: str = None - def configure_model(self, tokenizer) -> "MCoreMambaModel": + def configure_model(self, tokenizer, pre_process=None, post_process=None) -> "MCoreMambaModel": return MCoreMambaModel( self, @@ -101,8 +101,8 @@ def configure_model(self, tokenizer) -> "MCoreMambaModel": rotary_percent=self.rotary_percent, rotary_base=self.rotary_base, seq_len_interpolation_factor=self.seq_len_interpolation_factor, - pre_process=parallel_state.is_pipeline_first_stage(), - post_process=parallel_state.is_pipeline_last_stage(), + pre_process=pre_process or parallel_state.is_pipeline_first_stage(), + post_process=post_process or parallel_state.is_pipeline_last_stage(), ) @@ -290,6 +290,7 @@ class BaseMambaConfig2_7B(SSMConfig): @dataclass class NVIDIAMambaConfig8B(SSMConfig): hybrid_override_pattern: str = "M" * 56 + num_attention_heads: int = 32 num_layers: int = 56 seq_length: int = 4096 hidden_size: int = 4096 diff --git a/nemo/collections/llm/inference/base.py b/nemo/collections/llm/inference/base.py index 55d865ec238b..795d6efadd3a 100644 --- a/nemo/collections/llm/inference/base.py +++ b/nemo/collections/llm/inference/base.py @@ -16,9 +16,10 @@ from pathlib import Path from typing import Optional, Union -import pytorch_lightning as pl +import lightning.pytorch as pl import torch import torch.distributed +from lightning.pytorch.trainer.states import TrainerFn from megatron.core.inference.common_inference_params import CommonInferenceParams from megatron.core.inference.engines.mcore_engine import MCoreEngine from megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper import ( @@ -31,13 +32,12 @@ SimpleTextGenerationController, ) from megatron.core.transformer.module import MegatronModule -from pytorch_lightning.trainer.states import TrainerFn import nemo.lightning as nl -from nemo.collections.llm.peft import LoRA from nemo.lightning import io from nemo.lightning.ckpt_utils import ADAPTER_META_FILENAME, ckpt_to_context_subdir from nemo.lightning.io.pl import ckpt_to_weights_subdir +from nemo.lightning.pytorch.callbacks import PEFT from nemo.lightning.pytorch.strategies.megatron_strategy import MegatronStrategy from nemo.lightning.pytorch.strategies.utils import RestoreConfig @@ -161,9 +161,9 @@ def _setup_trainer_and_restore_model(path: Path, trainer: nl.Trainer, model: pl. trainer.strategy.trainer = trainer trainer.strategy.selective_restore() - lora: Union[io.TrainerContext, LoRA] = io.load_context(ckpt_to_context_subdir(path), "model.model_transform") - if isinstance(lora, LoRA): - model = lora(model) + peft: Union[io.TrainerContext, PEFT] = io.load_context(ckpt_to_context_subdir(path), "model.model_transform") + if isinstance(peft, PEFT): + model = peft(model) adapter_sharded_state_dict = {k: v for k, v in model.sharded_state_dict().items() if ".adapter." in k} adapter_state = trainer.strategy.checkpoint_io.load_checkpoint( ckpt_to_weights_subdir(path, is_saving=False), sharded_state_dict=adapter_sharded_state_dict @@ -242,7 +242,7 @@ def generate( text_generation_controller=text_generation_controller, max_batch_size=max_batch_size, random_seed=random_seed ) - common_inference_params = inference_params or CommonInferenceParams(num_tokens_to_generate=512) + common_inference_params = inference_params or CommonInferenceParams(num_tokens_to_generate=512, top_k=1) results = mcore_engine.generate( prompts=prompts, diff --git a/nemo/collections/llm/peft/__init__.py b/nemo/collections/llm/peft/__init__.py index 3dae5622b733..1dcc070a5a97 100644 --- a/nemo/collections/llm/peft/__init__.py +++ b/nemo/collections/llm/peft/__init__.py @@ -12,7 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from nemo.collections.llm.peft.api import gpt_lora +from nemo.collections.llm.peft.api import gpt_lora, merge_lora +from nemo.collections.llm.peft.dora import DoRA from nemo.collections.llm.peft.lora import LoRA -__all__ = ["LoRA", "gpt_lora"] +PEFT_STR2CLS = { + "LoRA": LoRA, + "lora": LoRA, + "DoRA": DoRA, + "dora": DoRA, +} + +__all__ = ["LoRA", "DoRA", "gpt_lora", "PEFT_STR2CLS", "merge_lora"] diff --git a/nemo/collections/llm/peft/api.py b/nemo/collections/llm/peft/api.py index 85c0ae6cae41..a089a6d17515 100644 --- a/nemo/collections/llm/peft/api.py +++ b/nemo/collections/llm/peft/api.py @@ -12,9 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -from nemo.collections.llm.peft.lora import LoRA +import json +from pathlib import Path +from typing import Tuple, Union + +import pytorch_lightning as pl +from megatron.core import dist_checkpointing +from pytorch_lightning.trainer.states import TrainerFn + +from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer +from nemo.collections.llm.peft.lora import LoRA, LoRAMerge from nemo.collections.llm.utils import factory +from nemo.lightning import MegatronStrategy, Trainer, _strategy_lib, io +from nemo.lightning.ckpt_utils import ADAPTER_META_FILENAME, ckpt_to_context_subdir +from nemo.lightning.io.pl import TrainerContext, ckpt_to_weights_subdir +from nemo.lightning.pytorch.callbacks import PEFT from nemo.lightning.pytorch.callbacks.peft import PEFT +from nemo.lightning.pytorch.strategies.utils import RestoreConfig +from nemo.utils import logging @factory @@ -22,4 +37,108 @@ def gpt_lora() -> PEFT: return LoRA() -__all__ = ["gpt_lora"] +def merge_lora( + lora_checkpoint_path: str, + output_path: str, +) -> None: + """ + Merges the LoRA adapter weights into the base model's weights. + + Python Usage: + ```python + if __name__ == '__main__': + llm.peft.merge_lora( + lora_checkpoint_path=your_lora_checkpoint_path, + output_path=your_output_path, + ) + ``` + + Args: + lora_checkpoint_path: The path to the LoRA checkpoint. + output_path: The path to save the merged checkpoint. + + """ + from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed + + trainer = Trainer( + devices=1, + accelerator="cpu", + strategy=MegatronStrategy(ddp="pytorch", setup_optimizers=False, plugins=bf16_mixed()), + ) + + model, lora = _load_base_model_and_lora(lora_checkpoint_path) + _setup_trainer_and_restore_model_and_adapter(Path(lora_checkpoint_path), trainer, model, lora) + + lora_merge = LoRAMerge() + merged_model = lora_merge(trainer.strategy.megatron_parallel) + merged_weights = {k: v for k, v in merged_model.sharded_state_dict().items() if ".adapter." not in k} + _save_merged_weight(output_path, merged_weights, model, trainer) + + +def _load_base_model_and_lora(lora_checkpoint_path: Path) -> Tuple[pl.LightningModule, LoRA]: + model = io.load_context(ckpt_to_context_subdir(lora_checkpoint_path), "model") + model.model_transform, model.__io__.model_transform = None, None + model.config.bf16 = False + lora: Union[io.TrainerContext, LoRA] = io.load_context( + ckpt_to_context_subdir(lora_checkpoint_path), "model.model_transform" + ) + assert isinstance(lora, LoRA), "LoRA config not found in checkpoint" + return model, lora + + +def _setup_trainer_and_restore_model_and_adapter( + lora_checkpoint_path: Path, trainer: Trainer, model: pl.LightningModule, lora: LoRA +) -> None: + if ( + adapter_meta_path := ckpt_to_weights_subdir(lora_checkpoint_path, is_saving=False) / ADAPTER_META_FILENAME + ).exists(): + with open(adapter_meta_path, "r") as f: + metadata = json.load(f) + restore_config = RestoreConfig( + path=metadata["model_ckpt_path"], + load_model_state=True, + load_optim_state=False, + ) + else: + raise ValueError(f"Cannot find adapter meta file in {lora_checkpoint_path}") + + trainer.strategy.restore_config = restore_config + trainer.strategy._setup_optimizers = False + trainer.ckpt_path = None + trainer.strategy.connect(model) + trainer.strategy.setup_environment() + + if not model.state_dict(): + with _strategy_lib.megatron_cpu_init_context(model.config): + model.configure_model() + + trainer.strategy.setup(trainer) # load base model ckpt + trainer.state.fn = TrainerFn.TESTING + trainer.strategy.setup_megatron_parallel(trainer=trainer) + trainer.strategy.trainer = trainer + model.trainer = trainer + + lora(model) + adapter_sharded_state_dict = { + k: v for k, v in trainer.strategy.megatron_parallel.sharded_state_dict().items() if ".adapter." in k + } + adapter_state = trainer.strategy.checkpoint_io.load_checkpoint( + ckpt_to_weights_subdir(lora_checkpoint_path, is_saving=False), sharded_state_dict=adapter_sharded_state_dict + ) + trainer.strategy.load_model_state_dict(adapter_state, strict=False) + + +def _save_merged_weight(output_path: str, merged_weights: dict, model: pl.LightningModule, trainer: Trainer): + weight_path = ckpt_to_weights_subdir(output_path, is_saving=True) + Path(weight_path).mkdir(parents=True, exist_ok=True) + dist_checkpointing.save(merged_weights, str(ckpt_to_weights_subdir(output_path, is_saving=True))) + if hasattr(model.tokenizer, "save_pretrained"): + model.tokenizer.save_pretrained("/tmp/nemo_tokenizer") + model.tokenizer = AutoTokenizer("/tmp/nemo_tokenizer") + if hasattr(trainer.model, "__io__") and hasattr(trainer.model.tokenizer, '__io__'): + trainer.model.__io__.tokenizer = trainer.model.tokenizer.__io__ + TrainerContext.from_trainer(trainer).io_dump(ckpt_to_context_subdir(output_path), yaml_attrs=["model"]) + logging.info(f"Merged checkpoint saved to {output_path}") + + +__all__ = ["gpt_lora", "merge_lora"] diff --git a/nemo/collections/llm/peft/dora.py b/nemo/collections/llm/peft/dora.py new file mode 100644 index 000000000000..d77d2a4dc0d4 --- /dev/null +++ b/nemo/collections/llm/peft/dora.py @@ -0,0 +1,261 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import re +from dataclasses import dataclass, field +from typing import List, Literal, Optional + +import torch +from megatron.core import parallel_state +from megatron.core.dist_checkpointing.mapping import ShardedStateDict +from megatron.core.tensor_parallel import ( + ColumnParallelLinear, + RowParallelLinear, + gather_from_tensor_model_parallel_region, +) +from megatron.core.utils import make_sharded_tensor_for_checkpoint, make_tp_sharded_tensor_for_checkpoint +from torch import nn + +from nemo.collections.llm.peft.lora import LinearAdapter +from nemo.collections.nlp.modules.common.megatron.adapters.parallel_adapters import ParallelLinearAdapter +from nemo.lightning.pytorch.callbacks.peft import PEFT, AdapterWrapper +from nemo.utils import logging +from nemo.utils.import_utils import safe_import_from + +TEColumnParallelLinear, HAVE_TE_COL_LINEAR = safe_import_from( + "megatron.core.extensions.transformer_engine", "TEColumnParallelLinear" +) +TELayerNormColumnParallelLinear, HAVE_TE_LN_COL_LINEAR = safe_import_from( + "megatron.core.extensions.transformer_engine", + "TELayerNormColumnParallelLinear", +) +TERowParallelLinear, HAVE_TE_ROW_LINEAR = safe_import_from( + "megatron.core.extensions.transformer_engine", "TERowParallelLinear" +) +HAVE_TE = all((HAVE_TE_COL_LINEAR, HAVE_TE_LN_COL_LINEAR, HAVE_TE_ROW_LINEAR)) + + +class ParallelLinearDoRAAdapter(ParallelLinearAdapter): + """ + Adapter class for DoRA to handle the additional weight_magnitude parameter + """ + + def init_weight_magnitude(self, value): + """ + Initialize weight_magnitude with shape (d,), where d is the output dim of the linear layer + """ + self.weight_magnitude = nn.Parameter(value, requires_grad=True) + + def get_weight_magnitude(self): + """ + Public function to get the weight magnitude parameter + """ + return self.weight_magnitude + + def sharded_state_dict( + self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[dict] = None + ) -> ShardedStateDict: + """ + Sharded state dict implementation for DoRA adapter. + Weight magnitude is TP sharded for linear_qkv and linear_fc1 only. + """ + sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata) + + magnitude_key = f"{prefix}weight_magnitude" + if self.input_is_parallel: + # RPL output is gathered, so weight_magnitude is not sharded for TP + magnitude_sharded_tensor = make_sharded_tensor_for_checkpoint( + self.weight_magnitude, magnitude_key, prepend_offsets=sharded_offsets + ) + else: + # CPL output is sharded, so weight_magnitude is sharded for TP + magnitude_sharded_tensor = make_tp_sharded_tensor_for_checkpoint( + self.weight_magnitude, magnitude_key, 0, prepend_offsets=sharded_offsets + ) + sharded_state_dict[magnitude_key] = magnitude_sharded_tensor + + return sharded_state_dict + + +class DoRALinear(AdapterWrapper): + """ + An adapter wrapper that is designed to be used with DoRA + It extends the AdapterWrapper class to provide a specific implementation of the forward method. + """ + + def __init__(self, to_wrap: nn.Module, adapter: ParallelLinearDoRAAdapter): + super().__init__(to_wrap, adapter) + self.adapter: ParallelLinearDoRAAdapter + self.scaling = adapter.alpha / adapter.dim + self.adapter.init_weight_magnitude(self._get_weight_norm()) + + def _get_weight_norm(self): + if self.adapter.input_is_parallel: + linear_out_weight = gather_from_tensor_model_parallel_region(self.adapter.linear_out.weight.T).T + linear_in_weight = self.adapter.linear_in.weight + else: + linear_out_weight = self.adapter.linear_out.weight + linear_in_weight = gather_from_tensor_model_parallel_region(self.adapter.linear_in.weight.T).T + + weight = self.to_wrap.weight + self.scaling * linear_out_weight @ linear_in_weight + return torch.linalg.norm(weight, dim=1).to(weight.dtype).detach() + + def forward(self, x): + """ + Forward method for DoRA + + mag_norm_scale * (linear_output + adapter_output) + = ||W_0 + B_0 A_0|| / ||W_0 + B A|| * (W_0 x + B A x) + = ||W_0 + B_0 A_0|| ((W_0 + B A) / ||W_0 + B A||) x + = m ((W_0 + B A) / ||W_0 + B A||) x + = equation 5 in DoRA paper + + When dropout is used, equation becomes + W_0 x + (m /||W_0 + B A|| - 1) W_0 dropout(x) + m /||W_0 + B A|| B A dropout(x) + = ... + = m /||W_0 + B A|| (W_0 x + B A dropout(x)) + (m /||W_0 + B A|| - 1) W_0 (dropout(x) - x) + + """ + linear_output, bias, layernorm_output = self.base_linear_forward(x) + adapter_output = self.adapter(layernorm_output.contiguous()) + + # mag_norm_scale is ||W_0 + B_0 A_0|| / ||W_0 + B A|| (scaling in front of BA not shown) + mag_norm_scale = (self.adapter.get_weight_magnitude() / self._get_weight_norm()).view(1, 1, -1) + if self.adapter.dropout is None or not self.training: + dropout_correction = 0 + else: + dropout_correction = (mag_norm_scale - 1) * self.base_linear_forward( + self.adapter.dropout(layernorm_output) - layernorm_output + )[0] + + return mag_norm_scale * (linear_output + adapter_output) + dropout_correction, bias + + +@dataclass +class DoRA(PEFT): + """ + Implements the DoRA (Weight-Decomposed LowRank Adaptation) module for parameter-efficient fine-tuning. + + DoRA decomposes pre-trained weight into magnitude and direction, and uses a low-rank projection in the + directional component to adapt the weights of a pre-trained model to a new downstream task. + This class facilitates the application of DoRA to specific modules within the model architecture. + + Args: + See LoRA class for a detailed explanation of the arguments. + + Example: + -------- + >>> from nemo.collections import llm + >>> lora = llm.peft.DoRA(target_modules=['linear_qkv', 'linear_proj'], dim=32, alpha=64) + >>> model = llm.Mistral7BModel(model_transform=lora) + >>> # (set up trainer and data) + >>> trainer.fit(model, data) + + References: + ----------- + Shih-Yang Liu, Chien-Yi Wang, Hongxu Yin, Pavlo Molchanov, Yu-Chiang Frank Wang, Kwang-Ting Cheng, + Min-Hung Chen (2024). DoRA: Weight-Decomposed Low-Rank Adaptation. arXiv preprint arXiv:2402.09353. + https://arxiv.org/abs/2402.09353 + ) + """ + + target_modules: List[str] = field( + default_factory=lambda: ['linear_qkv', 'linear_proj', 'linear_fc1', 'linear_fc2'] + ) + dim: int = 32 + alpha: int = 64 + dropout: float = 0.0 + dropout_position: Literal['pre', 'post'] = 'pre' + lora_A_init_method: str = "xavier" + lora_B_init_method: str = "zero" + + def __post_init__(self): + assert self.dropout_position == "pre", ( + "DoRA only supports pre-adapter dropout at this time." "Please set DoRA(..., dropout_position='pre')" + ) + + def transform(self, m: nn.Module, name=None, prefix=None): + """ + Applies DoRA to a specific module within the model architecture. + + Args: + m (nn.Module): The module to apply DoRA to. + name (str, optional): Name of the module (if applicable). Defaults to None. + prefix (str, optional): Prefix for the module name (if applicable). Defaults to None. + + Returns: + nn.Module: The modified module with DoRA applied, or the original module if not a target. + """ + + def wildcard_match(pattern, key): + if key is None: + return None + regex_pattern = re.compile("^" + pattern.replace("*", "(.*)") + "$") + match = regex_pattern.match(key) + return match is not None + + full_name = f"{prefix}.{name}" if prefix else name + if name in self.target_modules or any(wildcard_match(pattern, full_name) for pattern in self.target_modules): + if HAVE_TE and isinstance(m, TEColumnParallelLinear) or isinstance(m, TELayerNormColumnParallelLinear): + input_is_parallel = False + # m.in_features and m.out_features are divided by tp_size already, + # but in_features and out_features passed to ParallelLinearAdapter are not. + tp_size = parallel_state.get_tensor_model_parallel_world_size() + in_features = m.in_features + out_features = m.out_features * tp_size + # DoRA is applied after layernorm, so layernorm output must be returned + m.return_layernorm_output = True + # perf optimization for DoRA + SP (to check!) + if m.config.sequence_parallel and not m.ub_overlap_ag: + m.return_layernorm_output_gathered = True + elif HAVE_TE and isinstance(m, TERowParallelLinear): + input_is_parallel = True + tp_size = parallel_state.get_tensor_model_parallel_world_size() + in_features = m.in_features * tp_size + out_features = m.out_features + elif isinstance(m, ColumnParallelLinear): + input_is_parallel = False + in_features = m.input_size + out_features = m.output_size + elif isinstance(m, RowParallelLinear): + input_is_parallel = True + in_features = m.input_size + out_features = m.output_size + elif isinstance(m, nn.Linear): + return LinearAdapter( + m, dim=self.dim, alpha=self.alpha, dropout=self.dropout, lora_A_init_method=self.lora_A_init_method + ) + else: + raise NotImplementedError(f"Layer type is unrecognized for LoRA: {type(m)}") + + logging.info(f"Adding DoRA to: {full_name}") + adapter = ParallelLinearDoRAAdapter( + in_features, + out_features, + self.dim, + activation='identity', + norm_position=None, + norm_type=None, + column_init_method=self.lora_A_init_method, + row_init_method=self.lora_B_init_method, + gather_output=False, + input_is_parallel=input_is_parallel, + dropout=self.dropout, + dropout_position=self.dropout_position, + model_parallel_config=getattr(m, "config", None), + alpha=self.alpha, + ) + return DoRALinear(m, adapter) + return m diff --git a/nemo/collections/llm/peft/lora.py b/nemo/collections/llm/peft/lora.py index 77063b9d7e98..205cde071fa7 100644 --- a/nemo/collections/llm/peft/lora.py +++ b/nemo/collections/llm/peft/lora.py @@ -48,25 +48,8 @@ class to provide a specific implementation of the forward method. """ def forward(self, x): - linear_output = self.to_wrap(x) - assert isinstance( - linear_output, tuple - ), f"{self.to_wrap} should return a tuple but instead returns {linear_output}" - """ Four cases for the wrapped module's return values - 1. nothing: (out, None) - 2. return_bias: (out, bias) - 2. return_layernorm_output: ((out, ln_out), None) - 3. both: (out, bias, ln_out) - """ - if len(linear_output) == 2: - linear_output, bias = linear_output - if isinstance(linear_output, tuple) and len(linear_output) == 2: - linear_output, layernorm_output = linear_output - x = layernorm_output - elif len(linear_output) == 3: - linear_output, bias, layernorm_output = linear_output - x = layernorm_output - adapter_output = self.adapter(x.contiguous()) + linear_output, bias, layernorm_output = self.base_linear_forward(x) + adapter_output = self.adapter(layernorm_output.contiguous()) return linear_output + adapter_output, bias @@ -129,8 +112,8 @@ class LoRA(PEFT): target_modules (List[str], optional): A list of module names to apply LoRA to. Defaults to all linear layers ['linear_qkv', 'linear_proj', 'linear_fc1', 'linear_fc2']. - 'linear_qkv': Apply LoRA to the fused linear layer used for query, key, and value projections - in self-attention modules. - - 'linear_proj': Apply LoRA to the linear layer used for projecting the output of self-attention modules. + in self-attention. + - 'linear_proj': Apply LoRA to the linear layer used for projecting the output of self-attention. - 'linear_fc1': Apply LoRA to the first fully-connected layer in MLP. - 'linear_fc2': Apply LoRA to the second fully-connected layer in MLP. Target modules can also contain wildcards. For example, you can specify @@ -141,6 +124,7 @@ class LoRA(PEFT): dropout (float): Dropout rate for the low-rank projection. Defaults to 0.0. dropout_position (Literal['pre', 'post'], optional): Position for applying dropout. Can be 'pre' (before the low-rank projection) or 'post' (after). Defaults to 'post'. + a2a_experimental (bool): Enables the experimental All-to-All (A2A) communication strategy. Defaults to False. Example: -------- @@ -168,6 +152,7 @@ class LoRA(PEFT): dropout_position: Literal['pre', 'post'] = 'post' lora_A_init_method: str = "xavier" lora_B_init_method: str = "zero" + a2a_experimental: bool = False def transform(self, m: nn.Module, name=None, prefix=None): """ @@ -241,6 +226,47 @@ def wildcard_match(pattern, key): model_parallel_config=getattr(m, "config", None), alpha=self.alpha, is_expert=is_expert_linear(full_name), + a2a_experimental=self.a2a_experimental, ) return AdapterParallelAdd(m, adapter) return m + + +class LoRAMerge(PEFT): + """ + Implements the LoRA weight merge for parameter-efficient fine-tuning. + + Example: + -------- + >>> from nemo.collections.llm.peft.lora import LoRAMerge + >>> lora_merge = LoRAMerge() + >>> merged_model = lora_merge(trainer.strategy.megatron_parallel) + """ + + @torch.no_grad() + def transform(self, m: nn.Module, name=None, prefix=None): + """ + Merges the LoRA adapter with the base model weights. + + Args: + m (nn.Module): The module to apply LoRA merge to. + name (str, optional): Name of the module to merge. Defaults to None. + prefix (str, optional): Prefix for the module name. Defaults to None. + + Returns: + nn.Module: The modified module with the LoRA adapter merged into the base model weights. + """ + + if not isinstance(m, AdapterParallelAdd): + return m + logging.info(f'merging {(prefix if prefix else "") + "." + (name if name else "")}') + base_weight = m.to_wrap.weight + lora_weight = ( + m.adapter.alpha + / m.adapter.dim + * m.adapter.linear_out.weight.to(base_weight.device) + @ m.adapter.linear_in.weight.to(base_weight.device) + ) + merged_weight = base_weight + lora_weight + m.to_wrap.weight.data = merged_weight + return m diff --git a/nemo/collections/llm/quantization/quantizer.py b/nemo/collections/llm/quantization/quantizer.py index 2f3e0e1e986e..d41ba39f39ea 100644 --- a/nemo/collections/llm/quantization/quantizer.py +++ b/nemo/collections/llm/quantization/quantizer.py @@ -15,6 +15,7 @@ import os import shutil from dataclasses import dataclass +from pathlib import Path from typing import Optional, Union import torch @@ -23,10 +24,12 @@ from tqdm import tqdm from nemo.collections import llm -from nemo.lightning.ckpt_utils import CONTEXT_PATH +from nemo.collections.llm.inference import MCoreTokenizerWrappper, generate +from nemo.lightning.ckpt_utils import ckpt_to_context_subdir +from nemo.lightning.megatron_parallel import MegatronParallel from nemo.utils import logging -from .utils import get_unwrapped_mcore_model +from .utils import get_modelopt_decoder_type, get_unwrapped_mcore_model try: import modelopt.torch.quantization as mtq @@ -75,51 +78,31 @@ class QuantizationConfig: @dataclass class ExportConfig: - """Inference configuration for the quantized TensorRT-LLM engine""" + """Inference configuration for the quantized TensorRT-LLM checkpoint.""" - path: str + path: Union[Path, str] dtype: Union[str, int] = "bf16" decoder_type: Optional[str] = None inference_tensor_parallel: int = 1 inference_pipeline_parallel: int = 1 + generate_sample: bool = False - -def get_modelopt_decoder_type(config: llm.GPTConfig) -> str: - """Infers the modelopt decoder type from GPTConfig class""" - mapping = [ - (llm.Baichuan2Config, "baichuan"), - (llm.ChatGLMConfig, "chatglm"), - (llm.GemmaConfig, "gemma"), - (llm.LlamaConfig, "llama"), - (llm.MistralConfig7B, "llama"), - (llm.MixtralConfig, "llama"), - (llm.NemotronConfig, "gptnext"), - (llm.Qwen2Config, "qwen"), - # TODO: (llm.StarcoderConfig, ""), - (llm.Starcoder2Config, "gptnext"), - ] - - for config_class, decoder_type in mapping: - if isinstance(config, config_class): - return decoder_type - - logging.warning("Could not directly infer the decoder type") - # TODO: Add a reasonable behavior for GPTConfig (for instance based on position_embedding_type) - return "llama" + def __post_init__(self): + self.path = Path(self.path) class Quantizer: - """Post-training quantization (PTQ) and TRT-LLM export of NeMo 2.0 checkpoints. + """Post-training quantization (PTQ) and TensorRT-LLM export of NeMo 2.0 checkpoints. PTQ converts selected model layers to low-precision format (e.g., INT4, FP8) for efficient serving. The process consist of several steps: 1. Loading a Nemo model from disk using appropriate parallelism strategy 2. Calibrating the model to obtain appropriate algorithm-specific scaling factors - 3. Producing output directory + 3. Producing an output directory with a quantized checkpoint and a tokenizer The output directory produced is intended to be consumed by TensorRT-LLM toolbox - for efficient inference. This can be achieved using NeMo inference containers. + for efficient inference. This can be achieved using nemo.export.tensorrt_llm module. """ def __init__(self, quantization_config: QuantizationConfig, export_config: ExportConfig): @@ -142,16 +125,37 @@ def __init__(self, quantization_config: QuantizationConfig, export_config: Expor assert dtype in SUPPORTED_DTYPE, f"Unsupported export dtype: {dtype}" self.torch_dtype = torch_dtype_from_precision(dtype) - def _setup(self, model: llm.GPTModel) -> None: + @staticmethod + def _setup(model: MegatronParallel) -> None: """Setup model for quantization.""" # TODO: disable activation checkpointing model.config.vocab_size = model.tokenizer.vocab_size model.freeze() - def _get_decoder_type(self, config: llm.GPTConfig): - return self.export_config.decoder_type or get_modelopt_decoder_type(config) + def _get_decoder_type(self, model: MegatronParallel): + if self.export_config.decoder_type is not None: + return self.export_config.decoder_type + unwrapped_model = model + while not isinstance(unwrapped_model, llm.GPTModel): + unwrapped_model = unwrapped_model.module - def quantize(self, model: llm.GPTModel, forward_loop=None): + return get_modelopt_decoder_type(unwrapped_model) + + @staticmethod + def _generate_sample(model: MegatronParallel): + prompts = ["Born in north-east France, Soyer trained as a", "Born in California, Soyer trained as a"] + + mcore_tokenizer = MCoreTokenizerWrappper(model.tokenizer) + mcore_inference = model.get_inference_wrapper( + params_dtype=torch.bfloat16, inference_batch_times_seqlen_threshold=30 + ) + + generated = [r.generated_text for r in generate(mcore_inference, mcore_tokenizer, prompts)] + outputs = [prompt + generation for prompt, generation in zip(prompts, generated)] + + logging.info(f'Sample generation after PTQ (with prompts): {outputs}') + + def quantize(self, model: MegatronParallel, forward_loop=None): """Quantize the model and calibrate using given forward loop.""" if forward_loop is None: get_dataloader = create_data_iterator_getter( @@ -181,7 +185,7 @@ def quantize(self, model: llm.GPTModel, forward_loop=None): self._setup(model) unwrapped_model = get_unwrapped_mcore_model(model) - decoder_type = self._get_decoder_type(unwrapped_model.config) + decoder_type = self._get_decoder_type(model) quant_cfg = QUANT_CFG_CHOICES[algorithm] if "awq" in algorithm: weight_quantizer = quant_cfg["quant_cfg"]["*weight_quantizer"] @@ -226,11 +230,16 @@ def quantize(self, model: llm.GPTModel, forward_loop=None): if dist.get_rank() == 0: mtq.print_quant_summary(unwrapped_model) + if self.export_config.generate_sample: + logging.info("Generating a sample output after model quantization.") + self._generate_sample(model) + return model def create_megatron_forward_loop( self, get_dataloader, num_batches, seq_length=None, micro_batch_size=None, decoder_seq_length=None ): + """Create a forward loop for over a given data iterator.""" from megatron.core.pipeline_parallel.schedules import get_forward_backward_func forward_backward_func = get_forward_backward_func() @@ -261,40 +270,54 @@ def loop(model): return loop - def export(self, model: llm.GPTModel, model_dir: str) -> None: - assert self.export_config is not None, "Export config is not set" - # TODO: Add sample generate - # TODO: Support megatron_amp_O2 + @staticmethod + def _validate_quantized_checkpoint(checkpoint_dir: Path, tensor_parallelism_size: int) -> bool: + """Basic validation of the model structure.""" + + saved_config = (checkpoint_dir / 'config.json').exists() + saved_weights = True + for i in range(tensor_parallelism_size): + saved_weights &= (checkpoint_dir / f'rank{i}.safetensors').exists() + + export_successful = saved_config and saved_weights + if not export_successful: + logging.error("Failed to export the quantized model.") + return export_successful + + def export(self, model: MegatronParallel, model_dir: str) -> None: + """Export model to a TensorRT-LLM checkpoint.""" export_dir = self.export_config.path - use_nfs_workspace = (model.trainer._fabric.__io__.num_nodes > 1) or ( - model.config.pipeline_model_parallel_size > 1 - ) + inference_tp = self.export_config.inference_tensor_parallel + inference_pp = self.export_config.inference_pipeline_parallel + + use_nfs_workspace = model.config.pipeline_model_parallel_size > 1 export_tensorrt_llm_checkpoint( model=get_unwrapped_mcore_model(model), - decoder_type=self._get_decoder_type(model.config), + decoder_type=self._get_decoder_type(model), dtype=self.torch_dtype, export_dir=export_dir, - inference_tensor_parallel=self.export_config.inference_tensor_parallel, - inference_pipeline_parallel=self.export_config.inference_pipeline_parallel, + inference_tensor_parallel=inference_tp, + inference_pipeline_parallel=inference_pp, use_nfs_workspace=use_nfs_workspace, ) + dist.barrier() # Save the model context in order to restore its tokenizer later. The destination # path is "nemo_context" as this name is used in nemo.export to setup tokenizer. - shutil.copytree( - os.path.join(model_dir, CONTEXT_PATH), - os.path.join(export_dir, "nemo_context"), - dirs_exist_ok=True, - ) - logging.info(f"Model context saved.") - - logging.info(f"Export succeeded, model has been exported to {export_dir}.") + if dist.get_rank() == 0: + assert self._validate_quantized_checkpoint(export_dir, inference_tp) + shutil.copytree( + ckpt_to_context_subdir(model_dir), + os.path.join(export_dir, "nemo_context"), + dirs_exist_ok=True, + ) + logging.info(f"Export succeeded, model has been exported to {export_dir}.") def get_calib_data_iter( data: str = "cnn_dailymail", batch_size: int = 64, calib_size: int = 512, max_sequence_length: int = 512 ): - """Creates a sample data iterator for calibration""" + """Creates a sample data iterator for calibration.""" if data == "wikitext": dataset = load_dataset("wikitext", "wikitext-103-v1", split="train") text_column = "text" @@ -314,7 +337,9 @@ def get_calib_data_iter( def create_data_iterator_getter(model, dataset, seq_len, batch_size, calibration_size): - def _iterator(): + """Create a function that provides iterator over a given dataset.""" + + def _get_iterator(): CHARACTERS_PER_TOKEN = 4 dataloader = get_calib_data_iter( @@ -323,14 +348,13 @@ def _iterator(): batch_size=batch_size, calib_size=calibration_size, ) + + data = [] for batch in dataloader: batch = [model.tokenizer.text_to_ids(text)[:seq_len] for text in batch] batch = [ids + (seq_len - len(ids)) * [model.tokenizer.eos] for ids in batch] - yield torch.tensor(batch, device=model.device) + data.append(torch.tensor(batch, device=model.device)) - def _iterator_getter(): - dataloader = _iterator() - dataloader = [data for data in dataloader] - return iter(tqdm(dataloader)) + return iter(tqdm(data)) - return _iterator_getter + return _get_iterator diff --git a/nemo/collections/llm/quantization/utils.py b/nemo/collections/llm/quantization/utils.py index c4c533fe38d0..20739c872e80 100644 --- a/nemo/collections/llm/quantization/utils.py +++ b/nemo/collections/llm/quantization/utils.py @@ -18,12 +18,38 @@ from nemo import lightning as nl from nemo.collections import llm +from nemo.collections.llm.inference.base import _setup_trainer_and_restore_model from nemo.lightning.ckpt_utils import ckpt_to_context_subdir from nemo.utils import logging +def get_modelopt_decoder_type(model: llm.GPTModel) -> str: + """Infers the modelopt decoder type from GPTModel subclass.""" + mapping = [ + (llm.Baichuan2Model, "baichuan"), + (llm.ChatGLMModel, "chatglm"), + (llm.Gemma2Model, "gemma2"), + (llm.GemmaModel, "gemma"), + (llm.LlamaModel, "llama"), + (llm.MistralModel, "llama"), + (llm.MixtralModel, "llama"), + (llm.NemotronModel, "gptnext"), + (llm.Qwen2Model, "qwen"), + (llm.StarcoderModel, "gptnext"), + (llm.Starcoder2Model, "gptnext"), + (llm.Phi3Model, "phi3"), + ] + + for config_class, decoder_type in mapping: + if isinstance(model, config_class): + return decoder_type + + logging.warning("Could not infer the decoder type") + return None + + def quantizable_model_config(model_cfg: llm.GPTConfig) -> llm.GPTConfig: - """Modify model config for TensorRT Model Optimizer""" + """Modify model config for TensorRT-Model-Optimizer quantization""" from nemo.collections.nlp.models.language_modeling.megatron.gpt_layer_modelopt_spec import ( get_gpt_layer_modelopt_spec, @@ -42,25 +68,47 @@ def quantizable_model_config(model_cfg: llm.GPTConfig) -> llm.GPTConfig: return model_cfg -def load_with_modelopt_layer_spec(nemo_checkpoint_path: str, calib_tp: int = 1, calib_pp: int = 1) -> llm.GPTModel: +def load_with_modelopt_layer_spec( + nemo_checkpoint_path: str, calib_tp: int = 1, calib_pp: int = 1, inference_only: bool = True +): + """Loads a model from a NeMo 2.0 checkpoint using modelopt layer spec.""" + # TODO: setting ddp="pytorch" and deleting model.optim is a hackish way to disable DDP initialization. + # Needs a systematic solution. + if inference_only: + strategy = nl.MegatronStrategy( + tensor_model_parallel_size=calib_tp, + pipeline_model_parallel_size=calib_pp, + pipeline_dtype=torch.bfloat16, + ckpt_load_optimizer=False, + ckpt_parallel_save_optim=False, + setup_optimizers=False, + lazy_init=True, + ddp="pytorch", + ) + else: + strategy = nl.MegatronStrategy( + tensor_model_parallel_size=calib_tp, pipeline_model_parallel_size=calib_pp, pipeline_dtype=torch.bfloat16 + ) + trainer = nl.Trainer( devices=calib_tp, num_nodes=calib_pp, - strategy=nl.MegatronStrategy( - tensor_model_parallel_size=calib_tp, pipeline_model_parallel_size=calib_pp, pipeline_dtype=torch.bfloat16 - ), - plugins=nl.MegatronMixedPrecision(precision='bf16', pipeline_dtype=torch.bfloat16, autocast_enabled=True), + strategy=strategy, + plugins=nl.MegatronMixedPrecision(precision='bf16', params_dtype=torch.bfloat16, autocast_enabled=True), ) - fabric = trainer.to_fabric() - fabric.launch() - model_path = Path(nemo_checkpoint_path) - model = nl.io.load_context(ckpt_to_context_subdir(model_path)).model + model = nl.io.load_context(path=ckpt_to_context_subdir(model_path), subpath="model") model.config = quantizable_model_config(model.config) - return fabric.load_model(nemo_checkpoint_path, model=model) + + if inference_only: + del model.optim + + _setup_trainer_and_restore_model(nemo_checkpoint_path, trainer, model) + return model -def get_unwrapped_mcore_model(model: llm.GPTModel): +def get_unwrapped_mcore_model(model): + """Unwraps NeMo 2.0 to base MCore model.""" from megatron.core.models.gpt import GPTModel as MCoreGPTModel unwrapped_model = model diff --git a/nemo/collections/llm/recipes/__init__.py b/nemo/collections/llm/recipes/__init__.py index 8f772e3da5b7..1db88f633e89 100644 --- a/nemo/collections/llm/recipes/__init__.py +++ b/nemo/collections/llm/recipes/__init__.py @@ -33,6 +33,8 @@ llama31_8b, llama31_70b, llama31_405b, + llama32_1b, + llama32_3b, mamba2_1_3b, mamba2_2_7b, mamba2_8b, @@ -56,6 +58,7 @@ nemotron4_15b_16k, nemotron4_15b_64k, nemotron4_340b, + phi3_mini_4k_instruct, qwen2, qwen2_1p5b, qwen2_7b, @@ -72,6 +75,7 @@ ) from nemo.collections.llm.recipes.log.default import default_log, default_resume from nemo.collections.llm.recipes.optim import adam +from nemo.collections.llm.recipes.run.executor import torchrun __all__ = [ "baichuan2_7b", @@ -87,6 +91,8 @@ "llama31_8b", "llama31_70b", "llama31_405b", + "llama32_1b", + "llama32_3b", "mamba2_130m", "mamba2_370m", "mamba2_780m", @@ -111,6 +117,7 @@ "nemotron4_15b_16k", "nemotron4_15b_64k", "nemotron4_340b", + "phi3_mini_4k_instruct", "t5_220m", "t5_3b", "t5_11b", @@ -132,4 +139,5 @@ "adam", "default_log", "default_resume", + "torchrun", ] diff --git a/nemo/collections/llm/recipes/baichuan2_7b.py b/nemo/collections/llm/recipes/baichuan2_7b.py index 20de2c73f9dd..1350cbaa7edd 100644 --- a/nemo/collections/llm/recipes/baichuan2_7b.py +++ b/nemo/collections/llm/recipes/baichuan2_7b.py @@ -15,17 +15,17 @@ from typing import Callable, Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch +from lightning.pytorch.callbacks.callback import Callback from megatron.core.distributed import DistributedDataParallelConfig -from pytorch_lightning.callbacks.callback import Callback from nemo import lightning as nl from nemo.collections.llm import Baichuan2Config7B, Baichuan2Model from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing @@ -254,8 +254,10 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. - packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training efficiency. Default sequence length is 2048. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. + packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training + efficiency. Default sequence length is 2048. Returns: run.Partial: Partial configuration for fine-tuning. @@ -279,8 +281,8 @@ def finetune_recipe( if peft_scheme is None or peft_scheme.lower() == 'none': recipe.trainer.strategy.tensor_model_parallel_size = 2 recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) recipe.optim.config.lr = 1e-4 else: raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") diff --git a/nemo/collections/llm/recipes/callbacks/__init__.py b/nemo/collections/llm/recipes/callbacks/__init__.py new file mode 100644 index 000000000000..d9155f923f18 --- /dev/null +++ b/nemo/collections/llm/recipes/callbacks/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/llm/recipes/callbacks/common.py b/nemo/collections/llm/recipes/callbacks/common.py new file mode 100644 index 000000000000..72a1b3a0c640 --- /dev/null +++ b/nemo/collections/llm/recipes/callbacks/common.py @@ -0,0 +1,53 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from nemo_run import Config, cli + +from nemo.utils.import_utils import safe_import + +res_module, HAVE_RES = safe_import('nvidia_resiliency_ext.ptl_resiliency') + + +@cli.factory(is_target_default=True) +def straggler_det_callback( + straggler_report_time_interval: Optional[int] = 300, stop_if_detected_straggler: Optional[bool] = True +) -> Config[res_module.StragglerDetectionCallback]: + """ + This callback is used to detect slower ranks participating in a PyTorch distributed workload. + This callback is obtained from nvidia-resiliency-ext. + Performance scores are scalar values from 0.0 (worst) to 1.0 (best), reflecting each rank's performance. + A performance score can be interpreted as the ratio of current performance to reference performance. + Depending on the reference used, there are two types of performance scores: + Relative performance score: The best-performing rank in the workload is used as a reference. + Individual performance score: The best historical performance of the rank is used as a reference. + If the performance score drops below the threshold which is set to 0.7, it is deemed as a straggler. + To detect the stragglers, users can enable this callback which reports the performance scores every 5mins. + Args: + straggler_report_time_interval (int): Performance score reporting frequency in seconds, Default is 300 seconds. + stop_if_detected_straggler (bool): Whether to stop training if a straggler is detection. Default is True. + """ + + return Config( + res_module.StragglerDetectionCallback, + report_time_interval=straggler_report_time_interval, + calc_relative_gpu_perf=True, + calc_individual_gpu_perf=True, + num_gpu_perf_scores_to_print=5, + gpu_relative_perf_threshold=0.7, + gpu_individual_perf_threshold=0.7, + stop_if_detected=stop_if_detected_straggler, + enable_ptl_logging=True, + ) diff --git a/nemo/collections/llm/recipes/chatglm3_6b.py b/nemo/collections/llm/recipes/chatglm3_6b.py index ef815a0851fc..2cd424ce5bf6 100644 --- a/nemo/collections/llm/recipes/chatglm3_6b.py +++ b/nemo/collections/llm/recipes/chatglm3_6b.py @@ -15,17 +15,17 @@ from typing import Callable, Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch +from lightning.pytorch.callbacks.callback import Callback from megatron.core.distributed import DistributedDataParallelConfig -from pytorch_lightning.callbacks.callback import Callback from nemo import lightning as nl from nemo.collections.llm import ChatGLM3Config6B, ChatGLMModel from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing @@ -254,8 +254,10 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. - packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training efficiency. Default sequence length is 2048. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. + packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training + efficiency. Default sequence length is 2048. Returns: run.Partial: Partial configuration for fine-tuning. @@ -279,8 +281,8 @@ def finetune_recipe( if peft_scheme is None or peft_scheme.lower() == 'none': recipe.trainer.strategy.tensor_model_parallel_size = 2 recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) recipe.optim.config.lr = 1e-4 else: raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") diff --git a/nemo/collections/llm/recipes/finetune_default.py b/nemo/collections/llm/recipes/finetune_default.py index a060046a8bdf..e8af7f67bdbd 100644 --- a/nemo/collections/llm/recipes/finetune_default.py +++ b/nemo/collections/llm/recipes/finetune_default.py @@ -14,16 +14,18 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch import nemo.lightning as nl from nemo.collections import llm from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs +from nemo.collections.llm.peft import DoRA, LoRA from nemo.collections.llm.recipes.log.default import tensorboard_logger from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed +from nemo.lightning.pytorch.callbacks import PEFT def default_finetune_recipe( @@ -158,3 +160,41 @@ def nemo_resume(model_id: str) -> run.Config[nl.AutoResume]: nl.AutoResume, restore_config=run.Config(nl.RestoreConfig, path=f"nemo://{model_id}"), ) + + +@run.cli.factory(name='lora') +def lora() -> run.Config[PEFT]: + """ + Factory function to create a LoRA configuration. + + Returns: + run.Config[PEFT]: Configuration for the LoRA class. + + Examples: + CLI usage: + $ nemo llm finetune -f llama3_8b peft=lora + + Python API usage: + >>> lora_config = lora() + >>> print(lora_config) + """ + return run.Config(LoRA) + + +@run.cli.factory(name='dora') +def dora() -> run.Config[PEFT]: + """ + Factory function to create a DoRA configuration. + + Returns: + run.Config[PEFT]: Configuration for the DoRA class. + + Examples: + CLI usage: + $ nemo llm finetune -f llama3_8b peft=dora + + Python API usage: + >>> dora_config = dora() + >>> print(dora_config) + """ + return run.Config(DoRA) diff --git a/nemo/collections/llm/recipes/gemma2.py b/nemo/collections/llm/recipes/gemma2.py index 6fd1be83c183..2a690dc556d8 100644 --- a/nemo/collections/llm/recipes/gemma2.py +++ b/nemo/collections/llm/recipes/gemma2.py @@ -14,11 +14,11 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch +from lightning.pytorch.callbacks.callback import Callback from megatron.core.distributed import DistributedDataParallelConfig -from pytorch_lightning.callbacks.callback import Callback from nemo import lightning as nl from nemo.collections.llm.gpt.model.gemma2 import Gemma2Config2B, Gemma2Config9B, Gemma2Config27B, Gemma2Model diff --git a/nemo/collections/llm/recipes/gemma2_27b.py b/nemo/collections/llm/recipes/gemma2_27b.py index 6f852f0fe6cf..d6b41c0a221c 100644 --- a/nemo/collections/llm/recipes/gemma2_27b.py +++ b/nemo/collections/llm/recipes/gemma2_27b.py @@ -14,13 +14,13 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.gemma2 import gemma2_model, gemma2_trainer from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger @@ -191,8 +191,10 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. - packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training efficiency. Default sequence length is 2048. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. + packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training + efficiency. Default sequence length is 2048. Returns: run.Partial: Partial configuration for fine-tuning. @@ -220,8 +222,8 @@ def finetune_recipe( recipe.optim.config.lr = 5e-6 recipe.trainer.strategy.tensor_model_parallel_size = 8 recipe.trainer.strategy.pipeline_model_parallel_size = 2 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) recipe.trainer.strategy.tensor_model_parallel_size = 4 recipe.optim.config.lr = 1e-4 else: diff --git a/nemo/collections/llm/recipes/gemma2_2b.py b/nemo/collections/llm/recipes/gemma2_2b.py index 98c795591774..138140d0515d 100644 --- a/nemo/collections/llm/recipes/gemma2_2b.py +++ b/nemo/collections/llm/recipes/gemma2_2b.py @@ -14,13 +14,13 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.gemma2 import gemma2_model, gemma2_trainer from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger @@ -191,8 +191,10 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. - packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training efficiency. Default sequence length is 2048. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. + packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training + efficiency. Default sequence length is 2048. Returns: run.Partial: Partial configuration for fine-tuning. @@ -218,8 +220,8 @@ def finetune_recipe( if peft_scheme is None or peft_scheme.lower() == 'none': recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) recipe.optim.config.lr = 1e-4 else: raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") diff --git a/nemo/collections/llm/recipes/gemma2_9b.py b/nemo/collections/llm/recipes/gemma2_9b.py index a211d8cfa838..c49ac0246307 100644 --- a/nemo/collections/llm/recipes/gemma2_9b.py +++ b/nemo/collections/llm/recipes/gemma2_9b.py @@ -14,13 +14,13 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.gemma2 import gemma2_model, gemma2_trainer from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger @@ -191,8 +191,10 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. - packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training efficiency. Default sequence length is 2048. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. + packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training + efficiency. Default sequence length is 2048. Returns: run.Partial: Partial configuration for fine-tuning. @@ -219,8 +221,8 @@ def finetune_recipe( if peft_scheme is None or peft_scheme.lower() == 'none': recipe.optim.config.lr = 5e-6 recipe.trainer.strategy.tensor_model_parallel_size = 4 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) recipe.optim.config.lr = 1e-4 else: raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") diff --git a/nemo/collections/llm/recipes/gemma_2b.py b/nemo/collections/llm/recipes/gemma_2b.py index 8b2111e9f7c4..8bdf89696d56 100644 --- a/nemo/collections/llm/recipes/gemma_2b.py +++ b/nemo/collections/llm/recipes/gemma_2b.py @@ -14,17 +14,17 @@ import os from typing import Callable, Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch +from lightning.pytorch.callbacks.callback import Callback from megatron.core.distributed import DistributedDataParallelConfig -from pytorch_lightning.callbacks.callback import Callback from nemo import lightning as nl from nemo.collections.llm import GemmaConfig2B, GemmaModel from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing @@ -253,8 +253,10 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. - packed_sequence (Optional[bool]): If true, fine-tuning sequences will be packed into batches up to the given maximum seq_length for better efficiency. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. + packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training + efficiency. Default sequence length is 2048. Returns: run.Partial: Partial configuration for fine-tuning. @@ -284,8 +286,8 @@ def finetune_recipe( if peft_scheme is None or peft_scheme.lower() == 'none': recipe.trainer.strategy.context_parallel_size = 2 recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) recipe.optim.config.lr = 1e-4 else: raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") diff --git a/nemo/collections/llm/recipes/gemma_7b.py b/nemo/collections/llm/recipes/gemma_7b.py index 44efb3fe56b8..46c91e27575a 100644 --- a/nemo/collections/llm/recipes/gemma_7b.py +++ b/nemo/collections/llm/recipes/gemma_7b.py @@ -14,17 +14,17 @@ import os from typing import Callable, Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch +from lightning.pytorch.callbacks.callback import Callback from megatron.core.distributed import DistributedDataParallelConfig -from pytorch_lightning.callbacks.callback import Callback from nemo import lightning as nl from nemo.collections.llm import GemmaConfig7B, GemmaModel from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing @@ -256,8 +256,10 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. - packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training efficiency. Default sequence length is 2048. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. + packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training + efficiency. Default sequence length is 2048. Returns: run.Partial: Partial configuration for fine-tuning. @@ -287,8 +289,8 @@ def finetune_recipe( if peft_scheme is None or peft_scheme.lower() == 'none': recipe.trainer.strategy.tensor_model_parallel_size = 2 recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) recipe.optim.config.lr = 1e-4 else: raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") diff --git a/nemo/collections/llm/recipes/gpt3_175b.py b/nemo/collections/llm/recipes/gpt3_175b.py index 5932ce5346b9..189f0ca6baf1 100644 --- a/nemo/collections/llm/recipes/gpt3_175b.py +++ b/nemo/collections/llm/recipes/gpt3_175b.py @@ -15,11 +15,11 @@ from typing import Callable, Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch +from lightning.pytorch.callbacks.callback import Callback from megatron.core.distributed import DistributedDataParallelConfig -from pytorch_lightning.callbacks.callback import Callback from nemo import lightning as nl from nemo.collections.llm.api import pretrain diff --git a/nemo/collections/llm/recipes/hf_auto_model_for_causal_lm.py b/nemo/collections/llm/recipes/hf_auto_model_for_causal_lm.py index f5a52cd351be..5d2bea23686c 100644 --- a/nemo/collections/llm/recipes/hf_auto_model_for_causal_lm.py +++ b/nemo/collections/llm/recipes/hf_auto_model_for_causal_lm.py @@ -15,15 +15,15 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch -from pytorch_lightning.callbacks.callback import Callback +from lightning.pytorch.callbacks.callback import Callback from nemo import lightning as nl from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule -from nemo.collections.llm.gpt.model.hf_auto_model_for_causal_lm import HfAutoModelForCausalLM +from nemo.collections.llm.gpt.model.hf_auto_model_for_causal_lm import HFAutoModelForCausalLM from nemo.collections.llm.peft.lora import LoRA from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import pytorch_adam_with_cosine_annealing @@ -35,23 +35,23 @@ @run.cli.factory(name=NAME) def model(model_name, load_pretrained_weights) -> run.Config[pl.LightningModule]: """ - Factory function to create HfAutoModelForCausalLM model configurations. + Factory function to create HFAutoModelForCausalLM model configurations. Args: model_name (str): Model id on HF. Returns: - run.Config[pl.LightningModule]: Configuration for the HfAutoModelForCausalLM. + run.Config[pl.LightningModule]: Configuration for the HFAutoModelForCausalLM. Examples: CLI usage: - $ nemo llm pretrain --factory 'HfAutoModelForCausalLM(model_name="mistralai/Mistral-Nemo-Instruct-2407")' + $ nemo llm pretrain --factory 'HFAutoModelForCausalLM(model_name="mistralai/Mistral-Nemo-Instruct-2407")' Python API usage: >>> model_config = model(model_name="mistralai/Mistral-Nemo-Instruct-2407") >>> print(model_config) """ - return run.Config(HfAutoModelForCausalLM, model_name=model_name, load_pretrained_weights=load_pretrained_weights) + return run.Config(HFAutoModelForCausalLM, model_name=model_name, load_pretrained_weights=load_pretrained_weights) def trainer( @@ -69,7 +69,7 @@ def trainer( gradient_clip_val: float = 1.0, ) -> run.Config[nl.Trainer]: """ - Configure the NeMo Lightning Trainer for HfAutoModelForCausalLM. + Configure the NeMo Lightning Trainer for HFAutoModelForCausalLM. This function sets up the distributed training strategy and other training parameters. @@ -91,7 +91,7 @@ def trainer( Examples: CLI usage: - $ nemo llm pretrain trainer=HfAutoModelForCausalLM ... + $ nemo llm pretrain trainer=HFAutoModelForCausalLM ... Python API usage: >>> trainer_config = trainer(num_nodes=2, num_gpus_per_node=8) @@ -131,7 +131,7 @@ def pretrain_recipe( model_name: str = '', ) -> run.Partial: """ - Create a pre-training recipe for a HfAutoModelForCausalLM model. + Create a pre-training recipe for a HFAutoModelForCausalLM model. This function sets up a complete configuration for pre-training, including model, trainer, data, logging, optimization, and resumption settings. @@ -148,7 +148,7 @@ def pretrain_recipe( Examples: CLI usage: - $ nemo llm pretrain --factory 'HfAutoModelForCausalLM(model_name="mistralai/Mistral-Nemo-Instruct-2407")' + $ nemo llm pretrain --factory 'HFAutoModelForCausalLM(model_name="mistralai/Mistral-Nemo-Instruct-2407")' Python API usage: >>> recipe = pretrain_recipe(name="auto_pretrain", num_nodes=2, model_name="mistralai/Mistral-Nemo-Instruct-2407") @@ -179,7 +179,7 @@ def finetune_recipe( model_name: str = '', ) -> run.Partial: """ - Create a fine-tuning recipe for a HfAutoModelForCausalLM model. + Create a fine-tuning recipe for a HFAutoModelForCausalLM model. This function sets up a complete configuration for fine-tuning, including model, trainer, data, logging, optimization, and resumption settings. diff --git a/nemo/collections/llm/recipes/llama31_405b.py b/nemo/collections/llm/recipes/llama31_405b.py index 31c83713b6e7..5f08d82bd888 100644 --- a/nemo/collections/llm/recipes/llama31_405b.py +++ b/nemo/collections/llm/recipes/llama31_405b.py @@ -15,18 +15,18 @@ from typing import Callable, Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch +from lightning.pytorch.callbacks.callback import Callback from megatron.core.distributed import DistributedDataParallelConfig -from pytorch_lightning.callbacks.callback import Callback from nemo import lightning as nl from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs from nemo.collections.llm.gpt.model.llama import Llama31Config405B, LlamaModel -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing @@ -266,7 +266,8 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for finetuning. Allowed values: 'lora'/'none'/None. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. seq_length (int): Maximum number of tokens per microbatch. packed_sequence (Optional[bool]): If true, fine-tuning sequences will be packed into batches up to the given maximum seq_length for better efficiency. By default, this value equals performance_mode. @@ -296,7 +297,7 @@ def finetune_recipe( if num_nodes is None: if peft_scheme is None or peft_scheme.lower() == 'none': num_nodes = 12 - elif peft_scheme.lower() == 'lora': + elif peft_scheme.lower() in ['lora', 'dora']: num_nodes = 3 recipe = default_finetune_recipe( @@ -307,11 +308,10 @@ def finetune_recipe( recipe.trainer.strategy.pipeline_model_parallel_size = 14 recipe.data.global_batch_size = 6 recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) recipe.peft.dim = 16 recipe.peft.alpha = 32 - recipe.peft.target_modules = ['linear_qkv'] recipe.optim.config.use_distributed_optimizer = False # some settings currently do not function correctly with LoRA @@ -349,7 +349,8 @@ def finetune_performance_optimizations( Args: recipe (run.Partial): Base fine-tuning recipe to which performance optimizations will be added - peft_scheme (str): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. Returns: run.Partial: Partial configuration for performance-optimized fine-tuning. @@ -387,6 +388,7 @@ def finetune_performance_optimizations( recipe.trainer.strategy.tensor_model_parallel_size = 4 recipe.trainer.strategy.pipeline_model_parallel_size = 6 recipe.trainer.strategy.virtual_pipeline_model_parallel_size = 7 + recipe.peft.target_modules = ['linear_qkv'] recipe.trainer.strategy.sequence_parallel = True diff --git a/nemo/collections/llm/recipes/llama31_70b.py b/nemo/collections/llm/recipes/llama31_70b.py index 91e4e10c83e6..3120fedd7923 100644 --- a/nemo/collections/llm/recipes/llama31_70b.py +++ b/nemo/collections/llm/recipes/llama31_70b.py @@ -15,18 +15,18 @@ from typing import Callable, Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch +from lightning.pytorch.callbacks.callback import Callback from megatron.core.distributed import DistributedDataParallelConfig -from pytorch_lightning.callbacks.callback import Callback from nemo import lightning as nl from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs from nemo.collections.llm.gpt.model.llama import Llama31Config70B, LlamaModel -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing @@ -266,7 +266,8 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for finetuning. Allowed values: 'lora'/'none'/None. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. seq_length (int): Maximum number of tokens per microbatch. packed_sequence (Optional[bool]): If true, fine-tuning sequences will be packed into batches up to the given maximum seq_length for better efficiency. By default, this value equals performance_mode. @@ -300,7 +301,7 @@ def finetune_recipe( if num_nodes is None: if peft_scheme is None or peft_scheme.lower() == 'none': num_nodes = 4 - elif peft_scheme.lower() == 'lora': + elif peft_scheme.lower() in ['lora', 'dora']: num_nodes = 1 recipe = default_finetune_recipe( @@ -310,11 +311,10 @@ def finetune_recipe( recipe.trainer.strategy.tensor_model_parallel_size = 8 recipe.trainer.strategy.pipeline_model_parallel_size = 4 recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) recipe.peft.dim = 16 recipe.peft.alpha = 32 - recipe.peft.target_modules = ['linear_qkv'] recipe.optim.config.use_distributed_optimizer = False # some settings currently do not function correctly with LoRA @@ -350,7 +350,8 @@ def finetune_performance_optimizations( Args: recipe (run.Partial): Base fine-tuning recipe to which performance optimizations will be added - peft_scheme (str): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. Returns: run.Partial: Partial configuration for performance-optimized fine-tuning. @@ -388,6 +389,7 @@ def finetune_performance_optimizations( recipe.trainer.strategy.tensor_model_parallel_size = 2 recipe.trainer.strategy.pipeline_model_parallel_size = 4 recipe.trainer.strategy.virtual_pipeline_model_parallel_size = 5 + recipe.peft.target_modules = ['linear_qkv'] recipe.trainer.strategy.sequence_parallel = True diff --git a/nemo/collections/llm/recipes/llama31_8b.py b/nemo/collections/llm/recipes/llama31_8b.py index a4f0082e8535..62514940b678 100644 --- a/nemo/collections/llm/recipes/llama31_8b.py +++ b/nemo/collections/llm/recipes/llama31_8b.py @@ -15,18 +15,18 @@ from typing import Callable, Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch +from lightning.pytorch.callbacks.callback import Callback from megatron.core.distributed import DistributedDataParallelConfig -from pytorch_lightning.callbacks.callback import Callback from nemo import lightning as nl from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs from nemo.collections.llm.gpt.model.llama import Llama31Config8B, LlamaModel -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing @@ -266,7 +266,8 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for finetuning. Allowed values: 'lora'/'none'/None. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. seq_length (int): Maximum number of tokens per microbatch. packed_sequence (Optional[bool]): If true, fine-tuning sequences will be packed into batches up to the given maximum seq_length for better efficiency. By default, this value equals performance_mode. @@ -303,11 +304,10 @@ def finetune_recipe( if peft_scheme is None or peft_scheme.lower() == 'none': recipe.trainer.strategy.tensor_model_parallel_size = 2 recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) recipe.peft.dim = 8 recipe.peft.alpha = 16 - recipe.peft.target_modules = ['linear_qkv'] recipe.optim.config.use_distributed_optimizer = False # some settings currently do not function correctly with LoRA @@ -342,7 +342,8 @@ def finetune_performance_optimizations( Args: recipe (run.Partial): Base fine-tuning recipe to which performance optimizations will be added - peft_scheme (str): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. Returns: run.Partial: Partial configuration for performance-optimized fine-tuning. @@ -372,6 +373,8 @@ def finetune_performance_optimizations( tp_comm_overlap=False, ) ) + else: + recipe.peft.target_modules = ['linear_qkv'] recipe.trainer.callbacks.append(run.Config(TimingCallback)) recipe.trainer.callbacks.append( diff --git a/nemo/collections/llm/recipes/llama32_1b.py b/nemo/collections/llm/recipes/llama32_1b.py new file mode 100644 index 000000000000..32675adf3686 --- /dev/null +++ b/nemo/collections/llm/recipes/llama32_1b.py @@ -0,0 +1,270 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Callable, Optional + +import lightning.pytorch as pl +import nemo_run as run +import torch +from lightning.pytorch.callbacks.callback import Callback +from megatron.core.distributed import DistributedDataParallelConfig + +from nemo import lightning as nl +from nemo.collections.llm.api import finetune, pretrain +from nemo.collections.llm.gpt.data.mock import MockDataModule +from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs +from nemo.collections.llm.gpt.model.llama import Llama32Config1B, LlamaModel +from nemo.collections.llm.peft import PEFT_STR2CLS +from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe +from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger +from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing +from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed +from nemo.utils.exp_manager import TimingCallback + +NAME = "llama32_1b" + + +@run.cli.factory(name=NAME) +def model() -> run.Config[pl.LightningModule]: + """ + Factory function to create a Llama3.2 1B model configuration. + + Returns: + run.Config[pl.LightningModule]: Configuration for the Llama3.2 1B model. + + Examples: + CLI usage: + $ nemo llm pretrain model=llama32_1b ... + + Python API usage: + >>> model_config = model() + >>> print(model_config) + """ + conf = run.Config(Llama32Config1B) + conf.seq_length = 8192 + return run.Config(LlamaModel, config=conf) + + +def trainer( + tensor_parallelism: int = 1, + pipeline_parallelism: int = 1, + pipeline_parallelism_type: Optional[torch.dtype] = None, + virtual_pipeline_parallelism: Optional[int] = None, + context_parallelism: int = 1, + sequence_parallelism: bool = False, + num_nodes: int = 1, + num_gpus_per_node: int = 8, + max_steps: int = 1168251, + callbacks: Optional[list[run.Config[Callback]]] = None, +) -> run.Config[nl.Trainer]: + """ + Configure the NeMo Lightning Trainer for Llama3.2 1B model. + + Args: + tensor_parallelism (int): Degree of tensor model parallelism. + pipeline_parallelism (int): Degree of pipeline model parallelism. + pipeline_parallelism_type (Optional[torch.dtype]): Data type for pipeline parallelism. + virtual_pipeline_parallelism (Optional[int]): Size of virtual pipeline parallelism. + context_parallelism (int): Degree of context parallelism. + sequence_parallelism (bool): Whether to use sequence parallelism. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + max_steps (int): Maximum number of training steps. + callbacks (Optional[list[run.Config[Callback]]]): List of callback configurations. + + Returns: + run.Config[nl.Trainer]: Configuration for the NeMo Lightning Trainer. + + Examples: + CLI usage: + $ nemo llm pretrain trainer=llama32_1b ... + + Python API usage: + >>> trainer_config = trainer(num_nodes=1, num_gpus_per_node=1) + >>> print(trainer_config) + + Note: + This configuration uses extensive parallelism to handle the large model size efficiently. + """ + strategy = run.Config( + nl.MegatronStrategy, + tensor_model_parallel_size=tensor_parallelism, + pipeline_model_parallel_size=pipeline_parallelism, + pipeline_dtype=pipeline_parallelism_type, + virtual_pipeline_model_parallel_size=virtual_pipeline_parallelism, + context_parallel_size=context_parallelism, + sequence_parallel=sequence_parallelism, + gradient_as_bucket_view=True, + ckpt_async_save=True, + ckpt_parallel_load=True, + ddp=run.Config( + DistributedDataParallelConfig, + check_for_nan_in_grad=True, + grad_reduce_in_fp32=True, + overlap_grad_reduce=True, + overlap_param_gather=True, + average_in_collective=True, + ), + ) + + trainer = run.Config( + nl.Trainer, + accelerator="gpu", + accumulate_grad_batches=1, + callbacks=callbacks, + devices=num_gpus_per_node, + limit_test_batches=50, + limit_val_batches=32, + log_every_n_steps=10, + max_steps=max_steps, + num_nodes=num_nodes, + plugins=bf16_mixed(), + strategy=strategy, + use_distributed_sampler=False, + val_check_interval=2000, + ) + + return trainer + + +@run.cli.factory(target=pretrain, name=NAME) +def pretrain_recipe( + dir: Optional[str] = None, + name: str = "default", + num_nodes: int = 1, + num_gpus_per_node: int = 8, + fn: Callable = pretrain, +) -> run.Partial: + """ + Create a pre-training recipe for Llama3.2 1B model. + + This function sets up a complete configuration for pre-training, including + model, trainer, data, logging, optimization, and resumption settings. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the pre-training run. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + fn (Callable): The pre-training function to use. + + Returns: + run.Partial: Partial configuration for pre-training. + + Examples: + CLI usage: + $ nemo llm pretrain --factory llama32_1b + $ nemo llm pretrain --factory "llama32_1b(num_nodes=1, name='my_1b_pretrain')" + + Python API usage: + >>> recipe = pretrain_recipe(name="llama32_1b_pretrain", num_nodes=1) + >>> print(recipe) + + Note: + This recipe is optimized for the large 8B model and requires significant computational resources. + """ + recipe = run.Partial( + fn, + model=model(), + trainer=trainer( + num_nodes=num_nodes, + num_gpus_per_node=num_gpus_per_node, + callbacks=[run.Config(TimingCallback)], + ), + data=run.Config(MockDataModule, seq_length=8192, global_batch_size=512, micro_batch_size=1), + log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), + optim=distributed_fused_adam_with_cosine_annealing(max_lr=3e-4), + resume=default_resume(), + ) + + return recipe + + +@run.cli.factory(target=finetune, name=NAME) +def finetune_recipe( + dir: Optional[str] = None, + name: str = "default", + num_nodes: int = 1, + num_gpus_per_node: int = 8, + peft_scheme: Optional[str] = 'lora', + seq_length: Optional[int] = None, + packed_sequence: Optional[bool] = None, +) -> run.Partial: + """ + Create a fine-tuning recipe for Llama3.2 1B model. + + This function sets up a complete configuration for fine-tuning, including + model, trainer, data, logging, optimization, and resumption settings. + The recipe uses LoRA (Low-Rank Adaptation) for efficient fine-tuning, unless peft_scheme is set to None. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the fine-tuning run. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. + seq_length (int): Maximum number of tokens per microbatch. + packed_sequence (Optional[bool]): If true, fine-tuning sequences will be packed into batches up to the given + maximum seq_length for better efficiency. + + Returns: + run.Partial: Partial configuration for fine-tuning. + + Examples: + CLI usage: + $ nemo llm finetune --factory llama32_1b + + Python API usage: + >>> recipe = finetune_recipe(name="llama32_1b_finetune", num_nodes=1) + >>> print(recipe) + + Note: + This recipe uses the SQuAD dataset for fine-tuning. For more information + on fine-tuning LLMs with NeMo, see the fine-tuning guide in the + `examples/llm/finetune/` directory. + """ + + # For unpacked sequence, most samples in SQuAD dataset are shorter than 2K + if seq_length is None: + seq_length = 4096 if packed_sequence else 2048 + + recipe = default_finetune_recipe( + model(), "meta-llama/Llama-3.2-1B", dir, name, num_nodes, num_gpus_per_node, packed_sequence + ) + if peft_scheme is None or peft_scheme.lower() == 'none': + recipe.trainer.strategy.tensor_model_parallel_size = 1 + recipe.optim.config.lr = 5e-6 + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) + recipe.peft.dim = 8 + recipe.peft.alpha = 16 + recipe.optim.config.use_distributed_optimizer = False + + # some settings currently do not function correctly with LoRA + recipe.model.config.cross_entropy_loss_fusion = False + + recipe.optim.config.lr = 1e-4 + else: + raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") + + # Sequence length settings in the model and dataset must agree + recipe.model.config.seq_length = seq_length + recipe.data.seq_length = seq_length + if packed_sequence: + recipe.data.dataset_kwargs = {'pad_to_max_length': True} + recipe.data.packed_sequence_specs = run.Config(PackedSequenceSpecs, packed_sequence_size=seq_length) + + return recipe diff --git a/nemo/collections/llm/recipes/llama32_3b.py b/nemo/collections/llm/recipes/llama32_3b.py new file mode 100644 index 000000000000..d78ea0b50983 --- /dev/null +++ b/nemo/collections/llm/recipes/llama32_3b.py @@ -0,0 +1,270 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Callable, Optional + +import lightning.pytorch as pl +import nemo_run as run +import torch +from lightning.pytorch.callbacks.callback import Callback +from megatron.core.distributed import DistributedDataParallelConfig + +from nemo import lightning as nl +from nemo.collections.llm.api import finetune, pretrain +from nemo.collections.llm.gpt.data.mock import MockDataModule +from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs +from nemo.collections.llm.gpt.model.llama import Llama32Config3B, LlamaModel +from nemo.collections.llm.peft import PEFT_STR2CLS +from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe +from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger +from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing +from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed +from nemo.utils.exp_manager import TimingCallback + +NAME = "llama32_3b" + + +@run.cli.factory(name=NAME) +def model() -> run.Config[pl.LightningModule]: + """ + Factory function to create a Llama3.2 3B model configuration. + + Returns: + run.Config[pl.LightningModule]: Configuration for the Llama3.2 3B model. + + Examples: + CLI usage: + $ nemo llm pretrain model=llama32_3b ... + + Python API usage: + >>> model_config = model() + >>> print(model_config) + """ + conf = run.Config(Llama32Config3B) + conf.seq_length = 8192 + return run.Config(LlamaModel, config=conf) + + +def trainer( + tensor_parallelism: int = 1, + pipeline_parallelism: int = 1, + pipeline_parallelism_type: Optional[torch.dtype] = None, + virtual_pipeline_parallelism: Optional[int] = None, + context_parallelism: int = 1, + sequence_parallelism: bool = False, + num_nodes: int = 1, + num_gpus_per_node: int = 8, + max_steps: int = 1168251, + callbacks: Optional[list[run.Config[Callback]]] = None, +) -> run.Config[nl.Trainer]: + """ + Configure the NeMo Lightning Trainer for Llama3.2 3B model. + + Args: + tensor_parallelism (int): Degree of tensor model parallelism. + pipeline_parallelism (int): Degree of pipeline model parallelism. + pipeline_parallelism_type (Optional[torch.dtype]): Data type for pipeline parallelism. + virtual_pipeline_parallelism (Optional[int]): Size of virtual pipeline parallelism. + context_parallelism (int): Degree of context parallelism. + sequence_parallelism (bool): Whether to use sequence parallelism. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + max_steps (int): Maximum number of training steps. + callbacks (Optional[list[run.Config[Callback]]]): List of callback configurations. + + Returns: + run.Config[nl.Trainer]: Configuration for the NeMo Lightning Trainer. + + Examples: + CLI usage: + $ nemo llm pretrain trainer=llama32_3b ... + + Python API usage: + >>> trainer_config = trainer(num_nodes=1, num_gpus_per_node=1) + >>> print(trainer_config) + + Note: + This configuration uses extensive parallelism to handle the large model size efficiently. + """ + strategy = run.Config( + nl.MegatronStrategy, + tensor_model_parallel_size=tensor_parallelism, + pipeline_model_parallel_size=pipeline_parallelism, + pipeline_dtype=pipeline_parallelism_type, + virtual_pipeline_model_parallel_size=virtual_pipeline_parallelism, + context_parallel_size=context_parallelism, + sequence_parallel=sequence_parallelism, + gradient_as_bucket_view=True, + ckpt_async_save=True, + ckpt_parallel_load=True, + ddp=run.Config( + DistributedDataParallelConfig, + check_for_nan_in_grad=True, + grad_reduce_in_fp32=True, + overlap_grad_reduce=True, + overlap_param_gather=True, + average_in_collective=True, + ), + ) + + trainer = run.Config( + nl.Trainer, + accelerator="gpu", + accumulate_grad_batches=1, + callbacks=callbacks, + devices=num_gpus_per_node, + limit_test_batches=50, + limit_val_batches=32, + log_every_n_steps=10, + max_steps=max_steps, + num_nodes=num_nodes, + plugins=bf16_mixed(), + strategy=strategy, + use_distributed_sampler=False, + val_check_interval=2000, + ) + + return trainer + + +@run.cli.factory(target=pretrain, name=NAME) +def pretrain_recipe( + dir: Optional[str] = None, + name: str = "default", + num_nodes: int = 1, + num_gpus_per_node: int = 8, + fn: Callable = pretrain, +) -> run.Partial: + """ + Create a pre-training recipe for Llama3.2 3B model. + + This function sets up a complete configuration for pre-training, including + model, trainer, data, logging, optimization, and resumption settings. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the pre-training run. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + fn (Callable): The pre-training function to use. + + Returns: + run.Partial: Partial configuration for pre-training. + + Examples: + CLI usage: + $ nemo llm pretrain --factory llama32_3b + $ nemo llm pretrain --factory "llama32_3b(num_nodes=1, name='my_3b_pretrain')" + + Python API usage: + >>> recipe = pretrain_recipe(name="llama32_3b_pretrain", num_nodes=1) + >>> print(recipe) + + Note: + This recipe is optimized for the large 8B model and requires significant computational resources. + """ + recipe = run.Partial( + fn, + model=model(), + trainer=trainer( + num_nodes=num_nodes, + num_gpus_per_node=num_gpus_per_node, + callbacks=[run.Config(TimingCallback)], + ), + data=run.Config(MockDataModule, seq_length=8192, global_batch_size=512, micro_batch_size=1), + log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), + optim=distributed_fused_adam_with_cosine_annealing(max_lr=3e-4), + resume=default_resume(), + ) + + return recipe + + +@run.cli.factory(target=finetune, name=NAME) +def finetune_recipe( + dir: Optional[str] = None, + name: str = "default", + num_nodes: int = 1, + num_gpus_per_node: int = 8, + peft_scheme: Optional[str] = 'lora', + seq_length: Optional[int] = None, + packed_sequence: Optional[bool] = None, +) -> run.Partial: + """ + Create a fine-tuning recipe for Llama3.2 3B model. + + This function sets up a complete configuration for fine-tuning, including + model, trainer, data, logging, optimization, and resumption settings. + The recipe uses LoRA (Low-Rank Adaptation) for efficient fine-tuning, unless peft_scheme is set to None. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the fine-tuning run. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. + seq_length (int): Maximum number of tokens per microbatch. + packed_sequence (Optional[bool]): If true, fine-tuning sequences will be packed into batches up to the given + maximum seq_length for better efficiency. + + Returns: + run.Partial: Partial configuration for fine-tuning. + + Examples: + CLI usage: + $ nemo llm finetune --factory llama32_3b + + Python API usage: + >>> recipe = finetune_recipe(name="llama32_3b_finetune", num_nodes=1) + >>> print(recipe) + + Note: + This recipe uses the SQuAD dataset for fine-tuning. For more information + on fine-tuning LLMs with NeMo, see the fine-tuning guide in the + `examples/llm/finetune/` directory. + """ + + # For unpacked sequence, most samples in SQuAD dataset are shorter than 2K + if seq_length is None: + seq_length = 4096 if packed_sequence else 2048 + + recipe = default_finetune_recipe( + model(), "meta-llama/Llama-3.2-3B", dir, name, num_nodes, num_gpus_per_node, packed_sequence + ) + if peft_scheme is None or peft_scheme.lower() == 'none': + recipe.trainer.strategy.tensor_model_parallel_size = 1 + recipe.optim.config.lr = 5e-6 + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) + recipe.peft.dim = 8 + recipe.peft.alpha = 16 + recipe.optim.config.use_distributed_optimizer = False + + # some settings currently do not function correctly with LoRA + recipe.model.config.cross_entropy_loss_fusion = False + + recipe.optim.config.lr = 1e-4 + else: + raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") + + # Sequence length settings in the model and dataset must agree + recipe.model.config.seq_length = seq_length + recipe.data.seq_length = seq_length + if packed_sequence: + recipe.data.dataset_kwargs = {'pad_to_max_length': True} + recipe.data.packed_sequence_specs = run.Config(PackedSequenceSpecs, packed_sequence_size=seq_length) + + return recipe diff --git a/nemo/collections/llm/recipes/llama3_70b.py b/nemo/collections/llm/recipes/llama3_70b.py index d43302a0a0ee..8b61bff80e01 100644 --- a/nemo/collections/llm/recipes/llama3_70b.py +++ b/nemo/collections/llm/recipes/llama3_70b.py @@ -15,18 +15,18 @@ from typing import Callable, Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch +from lightning.pytorch.callbacks.callback import Callback from megatron.core.distributed import DistributedDataParallelConfig -from pytorch_lightning.callbacks.callback import Callback from nemo import lightning as nl from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs from nemo.collections.llm.gpt.model.llama import Llama3Config70B, LlamaModel -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing @@ -263,7 +263,8 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for finetuning. Allowed values: 'lora'/'none'/None. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. seq_length (int): Maximum number of tokens per microbatch. packed_sequence (Optional[bool]): If true, fine-tuning sequences will be packed into batches up to the given maximum seq_length for better efficiency. By default, this value equals performance_mode. @@ -297,7 +298,7 @@ def finetune_recipe( if num_nodes is None: if peft_scheme is None or peft_scheme.lower() == 'none': num_nodes = 4 - elif peft_scheme.lower() == 'lora': + elif peft_scheme.lower() in ['lora', 'dora']: num_nodes = 1 recipe = default_finetune_recipe( @@ -307,11 +308,10 @@ def finetune_recipe( recipe.trainer.strategy.tensor_model_parallel_size = 8 recipe.trainer.strategy.pipeline_model_parallel_size = 4 recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) recipe.peft.dim = 16 recipe.peft.alpha = 32 - recipe.peft.target_modules = ['linear_qkv'] recipe.optim.config.use_distributed_optimizer = False # some settings currently do not function correctly with LoRA @@ -347,7 +347,8 @@ def finetune_performance_optimizations( Args: recipe (run.Partial): Base fine-tuning recipe to which performance optimizations will be added - peft_scheme (str): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. Returns: run.Partial: Partial configuration for performance-optimized fine-tuning. @@ -385,6 +386,7 @@ def finetune_performance_optimizations( recipe.trainer.strategy.tensor_model_parallel_size = 2 recipe.trainer.strategy.pipeline_model_parallel_size = 4 recipe.trainer.strategy.virtual_pipeline_model_parallel_size = 5 + recipe.peft.target_modules = ['linear_qkv'] recipe.trainer.strategy.sequence_parallel = True diff --git a/nemo/collections/llm/recipes/llama3_70b_16k.py b/nemo/collections/llm/recipes/llama3_70b_16k.py index 928f961f7cf3..0a394d386afd 100644 --- a/nemo/collections/llm/recipes/llama3_70b_16k.py +++ b/nemo/collections/llm/recipes/llama3_70b_16k.py @@ -15,8 +15,8 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch from nemo.collections.llm.api import finetune, pretrain diff --git a/nemo/collections/llm/recipes/llama3_70b_64k.py b/nemo/collections/llm/recipes/llama3_70b_64k.py index ffadf5ca8084..e035424d3506 100644 --- a/nemo/collections/llm/recipes/llama3_70b_64k.py +++ b/nemo/collections/llm/recipes/llama3_70b_64k.py @@ -15,8 +15,8 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch from nemo.collections.llm.api import finetune, pretrain diff --git a/nemo/collections/llm/recipes/llama3_8b.py b/nemo/collections/llm/recipes/llama3_8b.py index 4f6f6ce17443..36b20c12ddb2 100644 --- a/nemo/collections/llm/recipes/llama3_8b.py +++ b/nemo/collections/llm/recipes/llama3_8b.py @@ -15,18 +15,18 @@ from typing import Callable, Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch +from lightning.pytorch.callbacks.callback import Callback from megatron.core.distributed import DistributedDataParallelConfig -from pytorch_lightning.callbacks.callback import Callback from nemo import lightning as nl from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs from nemo.collections.llm.gpt.model.llama import Llama3Config8B, LlamaModel -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing @@ -250,7 +250,8 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for finetuning. Allowed values: 'lora'/'none'/None. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. seq_length (int): Maximum number of tokens per microbatch. packed_sequence (Optional[bool]): If true, fine-tuning sequences will be packed into batches up to the given maximum seq_length for better efficiency. By default, this value equals performance_mode. @@ -287,11 +288,10 @@ def finetune_recipe( if peft_scheme is None or peft_scheme.lower() == 'none': recipe.trainer.strategy.tensor_model_parallel_size = 2 recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) recipe.peft.dim = 8 recipe.peft.alpha = 16 - recipe.peft.target_modules = ['linear_qkv'] recipe.optim.config.use_distributed_optimizer = False # some settings currently do not function correctly with LoRA @@ -326,7 +326,8 @@ def finetune_performance_optimizations( Args: recipe (run.Partial): Base fine-tuning recipe to which performance optimizations will be added - peft_scheme (str): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. Returns: run.Partial: Partial configuration for performance-optimized fine-tuning. @@ -356,6 +357,8 @@ def finetune_performance_optimizations( tp_comm_overlap=False, ) ) + else: + recipe.peft.target_modules = ['linear_qkv'] recipe.trainer.callbacks.append(run.Config(TimingCallback)) recipe.trainer.callbacks.append( diff --git a/nemo/collections/llm/recipes/llama3_8b_16k.py b/nemo/collections/llm/recipes/llama3_8b_16k.py index d6c1677a3b4b..b81d01c6ec9a 100644 --- a/nemo/collections/llm/recipes/llama3_8b_16k.py +++ b/nemo/collections/llm/recipes/llama3_8b_16k.py @@ -15,8 +15,8 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch from nemo.collections.llm.api import finetune, pretrain diff --git a/nemo/collections/llm/recipes/llama3_8b_64k.py b/nemo/collections/llm/recipes/llama3_8b_64k.py index 692347ea8dd0..ff176fb372bb 100644 --- a/nemo/collections/llm/recipes/llama3_8b_64k.py +++ b/nemo/collections/llm/recipes/llama3_8b_64k.py @@ -15,8 +15,8 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch from nemo.collections.llm.api import finetune, pretrain diff --git a/nemo/collections/llm/recipes/log/default.py b/nemo/collections/llm/recipes/log/default.py index d83580a1a543..023e4e459d5f 100644 --- a/nemo/collections/llm/recipes/log/default.py +++ b/nemo/collections/llm/recipes/log/default.py @@ -16,8 +16,8 @@ from datetime import timedelta from typing import Optional +from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger from nemo_run import Config, cli -from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger from nemo import lightning as nl diff --git a/nemo/collections/llm/recipes/mamba2_130m.py b/nemo/collections/llm/recipes/mamba2_130m.py index 08640604a112..e70fec03b3fb 100644 --- a/nemo/collections/llm/recipes/mamba2_130m.py +++ b/nemo/collections/llm/recipes/mamba2_130m.py @@ -15,11 +15,11 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch +from lightning.pytorch.callbacks.callback import Callback from megatron.core.distributed import DistributedDataParallelConfig -from pytorch_lightning.callbacks.callback import Callback from nemo import lightning as nl from nemo.collections import llm @@ -67,6 +67,7 @@ def model(tokenizer_model: str = None) -> run.Config[pl.LightningModule]: ) +@run.cli.factory(target=finetune, name=NAME) def trainer( tensor_parallelism: int = 1, pipeline_parallelism: int = 1, @@ -76,7 +77,11 @@ def trainer( sequence_parallelism: bool = False, num_nodes: int = 1, num_gpus_per_node: int = 8, - max_steps: int = 1168251, + max_steps: int = 100, + val_check_interval: int = 100, + limit_test_batches: int = 50, + limit_val_batches: int = 32, + log_every_n_steps: int = 10, callbacks: Optional[list[run.Config[Callback]]] = None, ) -> run.Config[nl.Trainer]: """ @@ -137,15 +142,15 @@ def trainer( accumulate_grad_batches=1, callbacks=callbacks, devices=num_gpus_per_node, - limit_test_batches=50, - limit_val_batches=32, - log_every_n_steps=10, max_steps=max_steps, num_nodes=num_nodes, plugins=bf16_mixed(), strategy=strategy, use_distributed_sampler=False, - val_check_interval=2000, + val_check_interval=val_check_interval, + limit_test_batches=limit_test_batches, + limit_val_batches=limit_val_batches, + log_every_n_steps=log_every_n_steps, ) return trainer @@ -158,6 +163,16 @@ def pretrain_recipe( tokenizer_model: str = None, num_nodes: int = 1, num_gpus_per_node: int = 8, + tensor_parallelism: int = 1, + pipeline_parallelism: int = 1, + max_steps: int = 100, + val_check_interval: int = 100, + limit_test_batches: int = 50, + limit_val_batches: int = 32, + log_every_n_steps: int = 10, + seq_length: int = 4096, + gbs: int = 8, + mbs: int = 1, fn=pretrain, ) -> run.Partial: """ @@ -193,16 +208,23 @@ def pretrain_recipe( fn, model=model(), trainer=trainer( + max_steps=max_steps, num_nodes=num_nodes, + tensor_parallelism=tensor_parallelism, + pipeline_parallelism=pipeline_parallelism, num_gpus_per_node=num_gpus_per_node, + val_check_interval=val_check_interval, + limit_test_batches=limit_test_batches, + limit_val_batches=limit_val_batches, + log_every_n_steps=log_every_n_steps, callbacks=[run.Config(TimingCallback)], ), data=run.Config( MockDataModule, - seq_length=4096, - global_batch_size=8, - micro_batch_size=1, - tokenizer=tokenizer(tokenizer_model=tokenizer_model), + seq_length=seq_length, + global_batch_size=gbs, + micro_batch_size=mbs, + tokenizer=tokenizer(), ), log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), optim=distributed_fused_adam_with_cosine_annealing(max_lr=3e-4), @@ -218,6 +240,14 @@ def finetune_recipe( tokenizer_model: str = None, num_nodes: int = 1, num_gpus_per_node: int = 8, + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + seq_length: int = 4096, + max_steps: int = 100, + val_check_interval: int = 100, + limit_test_batches: int = 50, + limit_val_batches: int = 32, + log_every_n_steps: int = 10, gbs: int = 8, mbs: int = 1, peft_scheme: Optional[str] = 'none', @@ -266,8 +296,8 @@ def finetune_recipe( ) strategy = run.Config( nl.MegatronStrategy, - tensor_model_parallel_size=1, - pipeline_model_parallel_size=1, + tensor_model_parallel_size=tensor_model_parallel_size, + pipeline_model_parallel_size=pipeline_model_parallel_size, gradient_as_bucket_view=True, ckpt_load_optimizer=False, ckpt_save_optimizer=False, @@ -283,10 +313,11 @@ def finetune_recipe( accelerator="gpu", accumulate_grad_batches=1, devices=num_gpus_per_node, - limit_test_batches=10, - limit_val_batches=10, - log_every_n_steps=20, - max_steps=100, + max_steps=max_steps, + val_check_interval=val_check_interval, + limit_test_batches=limit_test_batches, + limit_val_batches=limit_val_batches, + log_every_n_steps=log_every_n_steps, num_nodes=num_nodes, plugins=run.Config( nl.MegatronMixedPrecision, @@ -296,7 +327,6 @@ def finetune_recipe( callbacks=[checkpoint_callback], strategy=strategy, use_distributed_sampler=False, - val_check_interval=20, ) recipe = run.Partial( llm.finetune, @@ -304,7 +334,7 @@ def finetune_recipe( trainer=trainer, data=run.Config( llm.SquadDataModule, - seq_length=2048, + seq_length=seq_length, global_batch_size=gbs, micro_batch_size=mbs, tokenizer=tokenizer(tokenizer_model=tokenizer_model), diff --git a/nemo/collections/llm/recipes/mamba2_1_3b.py b/nemo/collections/llm/recipes/mamba2_1_3b.py index 58eaf049b059..aaa263078686 100644 --- a/nemo/collections/llm/recipes/mamba2_1_3b.py +++ b/nemo/collections/llm/recipes/mamba2_1_3b.py @@ -15,11 +15,11 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch +from lightning.pytorch.callbacks.callback import Callback from megatron.core.distributed import DistributedDataParallelConfig -from pytorch_lightning.callbacks.callback import Callback from nemo import lightning as nl from nemo.collections import llm @@ -67,6 +67,7 @@ def model(tokenizer_model: str = None) -> run.Config[pl.LightningModule]: ) +@run.cli.factory(target=finetune, name=NAME) def trainer( tensor_parallelism: int = 1, pipeline_parallelism: int = 1, @@ -76,7 +77,11 @@ def trainer( sequence_parallelism: bool = False, num_nodes: int = 1, num_gpus_per_node: int = 8, - max_steps: int = 1168251, + max_steps: int = 100, + val_check_interval: int = 100, + limit_test_batches: int = 50, + limit_val_batches: int = 32, + log_every_n_steps: int = 10, callbacks: Optional[list[run.Config[Callback]]] = None, ) -> run.Config[nl.Trainer]: """ @@ -137,15 +142,15 @@ def trainer( accumulate_grad_batches=1, callbacks=callbacks, devices=num_gpus_per_node, - limit_test_batches=50, - limit_val_batches=32, - log_every_n_steps=10, max_steps=max_steps, num_nodes=num_nodes, plugins=bf16_mixed(), strategy=strategy, use_distributed_sampler=False, - val_check_interval=2000, + val_check_interval=val_check_interval, + limit_test_batches=limit_test_batches, + limit_val_batches=limit_val_batches, + log_every_n_steps=log_every_n_steps, ) return trainer @@ -157,7 +162,17 @@ def pretrain_recipe( name: str = "default", tokenizer_model: str = None, num_nodes: int = 1, - num_gpus_per_node: int = 8, + num_gpus_per_node: int = 1, + tensor_parallelism: int = 1, + pipeline_parallelism: int = 1, + max_steps: int = 100, + val_check_interval: int = 100, + limit_test_batches: int = 50, + limit_val_batches: int = 32, + log_every_n_steps: int = 10, + seq_length: int = 4096, + gbs: int = 8, + mbs: int = 1, fn=pretrain, ) -> run.Partial: """ @@ -191,17 +206,24 @@ def pretrain_recipe( """ return run.Partial( fn, - model=model(), + model=model(tokenizer_model=tokenizer_model), trainer=trainer( + max_steps=max_steps, num_nodes=num_nodes, + tensor_parallelism=tensor_parallelism, + pipeline_parallelism=pipeline_parallelism, num_gpus_per_node=num_gpus_per_node, + val_check_interval=val_check_interval, + limit_test_batches=limit_test_batches, + limit_val_batches=limit_val_batches, + log_every_n_steps=log_every_n_steps, callbacks=[run.Config(TimingCallback)], ), data=run.Config( MockDataModule, - seq_length=4096, - global_batch_size=8, - micro_batch_size=1, + seq_length=seq_length, + global_batch_size=gbs, + micro_batch_size=mbs, tokenizer=tokenizer(tokenizer_model=tokenizer_model), ), log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), @@ -217,7 +239,15 @@ def finetune_recipe( resume_path: str = None, tokenizer_model: str = None, num_nodes: int = 1, - num_gpus_per_node: int = 8, + num_gpus_per_node: int = 1, + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + seq_length: int = 4096, + max_steps: int = 100, + val_check_interval: int = 100, + limit_test_batches: int = 50, + limit_val_batches: int = 32, + log_every_n_steps: int = 10, gbs: int = 8, mbs: int = 1, peft_scheme: Optional[str] = 'none', @@ -266,8 +296,8 @@ def finetune_recipe( ) strategy = run.Config( nl.MegatronStrategy, - tensor_model_parallel_size=1, - pipeline_model_parallel_size=1, + tensor_model_parallel_size=tensor_model_parallel_size, + pipeline_model_parallel_size=pipeline_model_parallel_size, gradient_as_bucket_view=True, ckpt_load_optimizer=False, ckpt_save_optimizer=False, @@ -283,10 +313,11 @@ def finetune_recipe( accelerator="gpu", accumulate_grad_batches=1, devices=num_gpus_per_node, - limit_test_batches=10, - limit_val_batches=10, - log_every_n_steps=20, - max_steps=100, + max_steps=max_steps, + val_check_interval=val_check_interval, + limit_test_batches=limit_test_batches, + limit_val_batches=limit_val_batches, + log_every_n_steps=log_every_n_steps, num_nodes=num_nodes, plugins=run.Config( nl.MegatronMixedPrecision, @@ -296,7 +327,6 @@ def finetune_recipe( callbacks=[checkpoint_callback], strategy=strategy, use_distributed_sampler=False, - val_check_interval=20, ) recipe = run.Partial( llm.finetune, @@ -304,7 +334,7 @@ def finetune_recipe( trainer=trainer, data=run.Config( llm.SquadDataModule, - seq_length=2048, + seq_length=seq_length, global_batch_size=gbs, micro_batch_size=mbs, tokenizer=tokenizer(tokenizer_model=tokenizer_model), diff --git a/nemo/collections/llm/recipes/mamba2_2_7b.py b/nemo/collections/llm/recipes/mamba2_2_7b.py index 5cb37c6a02a5..b4fd5b487b6a 100644 --- a/nemo/collections/llm/recipes/mamba2_2_7b.py +++ b/nemo/collections/llm/recipes/mamba2_2_7b.py @@ -15,11 +15,11 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch +from lightning.pytorch.callbacks.callback import Callback from megatron.core.distributed import DistributedDataParallelConfig -from pytorch_lightning.callbacks.callback import Callback from nemo import lightning as nl from nemo.collections import llm @@ -67,6 +67,7 @@ def model(tokenizer_model: str = None) -> run.Config[pl.LightningModule]: ) +@run.cli.factory(target=finetune, name=NAME) def trainer( tensor_parallelism: int = 1, pipeline_parallelism: int = 1, @@ -76,7 +77,11 @@ def trainer( sequence_parallelism: bool = False, num_nodes: int = 1, num_gpus_per_node: int = 8, - max_steps: int = 1168251, + max_steps: int = 100, + val_check_interval: int = 100, + limit_test_batches: int = 50, + limit_val_batches: int = 32, + log_every_n_steps: int = 10, callbacks: Optional[list[run.Config[Callback]]] = None, ) -> run.Config[nl.Trainer]: """ @@ -137,15 +142,15 @@ def trainer( accumulate_grad_batches=1, callbacks=callbacks, devices=num_gpus_per_node, - limit_test_batches=50, - limit_val_batches=32, - log_every_n_steps=10, max_steps=max_steps, num_nodes=num_nodes, plugins=bf16_mixed(), strategy=strategy, use_distributed_sampler=False, - val_check_interval=2000, + val_check_interval=val_check_interval, + limit_test_batches=limit_test_batches, + limit_val_batches=limit_val_batches, + log_every_n_steps=log_every_n_steps, ) return trainer @@ -158,6 +163,16 @@ def pretrain_recipe( tokenizer_model: str = None, num_nodes: int = 1, num_gpus_per_node: int = 8, + tensor_parallelism: int = 1, + pipeline_parallelism: int = 1, + max_steps: int = 100, + val_check_interval: int = 100, + limit_test_batches: int = 50, + limit_val_batches: int = 32, + log_every_n_steps: int = 10, + seq_length: int = 4096, + gbs: int = 8, + mbs: int = 1, fn=pretrain, ) -> run.Partial: """ @@ -193,16 +208,23 @@ def pretrain_recipe( fn, model=model(), trainer=trainer( + max_steps=max_steps, num_nodes=num_nodes, + tensor_parallelism=tensor_parallelism, + pipeline_parallelism=pipeline_parallelism, num_gpus_per_node=num_gpus_per_node, + val_check_interval=val_check_interval, + limit_test_batches=limit_test_batches, + limit_val_batches=limit_val_batches, + log_every_n_steps=log_every_n_steps, callbacks=[run.Config(TimingCallback)], ), data=run.Config( MockDataModule, - seq_length=4096, - global_batch_size=8, - micro_batch_size=1, - tokenizer=tokenizer(tokenizer_model=tokenizer_model), + seq_length=seq_length, + global_batch_size=gbs, + micro_batch_size=mbs, + tokenizer=tokenizer(), ), log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), optim=distributed_fused_adam_with_cosine_annealing(max_lr=3e-4), @@ -218,6 +240,14 @@ def finetune_recipe( tokenizer_model: str = None, num_nodes: int = 1, num_gpus_per_node: int = 8, + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + seq_length: int = 4096, + max_steps: int = 100, + val_check_interval: int = 100, + limit_test_batches: int = 50, + limit_val_batches: int = 32, + log_every_n_steps: int = 10, gbs: int = 8, mbs: int = 1, peft_scheme: Optional[str] = 'none', @@ -266,8 +296,8 @@ def finetune_recipe( ) strategy = run.Config( nl.MegatronStrategy, - tensor_model_parallel_size=1, - pipeline_model_parallel_size=1, + tensor_model_parallel_size=tensor_model_parallel_size, + pipeline_model_parallel_size=pipeline_model_parallel_size, gradient_as_bucket_view=True, ckpt_load_optimizer=False, ckpt_save_optimizer=False, @@ -283,10 +313,11 @@ def finetune_recipe( accelerator="gpu", accumulate_grad_batches=1, devices=num_gpus_per_node, - limit_test_batches=10, - limit_val_batches=10, - log_every_n_steps=20, - max_steps=100, + max_steps=max_steps, + val_check_interval=val_check_interval, + limit_test_batches=limit_test_batches, + limit_val_batches=limit_val_batches, + log_every_n_steps=log_every_n_steps, num_nodes=num_nodes, plugins=run.Config( nl.MegatronMixedPrecision, @@ -296,7 +327,6 @@ def finetune_recipe( callbacks=[checkpoint_callback], strategy=strategy, use_distributed_sampler=False, - val_check_interval=20, ) recipe = run.Partial( llm.finetune, @@ -304,7 +334,7 @@ def finetune_recipe( trainer=trainer, data=run.Config( llm.SquadDataModule, - seq_length=2048, + seq_length=seq_length, global_batch_size=gbs, micro_batch_size=mbs, tokenizer=tokenizer(tokenizer_model=tokenizer_model), diff --git a/nemo/collections/llm/recipes/mamba2_370m.py b/nemo/collections/llm/recipes/mamba2_370m.py index bb8bddc4045a..6fa619b33486 100644 --- a/nemo/collections/llm/recipes/mamba2_370m.py +++ b/nemo/collections/llm/recipes/mamba2_370m.py @@ -15,11 +15,11 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch +from lightning.pytorch.callbacks.callback import Callback from megatron.core.distributed import DistributedDataParallelConfig -from pytorch_lightning.callbacks.callback import Callback from nemo import lightning as nl from nemo.collections import llm @@ -67,6 +67,7 @@ def model(tokenizer_model: str = None) -> run.Config[pl.LightningModule]: ) +@run.cli.factory(target=finetune, name=NAME) def trainer( tensor_parallelism: int = 1, pipeline_parallelism: int = 1, @@ -76,7 +77,11 @@ def trainer( sequence_parallelism: bool = False, num_nodes: int = 1, num_gpus_per_node: int = 8, - max_steps: int = 1168251, + max_steps: int = 100, + val_check_interval: int = 100, + limit_test_batches: int = 50, + limit_val_batches: int = 32, + log_every_n_steps: int = 10, callbacks: Optional[list[run.Config[Callback]]] = None, ) -> run.Config[nl.Trainer]: """ @@ -137,15 +142,15 @@ def trainer( accumulate_grad_batches=1, callbacks=callbacks, devices=num_gpus_per_node, - limit_test_batches=50, - limit_val_batches=32, - log_every_n_steps=10, max_steps=max_steps, num_nodes=num_nodes, plugins=bf16_mixed(), strategy=strategy, use_distributed_sampler=False, - val_check_interval=2000, + val_check_interval=val_check_interval, + limit_test_batches=limit_test_batches, + limit_val_batches=limit_val_batches, + log_every_n_steps=log_every_n_steps, ) return trainer @@ -158,6 +163,16 @@ def pretrain_recipe( tokenizer_model: str = None, num_nodes: int = 1, num_gpus_per_node: int = 8, + tensor_parallelism: int = 1, + pipeline_parallelism: int = 1, + max_steps: int = 100, + val_check_interval: int = 100, + limit_test_batches: int = 50, + limit_val_batches: int = 32, + log_every_n_steps: int = 10, + seq_length: int = 4096, + gbs: int = 8, + mbs: int = 1, fn=pretrain, ) -> run.Partial: """ @@ -193,16 +208,23 @@ def pretrain_recipe( fn, model=model(), trainer=trainer( + max_steps=max_steps, num_nodes=num_nodes, + tensor_parallelism=tensor_parallelism, + pipeline_parallelism=pipeline_parallelism, num_gpus_per_node=num_gpus_per_node, + val_check_interval=val_check_interval, + limit_test_batches=limit_test_batches, + limit_val_batches=limit_val_batches, + log_every_n_steps=log_every_n_steps, callbacks=[run.Config(TimingCallback)], ), data=run.Config( MockDataModule, - seq_length=4096, - global_batch_size=8, - micro_batch_size=1, - tokenizer=tokenizer(tokenizer_model=tokenizer_model), + seq_length=seq_length, + global_batch_size=gbs, + micro_batch_size=mbs, + tokenizer=tokenizer(), ), log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), optim=distributed_fused_adam_with_cosine_annealing(max_lr=3e-4), @@ -218,6 +240,14 @@ def finetune_recipe( tokenizer_model: str = None, num_nodes: int = 1, num_gpus_per_node: int = 8, + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + seq_length: int = 4096, + max_steps: int = 100, + val_check_interval: int = 100, + limit_test_batches: int = 50, + limit_val_batches: int = 32, + log_every_n_steps: int = 10, gbs: int = 8, mbs: int = 1, peft_scheme: Optional[str] = 'none', @@ -266,8 +296,8 @@ def finetune_recipe( ) strategy = run.Config( nl.MegatronStrategy, - tensor_model_parallel_size=1, - pipeline_model_parallel_size=1, + tensor_model_parallel_size=tensor_model_parallel_size, + pipeline_model_parallel_size=pipeline_model_parallel_size, gradient_as_bucket_view=True, ckpt_load_optimizer=False, ckpt_save_optimizer=False, @@ -283,10 +313,11 @@ def finetune_recipe( accelerator="gpu", accumulate_grad_batches=1, devices=num_gpus_per_node, - limit_test_batches=10, - limit_val_batches=10, - log_every_n_steps=20, - max_steps=100, + max_steps=max_steps, + val_check_interval=val_check_interval, + limit_test_batches=limit_test_batches, + limit_val_batches=limit_val_batches, + log_every_n_steps=log_every_n_steps, num_nodes=num_nodes, plugins=run.Config( nl.MegatronMixedPrecision, @@ -296,7 +327,6 @@ def finetune_recipe( callbacks=[checkpoint_callback], strategy=strategy, use_distributed_sampler=False, - val_check_interval=20, ) recipe = run.Partial( llm.finetune, @@ -304,7 +334,7 @@ def finetune_recipe( trainer=trainer, data=run.Config( llm.SquadDataModule, - seq_length=2048, + seq_length=seq_length, global_batch_size=gbs, micro_batch_size=mbs, tokenizer=tokenizer(tokenizer_model=tokenizer_model), diff --git a/nemo/collections/llm/recipes/mamba2_780m.py b/nemo/collections/llm/recipes/mamba2_780m.py index 2f6ab6717ae1..45d28f82f779 100644 --- a/nemo/collections/llm/recipes/mamba2_780m.py +++ b/nemo/collections/llm/recipes/mamba2_780m.py @@ -15,11 +15,11 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch +from lightning.pytorch.callbacks.callback import Callback from megatron.core.distributed import DistributedDataParallelConfig -from pytorch_lightning.callbacks.callback import Callback from nemo import lightning as nl from nemo.collections import llm @@ -67,6 +67,7 @@ def model(tokenizer_model: str = None) -> run.Config[pl.LightningModule]: ) +@run.cli.factory(target=finetune, name=NAME) def trainer( tensor_parallelism: int = 1, pipeline_parallelism: int = 1, @@ -76,7 +77,11 @@ def trainer( sequence_parallelism: bool = False, num_nodes: int = 1, num_gpus_per_node: int = 8, - max_steps: int = 1168251, + max_steps: int = 100, + val_check_interval: int = 100, + limit_test_batches: int = 50, + limit_val_batches: int = 32, + log_every_n_steps: int = 10, callbacks: Optional[list[run.Config[Callback]]] = None, ) -> run.Config[nl.Trainer]: """ @@ -137,15 +142,15 @@ def trainer( accumulate_grad_batches=1, callbacks=callbacks, devices=num_gpus_per_node, - limit_test_batches=50, - limit_val_batches=32, - log_every_n_steps=10, max_steps=max_steps, num_nodes=num_nodes, plugins=bf16_mixed(), strategy=strategy, use_distributed_sampler=False, - val_check_interval=2000, + val_check_interval=val_check_interval, + limit_test_batches=limit_test_batches, + limit_val_batches=limit_val_batches, + log_every_n_steps=log_every_n_steps, ) return trainer @@ -158,6 +163,16 @@ def pretrain_recipe( tokenizer_model: str = None, num_nodes: int = 1, num_gpus_per_node: int = 8, + tensor_parallelism: int = 1, + pipeline_parallelism: int = 1, + max_steps: int = 100, + val_check_interval: int = 100, + limit_test_batches: int = 50, + limit_val_batches: int = 32, + log_every_n_steps: int = 10, + seq_length: int = 4096, + gbs: int = 8, + mbs: int = 1, fn=pretrain, ) -> run.Partial: """ @@ -193,16 +208,23 @@ def pretrain_recipe( fn, model=model(), trainer=trainer( + max_steps=max_steps, num_nodes=num_nodes, + tensor_parallelism=tensor_parallelism, + pipeline_parallelism=pipeline_parallelism, num_gpus_per_node=num_gpus_per_node, + val_check_interval=val_check_interval, + limit_test_batches=limit_test_batches, + limit_val_batches=limit_val_batches, + log_every_n_steps=log_every_n_steps, callbacks=[run.Config(TimingCallback)], ), data=run.Config( MockDataModule, - seq_length=4096, - global_batch_size=8, - micro_batch_size=1, - tokenizer=tokenizer(tokenizer_model=tokenizer_model), + seq_length=seq_length, + global_batch_size=gbs, + micro_batch_size=mbs, + tokenizer=tokenizer(), ), log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), optim=distributed_fused_adam_with_cosine_annealing(max_lr=3e-4), @@ -218,6 +240,14 @@ def finetune_recipe( tokenizer_model: str = None, num_nodes: int = 1, num_gpus_per_node: int = 8, + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + seq_length: int = 4096, + max_steps: int = 100, + val_check_interval: int = 100, + limit_test_batches: int = 50, + limit_val_batches: int = 32, + log_every_n_steps: int = 10, gbs: int = 8, mbs: int = 1, peft_scheme: Optional[str] = 'none', @@ -266,8 +296,8 @@ def finetune_recipe( ) strategy = run.Config( nl.MegatronStrategy, - tensor_model_parallel_size=1, - pipeline_model_parallel_size=1, + tensor_model_parallel_size=tensor_model_parallel_size, + pipeline_model_parallel_size=pipeline_model_parallel_size, gradient_as_bucket_view=True, ckpt_load_optimizer=False, ckpt_save_optimizer=False, @@ -283,10 +313,11 @@ def finetune_recipe( accelerator="gpu", accumulate_grad_batches=1, devices=num_gpus_per_node, - limit_test_batches=10, - limit_val_batches=10, - log_every_n_steps=20, - max_steps=100, + max_steps=max_steps, + val_check_interval=val_check_interval, + limit_test_batches=limit_test_batches, + limit_val_batches=limit_val_batches, + log_every_n_steps=log_every_n_steps, num_nodes=num_nodes, plugins=run.Config( nl.MegatronMixedPrecision, @@ -296,7 +327,6 @@ def finetune_recipe( callbacks=[checkpoint_callback], strategy=strategy, use_distributed_sampler=False, - val_check_interval=20, ) recipe = run.Partial( llm.finetune, @@ -304,7 +334,7 @@ def finetune_recipe( trainer=trainer, data=run.Config( llm.SquadDataModule, - seq_length=2048, + seq_length=seq_length, global_batch_size=gbs, micro_batch_size=mbs, tokenizer=tokenizer(tokenizer_model=tokenizer_model), diff --git a/nemo/collections/llm/recipes/mamba2_8b.py b/nemo/collections/llm/recipes/mamba2_8b.py index 58883deba732..8f8384b45059 100644 --- a/nemo/collections/llm/recipes/mamba2_8b.py +++ b/nemo/collections/llm/recipes/mamba2_8b.py @@ -15,11 +15,11 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch +from lightning.pytorch.callbacks.callback import Callback from megatron.core.distributed import DistributedDataParallelConfig -from pytorch_lightning.callbacks.callback import Callback from nemo import lightning as nl from nemo.collections import llm @@ -67,6 +67,7 @@ def model(tokenizer_model: str = None) -> run.Config[pl.LightningModule]: ) +@run.cli.factory(name=NAME) def trainer( tensor_parallelism: int = 8, pipeline_parallelism: int = 1, @@ -76,7 +77,11 @@ def trainer( sequence_parallelism: bool = False, num_nodes: int = 1, num_gpus_per_node: int = 8, - max_steps: int = 1168251, + max_steps: int = 100, + val_check_interval: int = 100, + limit_test_batches: int = 50, + limit_val_batches: int = 32, + log_every_n_steps: int = 10, callbacks: Optional[list[run.Config[Callback]]] = None, ) -> run.Config[nl.Trainer]: """ @@ -137,15 +142,15 @@ def trainer( accumulate_grad_batches=1, callbacks=callbacks, devices=num_gpus_per_node, - limit_test_batches=50, - limit_val_batches=32, - log_every_n_steps=10, max_steps=max_steps, num_nodes=num_nodes, plugins=bf16_mixed(), strategy=strategy, use_distributed_sampler=False, - val_check_interval=2000, + val_check_interval=val_check_interval, + limit_test_batches=limit_test_batches, + limit_val_batches=limit_val_batches, + log_every_n_steps=log_every_n_steps, ) return trainer @@ -158,6 +163,16 @@ def pretrain_recipe( tokenizer_model: str = None, num_nodes: int = 1, num_gpus_per_node: int = 8, + tensor_parallelism: int = 8, + pipeline_parallelism: int = 1, + max_steps: int = 100, + val_check_interval: int = 100, + limit_test_batches: int = 50, + limit_val_batches: int = 32, + log_every_n_steps: int = 10, + seq_length: int = 4096, + gbs: int = 8, + mbs: int = 1, fn=pretrain, ) -> run.Partial: """ @@ -191,17 +206,24 @@ def pretrain_recipe( """ return run.Partial( fn, - model=model(), + model=model(tokenizer_model=tokenizer_model), trainer=trainer( + max_steps=max_steps, num_nodes=num_nodes, + tensor_parallelism=tensor_parallelism, + pipeline_parallelism=pipeline_parallelism, num_gpus_per_node=num_gpus_per_node, + val_check_interval=val_check_interval, + limit_test_batches=limit_test_batches, + limit_val_batches=limit_val_batches, + log_every_n_steps=log_every_n_steps, callbacks=[run.Config(TimingCallback)], ), data=run.Config( MockDataModule, - seq_length=4096, - global_batch_size=8, - micro_batch_size=1, + seq_length=seq_length, + global_batch_size=gbs, + micro_batch_size=mbs, tokenizer=tokenizer(tokenizer_model=tokenizer_model), ), log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), @@ -218,6 +240,14 @@ def finetune_recipe( name: str = "default", num_nodes: int = 1, num_gpus_per_node: int = 8, + tensor_model_parallel_size: int = 8, + pipeline_model_parallel_size: int = 1, + seq_length: int = 4096, + max_steps: int = 100, + val_check_interval: int = 100, + limit_test_batches: int = 50, + limit_val_batches: int = 32, + log_every_n_steps: int = 10, gbs: int = 8, mbs: int = 1, peft_scheme: Optional[str] = 'none', @@ -266,8 +296,8 @@ def finetune_recipe( ) strategy = run.Config( nl.MegatronStrategy, - tensor_model_parallel_size=8, - pipeline_model_parallel_size=1, + tensor_model_parallel_size=tensor_model_parallel_size, + pipeline_model_parallel_size=pipeline_model_parallel_size, gradient_as_bucket_view=True, ckpt_load_optimizer=False, ckpt_save_optimizer=False, @@ -283,10 +313,11 @@ def finetune_recipe( accelerator="gpu", accumulate_grad_batches=1, devices=num_gpus_per_node, - limit_test_batches=10, - limit_val_batches=10, - log_every_n_steps=20, - max_steps=100, + max_steps=max_steps, + val_check_interval=val_check_interval, + limit_test_batches=limit_test_batches, + limit_val_batches=limit_val_batches, + log_every_n_steps=log_every_n_steps, num_nodes=num_nodes, plugins=run.Config( nl.MegatronMixedPrecision, @@ -296,7 +327,6 @@ def finetune_recipe( callbacks=[checkpoint_callback], strategy=strategy, use_distributed_sampler=False, - val_check_interval=20, ) recipe = run.Partial( llm.finetune, @@ -304,7 +334,7 @@ def finetune_recipe( trainer=trainer, data=run.Config( llm.SquadDataModule, - seq_length=2048, + seq_length=seq_length, global_batch_size=gbs, micro_batch_size=mbs, tokenizer=tokenizer(tokenizer_model=tokenizer_model), diff --git a/nemo/collections/llm/recipes/mamba2_hybrid_8b.py b/nemo/collections/llm/recipes/mamba2_hybrid_8b.py index eff37da46fca..b91c8e228bc9 100644 --- a/nemo/collections/llm/recipes/mamba2_hybrid_8b.py +++ b/nemo/collections/llm/recipes/mamba2_hybrid_8b.py @@ -15,11 +15,11 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch +from lightning.pytorch.callbacks.callback import Callback from megatron.core.distributed import DistributedDataParallelConfig -from pytorch_lightning.callbacks.callback import Callback from nemo import lightning as nl from nemo.collections import llm @@ -39,7 +39,7 @@ def tokenizer(tokenizer_model: str = None) -> run.Config[pl.LightningModule]: return run.Config( get_nmt_tokenizer, - library='megatronNVIDIAMambaConfig8B', + library='megatron', model_name="GPTSentencePieceTokenizer", tokenizer_model=tokenizer_model, use_fast=True, @@ -69,6 +69,7 @@ def model(tokenizer_model: str = None) -> run.Config[pl.LightningModule]: ) +@run.cli.factory(target=finetune, name=NAME) def trainer( tensor_parallelism: int = 8, pipeline_parallelism: int = 1, @@ -78,7 +79,11 @@ def trainer( sequence_parallelism: bool = False, num_nodes: int = 1, num_gpus_per_node: int = 8, - max_steps: int = 1168251, + max_steps: int = 100, + val_check_interval: int = 100, + limit_test_batches: int = 50, + limit_val_batches: int = 32, + log_every_n_steps: int = 10, callbacks: Optional[list[run.Config[Callback]]] = None, ) -> run.Config[nl.Trainer]: """ @@ -139,15 +144,15 @@ def trainer( accumulate_grad_batches=1, callbacks=callbacks, devices=num_gpus_per_node, - limit_test_batches=50, - limit_val_batches=32, - log_every_n_steps=10, max_steps=max_steps, num_nodes=num_nodes, plugins=bf16_mixed(), strategy=strategy, use_distributed_sampler=False, - val_check_interval=2000, + val_check_interval=val_check_interval, + limit_test_batches=limit_test_batches, + limit_val_batches=limit_val_batches, + log_every_n_steps=log_every_n_steps, ) return trainer @@ -160,6 +165,16 @@ def pretrain_recipe( tokenizer_model: str = None, num_nodes: int = 1, num_gpus_per_node: int = 8, + tensor_parallelism: int = 8, + pipeline_parallelism: int = 1, + max_steps: int = 100, + val_check_interval: int = 100, + limit_test_batches: int = 50, + limit_val_batches: int = 32, + log_every_n_steps: int = 10, + seq_length: int = 4096, + gbs: int = 8, + mbs: int = 1, fn=pretrain, ) -> run.Partial: """ @@ -193,17 +208,24 @@ def pretrain_recipe( """ return run.Partial( fn, - model=model(), + model=model(tokenizer_model=tokenizer_model), trainer=trainer( + max_steps=max_steps, num_nodes=num_nodes, + tensor_parallelism=tensor_parallelism, + pipeline_parallelism=pipeline_parallelism, num_gpus_per_node=num_gpus_per_node, + val_check_interval=val_check_interval, + limit_test_batches=limit_test_batches, + limit_val_batches=limit_val_batches, + log_every_n_steps=log_every_n_steps, callbacks=[run.Config(TimingCallback)], ), data=run.Config( MockDataModule, - seq_length=4096, - global_batch_size=8, - micro_batch_size=1, + seq_length=seq_length, + global_batch_size=gbs, + micro_batch_size=mbs, tokenizer=tokenizer(tokenizer_model=tokenizer_model), ), log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), @@ -220,6 +242,14 @@ def finetune_recipe( name: str = "default", num_nodes: int = 1, num_gpus_per_node: int = 8, + tensor_model_parallel_size: int = 8, + pipeline_model_parallel_size: int = 1, + seq_length: int = 4096, + max_steps: int = 100, + val_check_interval: int = 100, + limit_test_batches: int = 50, + limit_val_batches: int = 32, + log_every_n_steps: int = 10, gbs: int = 8, mbs: int = 1, peft_scheme: Optional[str] = 'none', @@ -268,8 +298,8 @@ def finetune_recipe( ) strategy = run.Config( nl.MegatronStrategy, - tensor_model_parallel_size=8, - pipeline_model_parallel_size=1, + tensor_model_parallel_size=tensor_model_parallel_size, + pipeline_model_parallel_size=pipeline_model_parallel_size, gradient_as_bucket_view=True, ckpt_load_optimizer=False, ckpt_save_optimizer=False, @@ -285,10 +315,11 @@ def finetune_recipe( accelerator="gpu", accumulate_grad_batches=1, devices=num_gpus_per_node, - limit_test_batches=10, - limit_val_batches=10, - log_every_n_steps=20, - max_steps=100, + max_steps=max_steps, + val_check_interval=val_check_interval, + limit_test_batches=limit_test_batches, + limit_val_batches=limit_val_batches, + log_every_n_steps=log_every_n_steps, num_nodes=num_nodes, plugins=run.Config( nl.MegatronMixedPrecision, @@ -298,7 +329,6 @@ def finetune_recipe( callbacks=[checkpoint_callback], strategy=strategy, use_distributed_sampler=False, - val_check_interval=20, ) recipe = run.Partial( llm.finetune, @@ -306,7 +336,7 @@ def finetune_recipe( trainer=trainer, data=run.Config( llm.SquadDataModule, - seq_length=2048, + seq_length=seq_length, global_batch_size=gbs, micro_batch_size=mbs, tokenizer=tokenizer(tokenizer_model=tokenizer_model), diff --git a/nemo/collections/llm/recipes/mistral_7b.py b/nemo/collections/llm/recipes/mistral_7b.py index 7685bcd3ace6..9e2d2e256fbe 100644 --- a/nemo/collections/llm/recipes/mistral_7b.py +++ b/nemo/collections/llm/recipes/mistral_7b.py @@ -15,18 +15,17 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch +from lightning.pytorch.callbacks.callback import Callback from megatron.core.distributed import DistributedDataParallelConfig -from pytorch_lightning.callbacks.callback import Callback from nemo import lightning as nl from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule -from nemo.collections.llm.gpt.data.squad import SquadDataModule from nemo.collections.llm.gpt.model.mistral import MistralConfig7B, MistralModel -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing @@ -207,8 +206,10 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. - packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training efficiency. Default sequence length is 2048. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. + packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training + efficiency. Default sequence length is 2048. Returns: run.Partial: Partial configuration for fine-tuning. @@ -237,8 +238,8 @@ def finetune_recipe( if peft_scheme is None or peft_scheme.lower() == 'none': recipe.trainer.strategy.tensor_model_parallel_size = 2 recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) recipe.optim.config.lr = 1e-4 else: raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") diff --git a/nemo/collections/llm/recipes/mistral_nemo_12b.py b/nemo/collections/llm/recipes/mistral_nemo_12b.py index e6616826d9a8..a10f8ae804b8 100644 --- a/nemo/collections/llm/recipes/mistral_nemo_12b.py +++ b/nemo/collections/llm/recipes/mistral_nemo_12b.py @@ -15,18 +15,17 @@ from typing import Callable, Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch +from lightning.pytorch.callbacks.callback import Callback from megatron.core.distributed import DistributedDataParallelConfig -from pytorch_lightning.callbacks.callback import Callback from nemo import lightning as nl from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule -from nemo.collections.llm.gpt.data.squad import SquadDataModule from nemo.collections.llm.gpt.model.mistral import MistralModel, MistralNeMoConfig12B -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing @@ -255,8 +254,10 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. - packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training efficiency. Default sequence length is 2048. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. + packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training + efficiency. Default sequence length is 2048. Returns: run.Partial: Partial configuration for fine-tuning. @@ -285,8 +286,10 @@ def finetune_recipe( ) if peft_scheme is None or peft_scheme.lower() == 'none': recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA, target_modules=['linear_qkv', 'linear_proj'], dim=32) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config( + PEFT_STR2CLS[peft_scheme.lower()], target_modules=['linear_qkv', 'linear_proj'], dim=32 + ) recipe.optim.config.lr = 1e-4 else: raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") diff --git a/nemo/collections/llm/recipes/mixtral_8x22b.py b/nemo/collections/llm/recipes/mixtral_8x22b.py index f768bf0499b1..ec1641a08d80 100644 --- a/nemo/collections/llm/recipes/mixtral_8x22b.py +++ b/nemo/collections/llm/recipes/mixtral_8x22b.py @@ -15,18 +15,17 @@ from typing import Callable, Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch +from lightning.pytorch.callbacks.callback import Callback from megatron.core.distributed import DistributedDataParallelConfig -from pytorch_lightning.callbacks.callback import Callback from nemo import lightning as nl from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule -from nemo.collections.llm.gpt.data.squad import SquadDataModule from nemo.collections.llm.gpt.model.mixtral import MixtralConfig8x22B, MixtralModel -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing @@ -227,7 +226,7 @@ def pretrain_performance_optimizations(recipe: run.Partial) -> run.Partial: ), run.Config( MegatronCommOverlapCallback, - overlap_param_gather_with_optimizer_step=False, # Currently disabled due to an issue with checkpointing + overlap_param_gather_with_optimizer_step=False, # Currently disabled due to issue with checkpointing align_param_gather=True, ), ] @@ -259,8 +258,10 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. - packed_sequence (Optional[bool]): If true, fine-tuning sequences will be packed into batches up to the given maximum seq_length for better efficiency. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. + packed_sequence (Optional[bool]): If true, fine-tuning sequences will be packed into batches up to the given + maximum seq_length for better efficiency. Returns: run.Partial: Partial configuration for fine-tuning. @@ -286,8 +287,10 @@ def finetune_recipe( recipe.trainer.strategy.pipeline_model_parallel_size = 4 recipe.trainer.strategy.virtual_pipeline_model_parallel_size = 14 recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA, target_modules=['linear_qkv', 'linear_proj'], dim=32) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config( + PEFT_STR2CLS[peft_scheme.lower()], target_modules=['linear_qkv', 'linear_proj'], dim=32 + ) recipe.optim.config.lr = 1e-4 else: raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") diff --git a/nemo/collections/llm/recipes/mixtral_8x7b.py b/nemo/collections/llm/recipes/mixtral_8x7b.py index d4286a15843f..d06e22fc2180 100644 --- a/nemo/collections/llm/recipes/mixtral_8x7b.py +++ b/nemo/collections/llm/recipes/mixtral_8x7b.py @@ -15,18 +15,17 @@ from typing import Callable, Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch +from lightning.pytorch.callbacks.callback import Callback from megatron.core.distributed import DistributedDataParallelConfig -from pytorch_lightning.callbacks.callback import Callback from nemo import lightning as nl from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule -from nemo.collections.llm.gpt.data.squad import SquadDataModule from nemo.collections.llm.gpt.model.mixtral import MixtralConfig8x7B, MixtralModel -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing @@ -222,7 +221,7 @@ def pretrain_performance_optimizations(recipe: run.Partial) -> run.Partial: run.Config(MegatronTokenDropCallback), run.Config( MegatronCommOverlapCallback, - overlap_param_gather_with_optimizer_step=False, # Currently disabled due to an issue with checkpointing. + overlap_param_gather_with_optimizer_step=False, # Currently disabled due to issue with checkpointing. align_param_gather=True, ), ] @@ -254,8 +253,10 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. - packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training efficiency. Default sequence length is 2048. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. + packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training + efficiency. Default sequence length is 2048. Returns: run.Partial: Partial configuration for fine-tuning. @@ -280,8 +281,10 @@ def finetune_recipe( recipe.trainer.strategy.pipeline_model_parallel_size = 4 recipe.trainer.strategy.virtual_pipeline_model_parallel_size = 8 recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA, target_modules=['linear_qkv', 'linear_proj'], dim=32) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config( + PEFT_STR2CLS[peft_scheme.lower()], target_modules=['linear_qkv', 'linear_proj'], dim=32 + ) recipe.optim.config.lr = 1e-4 else: raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") diff --git a/nemo/collections/llm/recipes/mixtral_8x7b_16k.py b/nemo/collections/llm/recipes/mixtral_8x7b_16k.py index 7cbfaf723544..499280cc8542 100644 --- a/nemo/collections/llm/recipes/mixtral_8x7b_16k.py +++ b/nemo/collections/llm/recipes/mixtral_8x7b_16k.py @@ -15,8 +15,8 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch from nemo.collections.llm.api import finetune, pretrain diff --git a/nemo/collections/llm/recipes/mixtral_8x7b_64k.py b/nemo/collections/llm/recipes/mixtral_8x7b_64k.py index 3606be5ec12b..e0702f7b2a63 100644 --- a/nemo/collections/llm/recipes/mixtral_8x7b_64k.py +++ b/nemo/collections/llm/recipes/mixtral_8x7b_64k.py @@ -15,8 +15,8 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch from nemo.collections.llm.api import finetune, pretrain diff --git a/nemo/collections/llm/recipes/nemotron.py b/nemo/collections/llm/recipes/nemotron.py index 104c3798567a..7982665eb3d5 100644 --- a/nemo/collections/llm/recipes/nemotron.py +++ b/nemo/collections/llm/recipes/nemotron.py @@ -14,11 +14,11 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch +from lightning.pytorch.callbacks.callback import Callback from megatron.core.distributed import DistributedDataParallelConfig -from pytorch_lightning.callbacks.callback import Callback from nemo import lightning as nl from nemo.collections.llm.gpt.model.nemotron import ( diff --git a/nemo/collections/llm/recipes/nemotron3_22b.py b/nemo/collections/llm/recipes/nemotron3_22b.py index 724e21f002e3..4c763301bc52 100644 --- a/nemo/collections/llm/recipes/nemotron3_22b.py +++ b/nemo/collections/llm/recipes/nemotron3_22b.py @@ -12,15 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Optional +from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.nemotron import nemotron_model, nemotron_trainer @@ -239,8 +239,10 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. - packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training efficiency. Default sequence length is 2048. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. + packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training + efficiency. Default sequence length is 2048. Returns: run.Partial: Partial configuration for fine-tuning. @@ -265,8 +267,8 @@ def finetune_recipe( if peft_scheme is None or peft_scheme.lower() == 'none': recipe.trainer.strategy.tensor_model_parallel_size = 8 recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) recipe.optim.config.lr = 1e-4 else: raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") diff --git a/nemo/collections/llm/recipes/nemotron3_22b_16k.py b/nemo/collections/llm/recipes/nemotron3_22b_16k.py index 81f4253ad37a..5ae58d1a757d 100644 --- a/nemo/collections/llm/recipes/nemotron3_22b_16k.py +++ b/nemo/collections/llm/recipes/nemotron3_22b_16k.py @@ -14,8 +14,8 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch from nemo.collections.llm.api import pretrain diff --git a/nemo/collections/llm/recipes/nemotron3_22b_64k.py b/nemo/collections/llm/recipes/nemotron3_22b_64k.py index 676694697e4c..22f6291cfadb 100644 --- a/nemo/collections/llm/recipes/nemotron3_22b_64k.py +++ b/nemo/collections/llm/recipes/nemotron3_22b_64k.py @@ -14,8 +14,8 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch from nemo.collections.llm.api import pretrain diff --git a/nemo/collections/llm/recipes/nemotron3_4b.py b/nemo/collections/llm/recipes/nemotron3_4b.py index e1c2ef345d7e..fc6f09a09358 100644 --- a/nemo/collections/llm/recipes/nemotron3_4b.py +++ b/nemo/collections/llm/recipes/nemotron3_4b.py @@ -14,13 +14,13 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.nemotron import nemotron_model, nemotron_trainer @@ -191,8 +191,10 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. - packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training efficiency. Default sequence length is 2048. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. + packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training + efficiency. Default sequence length is 2048. Returns: run.Partial: Partial configuration for fine-tuning. @@ -216,8 +218,8 @@ def finetune_recipe( ) if peft_scheme is None or peft_scheme.lower() == 'none': recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) recipe.optim.config.lr = 1e-4 else: raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") diff --git a/nemo/collections/llm/recipes/nemotron3_8b.py b/nemo/collections/llm/recipes/nemotron3_8b.py index 202efe658d83..f60463330cad 100644 --- a/nemo/collections/llm/recipes/nemotron3_8b.py +++ b/nemo/collections/llm/recipes/nemotron3_8b.py @@ -12,17 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Optional +from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch from nemo import lightning as nl from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule -from nemo.collections.llm.gpt.data.squad import SquadDataModule -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.nemotron import nemotron_model, nemotron_trainer @@ -256,8 +255,10 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. - packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training efficiency. Default sequence length is 2048. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. + packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training + efficiency. Default sequence length is 2048. Returns: run.Partial: Partial configuration for fine-tuning. @@ -282,8 +283,8 @@ def finetune_recipe( if peft_scheme is None or peft_scheme.lower() == 'none': recipe.trainer.strategy.tensor_model_parallel_size = 2 recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) recipe.optim.config.lr = 1e-4 else: raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") diff --git a/nemo/collections/llm/recipes/nemotron4_15b.py b/nemo/collections/llm/recipes/nemotron4_15b.py index 0f15c47c67b9..49f92fcc1616 100644 --- a/nemo/collections/llm/recipes/nemotron4_15b.py +++ b/nemo/collections/llm/recipes/nemotron4_15b.py @@ -12,15 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Optional +from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.nemotron import nemotron_model, nemotron_trainer @@ -228,8 +228,10 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. - packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training efficiency. Default sequence length is 2048. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. + packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training + efficiency. Default sequence length is 2048. Returns: run.Partial: Partial configuration for fine-tuning. @@ -254,8 +256,8 @@ def finetune_recipe( if peft_scheme is None or peft_scheme.lower() == 'none': recipe.trainer.strategy.tensor_model_parallel_size = 4 recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) recipe.optim.config.lr = 1e-4 else: raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") diff --git a/nemo/collections/llm/recipes/nemotron4_15b_16k.py b/nemo/collections/llm/recipes/nemotron4_15b_16k.py index 75eced72761f..e16c2b03b032 100644 --- a/nemo/collections/llm/recipes/nemotron4_15b_16k.py +++ b/nemo/collections/llm/recipes/nemotron4_15b_16k.py @@ -14,8 +14,8 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch from nemo.collections.llm.api import pretrain diff --git a/nemo/collections/llm/recipes/nemotron4_15b_64k.py b/nemo/collections/llm/recipes/nemotron4_15b_64k.py index 8286778aa7ba..2cedfbed398b 100644 --- a/nemo/collections/llm/recipes/nemotron4_15b_64k.py +++ b/nemo/collections/llm/recipes/nemotron4_15b_64k.py @@ -14,8 +14,8 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch from nemo.collections.llm.api import pretrain diff --git a/nemo/collections/llm/recipes/nemotron4_340b.py b/nemo/collections/llm/recipes/nemotron4_340b.py index c02950109669..14d4c0f32d11 100644 --- a/nemo/collections/llm/recipes/nemotron4_340b.py +++ b/nemo/collections/llm/recipes/nemotron4_340b.py @@ -12,16 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Optional +from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch -from nemo import lightning as nl from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.nemotron import nemotron_model, nemotron_trainer @@ -240,8 +239,10 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. - packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training efficiency. Default sequence length is 2048. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. + packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training + efficiency. Default sequence length is 2048. Returns: run.Partial: Partial configuration for fine-tuning. @@ -268,8 +269,8 @@ def finetune_recipe( recipe.trainer.strategy.tensor_model_parallel_size = 8 recipe.trainer.strategy.pipeline_model_parallel_size = 12 recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) recipe.trainer.strategy.tensor_model_parallel_size = 8 recipe.trainer.strategy.pipeline_model_parallel_size = 4 recipe.optim.config.lr = 1e-4 diff --git a/nemo/collections/llm/recipes/phi3_mini_4k_instruct.py b/nemo/collections/llm/recipes/phi3_mini_4k_instruct.py new file mode 100644 index 000000000000..73bbe4735adb --- /dev/null +++ b/nemo/collections/llm/recipes/phi3_mini_4k_instruct.py @@ -0,0 +1,283 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Optional + +import lightning.pytorch as pl +import nemo_run as run +import torch +from lightning.pytorch.callbacks.callback import Callback +from megatron.core.distributed import DistributedDataParallelConfig + +from nemo import lightning as nl +from nemo.collections.llm.api import finetune, pretrain +from nemo.collections.llm.gpt.data.mock import MockDataModule +from nemo.collections.llm.gpt.data.packed_sequence import PackedSequenceSpecs +from nemo.collections.llm.gpt.model.phi3mini import Phi3ConfigMini, Phi3Model +from nemo.collections.llm.peft import PEFT_STR2CLS +from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe +from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger +from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing +from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed +from nemo.utils.exp_manager import TimingCallback + +NAME = "phi3_mini_4k_instruct" + + +@run.cli.factory(name=NAME) +def model() -> run.Config[pl.LightningModule]: + """ + Factory function to create a Phi3 Mini 4k instruct model configuration. + + Returns: + run.Config[pl.LightningModule]: Configuration for the Phi3 mini 4k instruct model. + + Examples: + CLI usage: + $ nemo llm pretrain model=phi3_mini_4k_instruct ... + + Python API usage: + >>> model_config = model() + >>> print(model_config) + """ + return run.Config(Phi3Model, config=run.Config(Phi3ConfigMini)) + + +def trainer( + tensor_parallelism: int = 1, + pipeline_parallelism: int = 1, + pipeline_parallelism_type: Optional[torch.dtype] = None, + virtual_pipeline_parallelism: Optional[int] = None, + context_parallelism: int = 1, + sequence_parallelism: bool = False, + num_nodes: int = 1, + num_gpus_per_node: int = 1, + max_steps: int = 1168251, + callbacks: Optional[list[run.Config[Callback]]] = None, +) -> run.Config[nl.Trainer]: + """ + Configure the NeMo Lightning Trainer for Phi3 mini 4k instruct model. + + This function sets up the distributed training strategy and other training parameters. + + Args: + tensor_parallelism (int): Degree of tensor model parallelism. + pipeline_parallelism (int): Degree of pipeline model parallelism. + pipeline_parallelism_type (Optional[torch.dtype]): Data type for pipeline parallelism. + virtual_pipeline_parallelism (Optional[int]): Size of virtual pipeline parallelism. + context_parallelism (int): Degree of context parallelism. + sequence_parallelism (bool): Whether to use sequence parallelism. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + max_steps (int): Maximum number of training steps. + callbacks (Optional[list[run.Config[Callback]]]): List of callback configurations. + + Returns: + run.Config[nl.Trainer]: Configuration for the NeMo Lightning Trainer. + + Examples: + CLI usage: + $ nemo llm pretrain trainer=phi3_mini_4k_instruct ... + + Python API usage: + >>> trainer_config = trainer(num_nodes=2, num_gpus_per_node=8) + >>> print(trainer_config) + + Note: + For more information on distributed training strategies, refer to the + NeMo documentation on multi-GPU and multi-node training. + """ + strategy = run.Config( + nl.MegatronStrategy, + tensor_model_parallel_size=tensor_parallelism, + pipeline_model_parallel_size=pipeline_parallelism, + pipeline_dtype=pipeline_parallelism_type, + virtual_pipeline_model_parallel_size=virtual_pipeline_parallelism, + context_parallel_size=context_parallelism, + sequence_parallel=sequence_parallelism, + gradient_as_bucket_view=True, + ckpt_async_save=True, + ckpt_parallel_load=True, + ddp=run.Config( + DistributedDataParallelConfig, + check_for_nan_in_grad=True, + grad_reduce_in_fp32=True, + overlap_grad_reduce=True, + overlap_param_gather=True, + average_in_collective=True, + ), + ) + + trainer = run.Config( + nl.Trainer, + accelerator="gpu", + accumulate_grad_batches=1, + callbacks=callbacks, + devices=num_gpus_per_node, + limit_test_batches=50, + limit_val_batches=32, + log_every_n_steps=10, + max_steps=max_steps, + num_nodes=num_nodes, + plugins=bf16_mixed(), + strategy=strategy, + use_distributed_sampler=False, + val_check_interval=2000, + ) + + return trainer + + +@run.cli.factory(target=pretrain, name=NAME) +def pretrain_recipe( + dir: Optional[str] = None, + name: str = "default", + num_nodes: int = 1, + tensor_parallelism: int = 1, + num_gpus_per_node: int = 1, + max_steps: int = 1168251, + performance_mode: bool = False, + fn: Callable = pretrain, +) -> run.Partial: + """ + Create a pre-training recipe for phi3_mini_4k_instruct model. + + This function sets up a complete configuration for pre-training, including + model, trainer, data, logging, optimization, and resumption settings. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the pre-training run. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + performance_mode (bool): If true, enables optimizations for maximum performance. + fn (Callable): The pre-training function to use. + + Returns: + run.Partial: Partial configuration for pre-training. + + Examples: + CLI usage: + $ nemo llm pretrain --factory phi3_mini_4k_instruct + $ nemo llm pretrain --factory "phi3_mini_4k_instruct(num_nodes=1, name='my_pretrain')" + + Python API usage: + >>> recipe = pretrain_recipe(name="phi3_mini_4k_instruct", num_nodes=1) + >>> print(recipe) + + Note: + For more details on pre-training LLMs with NeMo, see the pre-training + guide in the `examples/llm/pretrain/` directory. + """ + recipe = run.Partial( + fn, + model=model(), + trainer=trainer( + num_nodes=num_nodes, + num_gpus_per_node=num_gpus_per_node, + callbacks=[run.Config(TimingCallback)], + ), + data=run.Config(MockDataModule, seq_length=4096, global_batch_size=512, micro_batch_size=1), + log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), + optim=distributed_fused_adam_with_cosine_annealing(max_lr=3e-4), + resume=default_resume(), + ) + + return recipe + + +@run.cli.factory(target=finetune, name=NAME) +def finetune_recipe( + dir: Optional[str] = None, + name: str = "default", + num_nodes: int = 1, + num_gpus_per_node: int = 1, + tensor_parallelism: int = 1, + max_steps: int = 116825, + peft_scheme: Optional[str] = 'lora', + seq_length: Optional[int] = None, + packed_sequence: Optional[bool] = None, + performance_mode: bool = False, +) -> run.Partial: + """ + Create a fine-tuning recipe for Phi3 mini-4k-instruct model. + + This function sets up a complete configuration for fine-tuning, including + model, trainer, data, logging, optimization, and resumption settings. + The recipe uses LoRA (Low-Rank Adaptation) for efficient fine-tuning, unless peft_scheme is set to None. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the fine-tuning run. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. + seq_length (int): Maximum number of tokens per microbatch. + packed_sequence (Optional[bool]): If true, fine-tuning sequences will be packed into batches up to the given + maximum seq_length for better efficiency. By default, this value equals performance_mode. + performance_mode (bool): If true, enables optimizations for maximum performance. + + Returns: + run.Partial: Partial configuration for fine-tuning. + + Examples: + CLI usage: + $ nemo llm finetune --factory phi3_mini_4k_instruct + + Python API usage: + >>> recipe = finetune_recipe(name="phi3_mini_4k_instruct", num_nodes=1) + >>> print(recipe) + + Note: + This recipe uses the SQuAD dataset for fine-tuning. For more information + on fine-tuning LLMs with NeMo, see the fine-tuning guide in the + `examples/llm/finetune/` directory. + """ + # Default to unpacked data in normal mode and packed data in performance mode + # once packing recipe is well tested, change this default to true + if packed_sequence is None: + packed_sequence = performance_mode + + # For unpacked sequence, most samples in SQuAD dataset are shorter than 2K + if seq_length is None: + seq_length = 4096 if packed_sequence else 2048 + + recipe = default_finetune_recipe( + model(), "microsoft/Phi-3-mini-4k-instruct", dir, name, num_nodes, num_gpus_per_node, packed_sequence + ) + if peft_scheme is None or peft_scheme.lower() == 'none': + recipe.trainer.strategy.tensor_model_parallel_size = 1 + recipe.optim.config.lr = 5e-6 + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) + recipe.peft.dim = 8 + recipe.peft.alpha = 16 + recipe.optim.config.use_distributed_optimizer = False + + # some settings currently do not function correctly with LoRA + recipe.model.config.cross_entropy_loss_fusion = False + + recipe.optim.config.lr = 1e-4 + else: + raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") + + # Sequence length settings in the model and dataset must agree + recipe.model.config.seq_length = seq_length + recipe.data.seq_length = seq_length + if packed_sequence: + recipe.data.pad_to_max_length = True + recipe.data.packed_sequence_specs = run.Config(PackedSequenceSpecs, packed_sequence_size=seq_length) + + return recipe diff --git a/nemo/collections/llm/recipes/qwen2.py b/nemo/collections/llm/recipes/qwen2.py index ff0c76a714f1..db9dcfc88865 100644 --- a/nemo/collections/llm/recipes/qwen2.py +++ b/nemo/collections/llm/recipes/qwen2.py @@ -14,10 +14,10 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch -from pytorch_lightning.callbacks.callback import Callback +from lightning.pytorch.callbacks.callback import Callback from nemo import lightning as nl from nemo.collections.llm.gpt.model.qwen2 import ( diff --git a/nemo/collections/llm/recipes/qwen2_1p5b.py b/nemo/collections/llm/recipes/qwen2_1p5b.py index 662f8e98899d..99ba5cd907fc 100644 --- a/nemo/collections/llm/recipes/qwen2_1p5b.py +++ b/nemo/collections/llm/recipes/qwen2_1p5b.py @@ -14,13 +14,13 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing @@ -194,8 +194,10 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. - packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training efficiency. Default sequence length is 2048. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. + packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training + efficiency. Default sequence length is 2048. Returns: run.Partial: Partial configuration for fine-tuning. @@ -218,8 +220,8 @@ def finetune_recipe( ) if peft_scheme is None or peft_scheme.lower() == 'none': recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) recipe.optim.config.lr = 1e-4 else: raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") diff --git a/nemo/collections/llm/recipes/qwen2_500m.py b/nemo/collections/llm/recipes/qwen2_500m.py index ac6cbfe84464..96d99c271c85 100644 --- a/nemo/collections/llm/recipes/qwen2_500m.py +++ b/nemo/collections/llm/recipes/qwen2_500m.py @@ -14,13 +14,13 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing @@ -194,8 +194,10 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. - packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training efficiency. Default sequence length is 2048. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. + packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training + efficiency. Default sequence length is 2048. Returns: run.Partial: Partial configuration for fine-tuning. @@ -218,8 +220,8 @@ def finetune_recipe( ) if peft_scheme is None or peft_scheme.lower() == 'none': recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) recipe.optim.config.lr = 1e-4 else: raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") diff --git a/nemo/collections/llm/recipes/qwen2_72b.py b/nemo/collections/llm/recipes/qwen2_72b.py index 0b94761e5749..33bb0dd40835 100644 --- a/nemo/collections/llm/recipes/qwen2_72b.py +++ b/nemo/collections/llm/recipes/qwen2_72b.py @@ -14,13 +14,13 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing @@ -194,8 +194,10 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. - packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training efficiency. Default sequence length is 2048. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. + packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training + efficiency. Default sequence length is 2048. Returns: run.Partial: Partial configuration for fine-tuning. @@ -221,8 +223,8 @@ def finetune_recipe( recipe.trainer.strategy.tensor_model_parallel_size = 8 recipe.trainer.strategy.pipeline_model_parallel_size = 4 recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) recipe.trainer.strategy.tensor_model_parallel_size = 8 recipe.optim.config.lr = 1e-4 else: diff --git a/nemo/collections/llm/recipes/qwen2_7b.py b/nemo/collections/llm/recipes/qwen2_7b.py index 10c990f15142..2e62176a408e 100644 --- a/nemo/collections/llm/recipes/qwen2_7b.py +++ b/nemo/collections/llm/recipes/qwen2_7b.py @@ -14,13 +14,13 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing @@ -194,8 +194,10 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. - packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training efficiency. Default sequence length is 2048. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. + packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training + efficiency. Default sequence length is 2048. Returns: run.Partial: Partial configuration for fine-tuning. @@ -219,8 +221,8 @@ def finetune_recipe( if peft_scheme is None or peft_scheme.lower() == 'none': recipe.trainer.strategy.tensor_model_parallel_size = 2 recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) recipe.optim.config.lr = 1e-4 else: raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") diff --git a/nemo/collections/llm/recipes/run/__init__.py b/nemo/collections/llm/recipes/run/__init__.py new file mode 100644 index 000000000000..d9155f923f18 --- /dev/null +++ b/nemo/collections/llm/recipes/run/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/llm/recipes/run/executor.py b/nemo/collections/llm/recipes/run/executor.py new file mode 100644 index 000000000000..fe14a4f55bd2 --- /dev/null +++ b/nemo/collections/llm/recipes/run/executor.py @@ -0,0 +1,31 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import nemo_run as run + + +@run.cli.factory +def torchrun(devices: int = 8) -> run.Config[run.LocalExecutor]: + """Local executor using torchrun.""" + env_vars = { + "TORCH_NCCL_AVOID_RECORD_STREAMS": "1", + } + + executor = run.Config( + run.LocalExecutor, + ntasks_per_node=devices, + launcher="torchrun", + env_vars=env_vars, + ) + + return executor diff --git a/nemo/collections/llm/recipes/starcoder2.py b/nemo/collections/llm/recipes/starcoder2.py index c3a19326585c..b090ce1cf9ef 100644 --- a/nemo/collections/llm/recipes/starcoder2.py +++ b/nemo/collections/llm/recipes/starcoder2.py @@ -14,10 +14,11 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch -from pytorch_lightning.callbacks.callback import Callback +from lightning.pytorch.callbacks.callback import Callback + from nemo import lightning as nl from nemo.collections.llm.gpt.model.starcoder2 import ( Starcoder2Config3B, diff --git a/nemo/collections/llm/recipes/starcoder2_15b.py b/nemo/collections/llm/recipes/starcoder2_15b.py index a59ec272c865..e424cb67dba4 100644 --- a/nemo/collections/llm/recipes/starcoder2_15b.py +++ b/nemo/collections/llm/recipes/starcoder2_15b.py @@ -14,13 +14,13 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing @@ -194,8 +194,10 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. - packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training efficiency. Default sequence length is 2048. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. + packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training + efficiency. Default sequence length is 2048. Returns: run.Partial: Partial configuration for fine-tuning. @@ -219,8 +221,8 @@ def finetune_recipe( if peft_scheme is None or peft_scheme.lower() == 'none': recipe.trainer.strategy.tensor_model_parallel_size = 4 recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) recipe.optim.config.lr = 1e-4 else: raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") diff --git a/nemo/collections/llm/recipes/starcoder2_3b.py b/nemo/collections/llm/recipes/starcoder2_3b.py index 55884b353d8f..faf0b416c56a 100644 --- a/nemo/collections/llm/recipes/starcoder2_3b.py +++ b/nemo/collections/llm/recipes/starcoder2_3b.py @@ -14,13 +14,13 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing @@ -194,8 +194,10 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. - packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training efficiency. Default sequence length is 2048. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. + packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training + efficiency. Default sequence length is 2048. Returns: run.Partial: Partial configuration for fine-tuning. @@ -219,8 +221,8 @@ def finetune_recipe( if peft_scheme is None or peft_scheme.lower() == 'none': recipe.trainer.strategy.tensor_model_parallel_size = 2 recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) recipe.optim.config.lr = 1e-4 else: raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") diff --git a/nemo/collections/llm/recipes/starcoder2_7b.py b/nemo/collections/llm/recipes/starcoder2_7b.py index 46e34b8b0c77..091e882cd932 100644 --- a/nemo/collections/llm/recipes/starcoder2_7b.py +++ b/nemo/collections/llm/recipes/starcoder2_7b.py @@ -14,13 +14,13 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing @@ -194,8 +194,10 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. - packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training efficiency. Default sequence length is 2048. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. + packed_sequence (Optional[bool]): Packing multiple training sequences into one long sequence for training + efficiency. Default sequence length is 2048. Returns: run.Partial: Partial configuration for fine-tuning. @@ -219,8 +221,8 @@ def finetune_recipe( if peft_scheme is None or peft_scheme.lower() == 'none': recipe.trainer.strategy.tensor_model_parallel_size = 2 recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) recipe.optim.config.lr = 1e-4 else: raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") diff --git a/nemo/collections/llm/recipes/starcoder_15b.py b/nemo/collections/llm/recipes/starcoder_15b.py index cb0ba14df868..382d0eb4d8ca 100644 --- a/nemo/collections/llm/recipes/starcoder_15b.py +++ b/nemo/collections/llm/recipes/starcoder_15b.py @@ -14,16 +14,16 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch -from pytorch_lightning.callbacks.callback import Callback +from lightning.pytorch.callbacks.callback import Callback from nemo import lightning as nl from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule from nemo.collections.llm.gpt.model.starcoder import StarcoderConfig15B, StarcoderModel -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_recipe from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing @@ -280,7 +280,8 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. Returns: run.Partial: Partial configuration for fine-tuning. @@ -302,8 +303,8 @@ def finetune_recipe( if peft_scheme is None or peft_scheme.lower() == 'none': recipe.trainer.strategy.pipeline_model_parallel_size = 8 recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) recipe.optim.config.lr = 1e-4 else: raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") diff --git a/nemo/collections/llm/recipes/t5_11b.py b/nemo/collections/llm/recipes/t5_11b.py index b3806e6f2540..c54bf48b9613 100644 --- a/nemo/collections/llm/recipes/t5_11b.py +++ b/nemo/collections/llm/recipes/t5_11b.py @@ -15,16 +15,16 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch +from lightning.pytorch.callbacks.callback import Callback from megatron.core.distributed import DistributedDataParallelConfig from megatron.core.optimizer import OptimizerConfig -from pytorch_lightning.callbacks.callback import Callback from nemo import lightning as nl from nemo.collections.llm.api import finetune, pretrain -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_trainer, nemo_resume from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed @@ -175,7 +175,8 @@ def pretrain_recipe( guide in the `examples/llm/pretrain/` directory. """ - opt_config = OptimizerConfig( + opt_config = run.Config( + OptimizerConfig, optimizer='adam', lr=0.0001, use_distributed_optimizer=True, @@ -183,7 +184,8 @@ def pretrain_recipe( weight_decay=0.01, ) - lr_scheduler = WarmupAnnealingScheduler( + lr_scheduler = run.Config( + WarmupAnnealingScheduler, warmup_steps=None, warmup_ratio=0.01, max_steps=1000000, @@ -202,7 +204,7 @@ def pretrain_recipe( MockDataModule, seq_length=512, seq_length_dec=128, global_batch_size=1920, micro_batch_size=24 ), log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), - optim=MegatronOptimizerModule(config=opt_config, lr_scheduler=lr_scheduler), + optim=run.Config(MegatronOptimizerModule, config=opt_config, lr_scheduler=lr_scheduler), resume=default_resume(), ) @@ -229,7 +231,8 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. Returns: run.Partial: Partial configuration for fine-tuning. @@ -247,15 +250,17 @@ def finetune_recipe( on fine-tuning LLMs with NeMo, see the fine-tuning guide in the `examples/llm/finetune/` directory. """ - opt_config = OptimizerConfig( + opt_config = run.Config( + OptimizerConfig, optimizer='adam', - lr=1e-4, + lr=0.0001, use_distributed_optimizer=True, bf16=True, weight_decay=0.01, ) - lr_scheduler = WarmupAnnealingScheduler( + lr_scheduler = run.Config( + WarmupAnnealingScheduler, warmup_steps=50, max_steps=2000, min_lr=0.00001, @@ -272,15 +277,15 @@ def finetune_recipe( SquadDataModule, seq_length=512, seq_length_dec=128, global_batch_size=128, micro_batch_size=1 ), log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), - optim=MegatronOptimizerModule(config=opt_config, lr_scheduler=lr_scheduler), + optim=run.Config(MegatronOptimizerModule, config=opt_config, lr_scheduler=lr_scheduler), resume=nemo_resume(checkpoint_path), ) if peft_scheme is None or peft_scheme.lower() == 'none': recipe.trainer.strategy.tensor_model_parallel_size = 4 recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) recipe.optim.config.lr = 1e-4 else: raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") diff --git a/nemo/collections/llm/recipes/t5_220m.py b/nemo/collections/llm/recipes/t5_220m.py index e220eb3fb1b0..975ac5519859 100644 --- a/nemo/collections/llm/recipes/t5_220m.py +++ b/nemo/collections/llm/recipes/t5_220m.py @@ -15,16 +15,16 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch +from lightning.pytorch.callbacks.callback import Callback from megatron.core.distributed import DistributedDataParallelConfig from megatron.core.optimizer import OptimizerConfig -from pytorch_lightning.callbacks.callback import Callback from nemo import lightning as nl from nemo.collections.llm.api import finetune, pretrain -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_trainer, nemo_resume from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed @@ -175,7 +175,8 @@ def pretrain_recipe( guide in the `examples/llm/pretrain/` directory. """ - opt_config = OptimizerConfig( + opt_config = run.Config( + OptimizerConfig, optimizer='adam', lr=0.0001, use_distributed_optimizer=True, @@ -183,7 +184,8 @@ def pretrain_recipe( weight_decay=0.01, ) - lr_scheduler = WarmupAnnealingScheduler( + lr_scheduler = run.Config( + WarmupAnnealingScheduler, warmup_steps=None, warmup_ratio=0.01, max_steps=1000000, @@ -200,7 +202,7 @@ def pretrain_recipe( ), data=run.Config(MockDataModule, seq_length=512, seq_length_dec=128, global_batch_size=512, micro_batch_size=1), log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), - optim=MegatronOptimizerModule(config=opt_config, lr_scheduler=lr_scheduler), + optim=run.Config(MegatronOptimizerModule, config=opt_config, lr_scheduler=lr_scheduler), resume=default_resume(), ) @@ -227,7 +229,8 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. Returns: run.Partial: Partial configuration for fine-tuning. @@ -245,15 +248,17 @@ def finetune_recipe( on fine-tuning LLMs with NeMo, see the fine-tuning guide in the `examples/llm/finetune/` directory. """ - opt_config = OptimizerConfig( + opt_config = run.Config( + OptimizerConfig, optimizer='adam', - lr=1e-4, + lr=0.0001, use_distributed_optimizer=True, bf16=True, weight_decay=0.01, ) - lr_scheduler = WarmupAnnealingScheduler( + lr_scheduler = run.Config( + WarmupAnnealingScheduler, warmup_steps=50, max_steps=2000, min_lr=0.00001, @@ -270,16 +275,17 @@ def finetune_recipe( SquadDataModule, seq_length=512, seq_length_dec=128, global_batch_size=128, micro_batch_size=1 ), log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), - optim=MegatronOptimizerModule(config=opt_config, lr_scheduler=lr_scheduler), + optim=run.Config(MegatronOptimizerModule, config=opt_config, lr_scheduler=lr_scheduler), resume=nemo_resume(checkpoint_path), ) if peft_scheme is None or peft_scheme.lower() == 'none': recipe.trainer.strategy.tensor_model_parallel_size = 1 recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) recipe.optim.config.lr = 1e-4 else: raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") + return recipe diff --git a/nemo/collections/llm/recipes/t5_3b.py b/nemo/collections/llm/recipes/t5_3b.py index e7f215d57635..b1783594d2f7 100644 --- a/nemo/collections/llm/recipes/t5_3b.py +++ b/nemo/collections/llm/recipes/t5_3b.py @@ -15,16 +15,16 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch +from lightning.pytorch.callbacks.callback import Callback from megatron.core.distributed import DistributedDataParallelConfig from megatron.core.optimizer import OptimizerConfig -from pytorch_lightning.callbacks.callback import Callback from nemo import lightning as nl from nemo.collections.llm.api import finetune, pretrain -from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.peft import PEFT_STR2CLS from nemo.collections.llm.recipes.finetune_default import default_finetune_trainer, nemo_resume from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed @@ -175,7 +175,8 @@ def pretrain_recipe( guide in the `examples/llm/pretrain/` directory. """ - opt_config = OptimizerConfig( + opt_config = run.Config( + OptimizerConfig, optimizer='adam', lr=0.0001, use_distributed_optimizer=True, @@ -183,7 +184,8 @@ def pretrain_recipe( weight_decay=0.01, ) - lr_scheduler = WarmupAnnealingScheduler( + lr_scheduler = run.Config( + WarmupAnnealingScheduler, warmup_steps=None, warmup_ratio=0.01, max_steps=1000000, @@ -202,7 +204,7 @@ def pretrain_recipe( MockDataModule, seq_length=512, seq_length_dec=128, global_batch_size=1920, micro_batch_size=24 ), log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), - optim=MegatronOptimizerModule(config=opt_config, lr_scheduler=lr_scheduler), + optim=run.Config(MegatronOptimizerModule, config=opt_config, lr_scheduler=lr_scheduler), resume=default_resume(), ) @@ -229,7 +231,8 @@ def finetune_recipe( name (str): Name of the fine-tuning run. num_nodes (int): Number of compute nodes to use. num_gpus_per_node (int): Number of GPUs per node. - peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. Allowed values: 'lora', 'none'/None. + peft_scheme (Optional[str]): Name of the peft scheme to use for fine-tuning. + Allowed values: 'lora'/'dora'/'none'/None. Returns: run.Partial: Partial configuration for fine-tuning. @@ -247,15 +250,17 @@ def finetune_recipe( on fine-tuning LLMs with NeMo, see the fine-tuning guide in the `examples/llm/finetune/` directory. """ - opt_config = OptimizerConfig( + opt_config = run.Config( + OptimizerConfig, optimizer='adam', - lr=1e-4, + lr=0.0001, use_distributed_optimizer=True, bf16=True, weight_decay=0.01, ) - lr_scheduler = WarmupAnnealingScheduler( + lr_scheduler = run.Config( + WarmupAnnealingScheduler, warmup_steps=50, max_steps=2000, min_lr=0.00001, @@ -272,15 +277,15 @@ def finetune_recipe( SquadDataModule, seq_length=512, seq_length_dec=128, global_batch_size=128, micro_batch_size=1 ), log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), - optim=MegatronOptimizerModule(config=opt_config, lr_scheduler=lr_scheduler), + optim=run.Config(MegatronOptimizerModule, config=opt_config, lr_scheduler=lr_scheduler), resume=nemo_resume(checkpoint_path), ) if peft_scheme is None or peft_scheme.lower() == 'none': recipe.trainer.strategy.tensor_model_parallel_size = 2 recipe.optim.config.lr = 5e-6 - elif peft_scheme.lower() == 'lora': - recipe.peft = run.Config(LoRA) + elif peft_scheme.lower() in ['lora', 'dora']: + recipe.peft = run.Config(PEFT_STR2CLS[peft_scheme.lower()]) recipe.optim.config.lr = 1e-4 else: raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") diff --git a/nemo/collections/llm/t5/data/__init__.py b/nemo/collections/llm/t5/data/__init__.py index d65f6923033f..e4f879b76a06 100644 --- a/nemo/collections/llm/t5/data/__init__.py +++ b/nemo/collections/llm/t5/data/__init__.py @@ -1,5 +1,6 @@ from nemo.collections.llm.t5.data.fine_tuning import FineTuningDataModule +from nemo.collections.llm.t5.data.mock import MockDataModule from nemo.collections.llm.t5.data.pre_training import PreTrainingDataModule from nemo.collections.llm.t5.data.squad import SquadDataModule -__all__ = ["FineTuningDataModule", "PreTrainingDataModule", "SquadDataModule"] +__all__ = ["FineTuningDataModule", "PreTrainingDataModule", "SquadDataModule", "MockDataModule"] diff --git a/nemo/collections/llm/t5/data/fine_tuning.py b/nemo/collections/llm/t5/data/fine_tuning.py index 4180b4f135cb..ced4ea1a0b37 100644 --- a/nemo/collections/llm/t5/data/fine_tuning.py +++ b/nemo/collections/llm/t5/data/fine_tuning.py @@ -17,7 +17,7 @@ from pathlib import Path from typing import TYPE_CHECKING, List, Optional, Union -import pytorch_lightning as pl +import lightning.pytorch as pl from torch.utils.data import DataLoader from nemo.collections.llm.t5.data.core import create_sft_dataset diff --git a/nemo/collections/llm/t5/data/mock.py b/nemo/collections/llm/t5/data/mock.py index eaf41d290da4..31198a4446e9 100644 --- a/nemo/collections/llm/t5/data/mock.py +++ b/nemo/collections/llm/t5/data/mock.py @@ -14,10 +14,10 @@ from typing import TYPE_CHECKING, Dict, List, Optional +import lightning.pytorch as pl import numpy as np -import pytorch_lightning as pl import torch -from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS +from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS from torch.utils import data from torch.utils.data import DataLoader, Dataset @@ -125,13 +125,11 @@ def __init__( self.seed = seed self.create_attention_mask = create_attention_mask - self.mask_encoder = torch.ones((self.seq_length, self.seq_length), device='cpu') - self.mask_decoder = torch.tril(torch.ones((self.seq_length_dec, self.seq_length_dec), device='cpu')) - self.mask_encoder_decoder = torch.ones((self.seq_length_dec, self.seq_length), device='cpu') + # update for T5 now use FlashFused attention (b11s) + self.mask_encoder = torch.ones(self.seq_length, device='cpu') + self.mask_decoder = torch.ones(self.seq_length_dec, device='cpu') self.mask_encoder = self.mask_encoder < 0.5 self.mask_decoder = self.mask_decoder < 0.5 - self.mask_encoder_decoder = self.mask_encoder_decoder < 0.5 - self.loss_mask = torch.ones(self.seq_length_dec, dtype=torch.float) def __len__(self) -> int: @@ -156,7 +154,6 @@ def __getitem__(self, idx) -> Dict[str, torch.Tensor]: "truncated": 0, "enc_mask": self.mask_encoder, "dec_mask": self.mask_decoder, - "enc_dec_mask": self.mask_encoder_decoder, } return batch diff --git a/nemo/collections/llm/t5/data/pre_training.py b/nemo/collections/llm/t5/data/pre_training.py index 45d485ba2074..4bd6e5ed5e93 100644 --- a/nemo/collections/llm/t5/data/pre_training.py +++ b/nemo/collections/llm/t5/data/pre_training.py @@ -17,8 +17,8 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, List, Optional -import pytorch_lightning as pl -from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS +import lightning.pytorch as pl +from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS from torch.utils import data from nemo.lightning.data import WrappedDataLoader diff --git a/nemo/collections/llm/t5/data/squad.py b/nemo/collections/llm/t5/data/squad.py index 3e413919211c..4e90b09e622e 100644 --- a/nemo/collections/llm/t5/data/squad.py +++ b/nemo/collections/llm/t5/data/squad.py @@ -42,6 +42,7 @@ class SquadDataModule(FineTuningDataModule, IOMixin): def __init__( self, + dataset_root: str = None, seq_length: int = 512, seq_length_dec: int = 128, tokenizer: Optional["TokenizerSpec"] = None, @@ -60,7 +61,7 @@ def __init__( self.delete_raw = delete_raw super().__init__( - dataset_root=get_dataset_root("squad"), + dataset_root=get_dataset_root("squad") if dataset_root is None else dataset_root, seq_length=seq_length, seq_length_dec=seq_length_dec, tokenizer=tokenizer, diff --git a/nemo/collections/llm/t5/model/t5.py b/nemo/collections/llm/t5/model/t5.py index 058acaaec7b0..940c0e51ee92 100644 --- a/nemo/collections/llm/t5/model/t5.py +++ b/nemo/collections/llm/t5/model/t5.py @@ -16,11 +16,12 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Callable, Dict, Literal, Optional, Union -import pytorch_lightning as L +import lightning.pytorch as L import torch import torch.distributed from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import InferenceWrapperConfig from megatron.core.inference.model_inference_wrappers.t5.t5_inference_wrapper import T5InferenceWrapper +from megatron.core.models.T5.t5_model import T5Model as MCoreT5Model from megatron.core.optimizer import OptimizerConfig from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.transformer_config import TransformerConfig @@ -38,8 +39,6 @@ HAVE_TE = False if TYPE_CHECKING: - from megatron.core.models.T5.t5_model import T5Model as MCoreT5Model - from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec @@ -58,22 +57,32 @@ def t5_data_step(dataloader_iter) -> Dict[str, torch.Tensor]: else: _batch = batch - # if Dataset object is NeMo 1.0's T5SFTDataset (e.g. when finetuning with SQUAD) - if 'enc_dec_mask' not in _batch: - encoder_attn_mask_3d = build_attention_mask_3d(_batch['enc_mask'], _batch['enc_mask'], AttnMaskType.padding) - decoder_attn_mask_3d = build_attention_mask_3d(_batch['dec_mask'], _batch['dec_mask'], AttnMaskType.causal) - enc_dec_attn_mask_3d = build_attention_mask_3d(_batch['dec_mask'], _batch['enc_mask'], AttnMaskType.padding) - _batch['enc_mask'] = encoder_attn_mask_3d - _batch['dec_mask'] = decoder_attn_mask_3d - _batch['enc_dec_mask'] = enc_dec_attn_mask_3d - - # if Dataset object is Mcore T5 dataset (e.g. pretraining) - else: - # convert attention mask values from int to True/False - _batch['enc_mask'] = _batch['enc_mask'] < 0.5 - _batch['dec_mask'] = _batch['dec_mask'] < 0.5 - _batch['enc_dec_mask'] = _batch['enc_dec_mask'] < 0.5 - + # work for both mcore's T5 pre-train dataset object, and NeMo's T5SFTDataset dataset + enc_mask = _batch['enc_mask'] < 0.5 + dec_mask = _batch['dec_mask'] < 0.5 + # process for Flash/Fused + enc_mask = enc_mask.unsqueeze(1).unsqueeze(1) + dec_mask = dec_mask.unsqueeze(1).unsqueeze(1) + enc_dec_mask = ( + dec_mask, + enc_mask, + ) + # set dec_mask to None because decoder uses AttnMaskType.causal + dec_mask = None + _batch['enc_mask'] = enc_mask + _batch['dec_mask'] = dec_mask + _batch['enc_dec_mask'] = enc_dec_mask + + # bring to device + for key in _batch.keys(): + if key == "enc_dec_mask": # because enc_dec_mask is a tuple + _batch[key] = (_batch[key][0].cuda(non_blocking=True), _batch[key][1].cuda(non_blocking=True)) + elif key == "dec_mask": # because dec_mask is a None since decoder uses AttnMaskType.causal + continue + else: + _batch[key] = _batch[key].cuda(non_blocking=True) + + # set up forward arguments for pipeline parallelism required_keys = set() required_keys.update(["enc_mask", "dec_mask", "enc_dec_mask"]) if parallel_state.is_pipeline_first_stage(): @@ -81,7 +90,7 @@ def t5_data_step(dataloader_iter) -> Dict[str, torch.Tensor]: if parallel_state.is_pipeline_last_stage(): required_keys.update(("labels", "loss_mask")) - output = {key: val.cuda(non_blocking=True) if key in required_keys else None for key, val in _batch.items()} + output = {key: val if key in required_keys else None for key, val in _batch.items()} return output @@ -139,9 +148,12 @@ class T5Config(TransformerConfig, io.IOMixin): share_embeddings_and_output_weights: bool = True make_vocab_size_divisible_by: int = 128 position_embedding_type: Literal["learned_absolute", "rope"] = "learned_absolute" + apply_rope_fusion: bool = True max_position_embeddings: int = 512 rotary_percent: float = 1.0 seq_len_interpolation_factor: Optional[float] = None + seq_length: int = 512 + seq_length_dec: int = 128 encoder_pipeline_model_parallel_size: int = 0 attention_softmax_in_fp32: float = False bias_activation_fusion: bool = True @@ -168,7 +180,6 @@ def configure_model(self, tokenizer) -> "MCoreT5Model": ) % vp_size == 0, "Make sure the number of model chunks is the same across all pipeline stages." from megatron.core import parallel_state - from megatron.core.models.T5.t5_model import T5Model as MCoreT5Model encoder_config = copy.deepcopy(self) encoder_config.num_layers = self.encoder_num_layers diff --git a/nemo/collections/multimodal/data/dreambooth/dreambooth_dataset.py b/nemo/collections/multimodal/data/dreambooth/dreambooth_dataset.py index 1c39b1a72216..baead0c47962 100644 --- a/nemo/collections/multimodal/data/dreambooth/dreambooth_dataset.py +++ b/nemo/collections/multimodal/data/dreambooth/dreambooth_dataset.py @@ -15,8 +15,8 @@ from pathlib import Path import torch +from lightning.pytorch.utilities import rank_zero_only from PIL import Image -from pytorch_lightning.utilities import rank_zero_only from torch.utils.data import Dataset from tqdm import tqdm diff --git a/nemo/collections/multimodal/data/energon/base.py b/nemo/collections/multimodal/data/energon/base.py index 0a99b1a1baad..4e90dce55c7a 100644 --- a/nemo/collections/multimodal/data/energon/base.py +++ b/nemo/collections/multimodal/data/energon/base.py @@ -16,10 +16,10 @@ from typing import Any, Dict, Literal, Optional import fiddle as fdl -import pytorch_lightning as pl +import lightning.pytorch as pl +from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS from megatron.core import parallel_state from megatron.energon import WorkerConfig, get_savable_loader, get_train_dataset -from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS from torch.utils.data import DataLoader from typing_extensions import Self diff --git a/nemo/collections/multimodal/data/energon/task_encoder.py b/nemo/collections/multimodal/data/energon/task_encoder.py index 23758b3a43db..7a8d0f0ab033 100644 --- a/nemo/collections/multimodal/data/energon/task_encoder.py +++ b/nemo/collections/multimodal/data/energon/task_encoder.py @@ -48,7 +48,8 @@ class MultiModalTaskEncoder( and similarity interleaved samples. This class extends the DefaultTaskEncoder and provides a flexible mechanism to handle and encode - different types of multimodal data. Support for VQA, captioning and interleaved samples is provided by default. It supports registering custom encoders for each sample type + different types of multimodal data. Support for VQA, captioning and interleaved samples is provided by default. + It supports registering custom encoders for each sample type and provides methods for encoding individual samples, batching them, and further processing the batch for model input. """ @@ -59,8 +60,8 @@ def __init__(self, tokenizer, image_processor, multimodal_sample_config): Parameters: tokenizer (Tokenizer): The tokenizer used for processing text across different sample types. - image_processor (ImageProcessor): The image processor used for preprocessing images across different sample types. - multimodal_sample_config (MultiModalSampleConfig): Configuration object for multimodal samples, including tokens and placeholders. + image_processor (ImageProcessor): The image processor used for preprocessing images. + multimodal_sample_config (MultiModalSampleConfig): MultiModalSampleConfig object. """ self.tokenizer = tokenizer self.encoders: Dict[str, SampleEncoder] = { @@ -173,5 +174,6 @@ def encode_batch(self, batch_data: ImageTextRawBatch) -> dict: position_ids = torch.arange(seq_length, dtype=torch.long) position_ids = position_ids.unsqueeze(0).repeat(micro_batch_size, 1) batch_dict['position_ids'] = position_ids - batch_dict['attention_mask'] = None + if 'attention_mask' not in batch_dict: + batch_dict['attention_mask'] = None return batch_dict diff --git a/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py b/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py index 5291497f92c3..5d19b8544305 100644 --- a/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py +++ b/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py @@ -22,8 +22,8 @@ import torch import torch.nn.functional as F from einops import rearrange, reduce, repeat +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import DictConfig, ListConfig, OmegaConf -from pytorch_lightning.trainer.trainer import Trainer from transformers import CLIPVisionModel, SiglipVisionModel from nemo.collections.common.parts.utils import extend_instance diff --git a/nemo/collections/multimodal/models/text_to_image/controlnet/controlnet.py b/nemo/collections/multimodal/models/text_to_image/controlnet/controlnet.py index 158fa7595782..981600fcc3a1 100644 --- a/nemo/collections/multimodal/models/text_to_image/controlnet/controlnet.py +++ b/nemo/collections/multimodal/models/text_to_image/controlnet/controlnet.py @@ -18,9 +18,9 @@ import torch import torch.nn as nn from einops import rearrange, repeat +from lightning.pytorch import Trainer +from lightning.pytorch.utilities.rank_zero import rank_zero_only from omegaconf import DictConfig -from pytorch_lightning import Trainer -from pytorch_lightning.utilities.rank_zero import rank_zero_only from torch._inductor import config as inductor_config from nemo.collections.multimodal.data.controlnet.controlnet_dataset import build_train_valid_datasets diff --git a/nemo/collections/multimodal/models/text_to_image/controlnet/util.py b/nemo/collections/multimodal/models/text_to_image/controlnet/util.py index 3d9a7d16b1c3..f890426c98f4 100644 --- a/nemo/collections/multimodal/models/text_to_image/controlnet/util.py +++ b/nemo/collections/multimodal/models/text_to_image/controlnet/util.py @@ -17,9 +17,9 @@ import numpy as np import torch import torchvision +from lightning.pytorch import Callback +from lightning.pytorch.utilities.rank_zero import rank_zero_only from PIL import Image -from pytorch_lightning import Callback -from pytorch_lightning.utilities.rank_zero import rank_zero_only class ImageLogger(Callback): diff --git a/nemo/collections/multimodal/models/text_to_image/dreambooth/dreambooth.py b/nemo/collections/multimodal/models/text_to_image/dreambooth/dreambooth.py index 47548b02961d..8906263faeba 100644 --- a/nemo/collections/multimodal/models/text_to_image/dreambooth/dreambooth.py +++ b/nemo/collections/multimodal/models/text_to_image/dreambooth/dreambooth.py @@ -15,8 +15,8 @@ from typing import Any, Optional import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig -from pytorch_lightning import Trainer from torch._inductor import config as inductor_config from nemo.collections.multimodal.data.dreambooth.dreambooth_dataset import DreamBoothDataset diff --git a/nemo/collections/multimodal/models/text_to_image/imagen/imagen.py b/nemo/collections/multimodal/models/text_to_image/imagen/imagen.py index ed9be58178c4..1772e465f604 100644 --- a/nemo/collections/multimodal/models/text_to_image/imagen/imagen.py +++ b/nemo/collections/multimodal/models/text_to_image/imagen/imagen.py @@ -17,8 +17,8 @@ from typing import Any import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig, open_dict -from pytorch_lightning import Trainer from nemo.collections.multimodal.data.imagen.imagen_dataset import build_train_valid_datasets from nemo.collections.multimodal.models.text_to_image.imagen.precond import ContinousDDPMPrecond, EDMPrecond diff --git a/nemo/collections/multimodal/models/text_to_image/imagen/imagen_pipeline.py b/nemo/collections/multimodal/models/text_to_image/imagen/imagen_pipeline.py index 43660c9000a1..63963321fcf7 100644 --- a/nemo/collections/multimodal/models/text_to_image/imagen/imagen_pipeline.py +++ b/nemo/collections/multimodal/models/text_to_image/imagen/imagen_pipeline.py @@ -17,8 +17,8 @@ from typing import Callable, List, Optional, Union import torch +from lightning.pytorch import Trainer from omegaconf.omegaconf import OmegaConf -from pytorch_lightning import Trainer from torch.cuda.amp import autocast from nemo.collections.multimodal.models.text_to_image.imagen.imagen import Imagen, MegatronImagen @@ -73,7 +73,9 @@ def _load_model(model_ckpt: str, model_cfg: str, eval_mode: bool = True, trainer model_cfg.micro_batch_size = 1 model_cfg.global_batch_size = 1 model = MegatronImagen.restore_from( - restore_path=model_ckpt, override_config_path=model_cfg, trainer=trainer, + restore_path=model_ckpt, + override_config_path=model_cfg, + trainer=trainer, ) elif model_ckpt.endswith('.ckpt'): model_cfg = OmegaConf.load(model_cfg) @@ -128,7 +130,9 @@ def model_cfg_modifier(model_cfg): models = [] print('Load base model.') model = ImagenPipeline._load_model( - model_ckpt=customized_models.base_ckpt, model_cfg=customized_models.base_cfg, trainer=trainer, + model_ckpt=customized_models.base_ckpt, + model_cfg=customized_models.base_cfg, + trainer=trainer, ) models.append(model) diff --git a/nemo/collections/multimodal/models/text_to_image/stable_diffusion/diffusion_engine.py b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/diffusion_engine.py index 8b18fe2b25fe..c7e8795a749c 100644 --- a/nemo/collections/multimodal/models/text_to_image/stable_diffusion/diffusion_engine.py +++ b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/diffusion_engine.py @@ -17,14 +17,14 @@ from typing import Any, Dict, List, Tuple, Union import hydra -import pytorch_lightning as pl +import lightning.pytorch as pl import torch import torch._dynamo import torch.nn as nn from einops import rearrange +from lightning.pytorch import Trainer +from lightning.pytorch.utilities import rank_zero_only from omegaconf import DictConfig, ListConfig, OmegaConf -from pytorch_lightning import Trainer -from pytorch_lightning.utilities import rank_zero_only from safetensors.torch import load_file as load_safetensors from torch._dynamo import optimize from torch.optim.lr_scheduler import LambdaLR diff --git a/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/autoencoder.py b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/autoencoder.py index d79d85c2e026..311ebc0f06f5 100644 --- a/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/autoencoder.py +++ b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/autoencoder.py @@ -13,7 +13,7 @@ # limitations under the License. from contextlib import contextmanager -import pytorch_lightning as pl +import lightning.pytorch as pl import torch import torch.nn.functional as F from nemo.utils import logging diff --git a/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/ddpm.py b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/ddpm.py index 744dc6945394..163b2fb27e0f 100644 --- a/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/ddpm.py +++ b/nemo/collections/multimodal/models/text_to_image/stable_diffusion/ldm/ddpm.py @@ -17,18 +17,18 @@ from functools import partial from typing import Any, Dict, List, Optional, Union +import lightning.pytorch as pl import numpy as np -import pytorch_lightning as pl import torch import torch.nn as nn from einops import rearrange, repeat -from lightning_fabric.utilities.cloud_io import _load as pl_load +from lightning.fabric.utilities.cloud_io import _load as pl_load +from lightning.pytorch import Trainer +from lightning.pytorch.core.saving import _load_state as ptl_load_state +from lightning.pytorch.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml +from lightning.pytorch.utilities.migration import pl_legacy_patch +from lightning.pytorch.utilities.rank_zero import rank_zero_only from omegaconf import DictConfig, open_dict -from pytorch_lightning import Trainer -from pytorch_lightning.core.saving import _load_state as ptl_load_state -from pytorch_lightning.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml -from pytorch_lightning.utilities.migration import pl_legacy_patch -from pytorch_lightning.utilities.rank_zero import rank_zero_only from torch._inductor import config as inductor_config from torchvision.utils import make_grid from tqdm import tqdm diff --git a/nemo/collections/multimodal/models/vision_language_foundation/clip/megatron_clip_models.py b/nemo/collections/multimodal/models/vision_language_foundation/clip/megatron_clip_models.py index a9e51610bedd..84718f99262f 100644 --- a/nemo/collections/multimodal/models/vision_language_foundation/clip/megatron_clip_models.py +++ b/nemo/collections/multimodal/models/vision_language_foundation/clip/megatron_clip_models.py @@ -23,10 +23,10 @@ import numpy as np import torch import torch.nn.functional as F +from lightning.pytorch.accelerators import CPUAccelerator +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import OmegaConf from omegaconf.dictconfig import DictConfig -from pytorch_lightning.accelerators import CPUAccelerator -from pytorch_lightning.trainer.trainer import Trainer from tqdm import tqdm from nemo.collections.multimodal.data.clip.clip_dataset import ( diff --git a/nemo/collections/multimodal/models/vision_language_foundation/megatron_nsfw_clip_models.py b/nemo/collections/multimodal/models/vision_language_foundation/megatron_nsfw_clip_models.py index 79c0f3910be0..37e33f892890 100644 --- a/nemo/collections/multimodal/models/vision_language_foundation/megatron_nsfw_clip_models.py +++ b/nemo/collections/multimodal/models/vision_language_foundation/megatron_nsfw_clip_models.py @@ -19,11 +19,11 @@ import torch import torch.nn as nn import torch.nn.functional as F +from lightning.pytorch.accelerators import CPUAccelerator +from lightning.pytorch.trainer.trainer import Trainer from megatron.core import parallel_state from megatron.core.pipeline_parallel.schedules import get_forward_backward_func from omegaconf.dictconfig import DictConfig -from pytorch_lightning.accelerators import CPUAccelerator -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.multimodal.data.clip.clip_dataset import tokenize from nemo.collections.multimodal.data.nsfw.nsfw_dataset import build_dataset @@ -38,7 +38,6 @@ from nemo.core.classes.common import PretrainedModelInfo from nemo.utils import logging - try: from megatron.core.num_microbatches_calculator import get_num_microbatches diff --git a/nemo/collections/multimodal/parts/utils.py b/nemo/collections/multimodal/parts/utils.py index 6ba2e8ca91f9..8773b47025bc 100644 --- a/nemo/collections/multimodal/parts/utils.py +++ b/nemo/collections/multimodal/parts/utils.py @@ -17,10 +17,10 @@ import numpy as np import torch +from lightning.pytorch import Trainer +from lightning.pytorch.plugins.environments import TorchElasticEnvironment from omegaconf import DictConfig, OmegaConf, open_dict from PIL import Image -from pytorch_lightning import Trainer -from pytorch_lightning.plugins.environments import TorchElasticEnvironment from transformers import CLIPImageProcessor, SiglipImageProcessor from nemo.collections.multimodal.data.clip.augmentations.augmentations import image_transform diff --git a/nemo/collections/multimodal/speech_cv/models/visual_ctc_models.py b/nemo/collections/multimodal/speech_cv/models/visual_ctc_models.py index 50b4d29c05a4..53ae4a2dfb65 100644 --- a/nemo/collections/multimodal/speech_cv/models/visual_ctc_models.py +++ b/nemo/collections/multimodal/speech_cv/models/visual_ctc_models.py @@ -20,8 +20,8 @@ from typing import Dict, List, Optional, Union import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig, OmegaConf, open_dict -from pytorch_lightning import Trainer from tqdm.auto import tqdm from nemo.collections.asr.data import audio_to_text_dataset diff --git a/nemo/collections/multimodal/speech_cv/models/visual_hybrid_rnnt_ctc_bpe_models.py b/nemo/collections/multimodal/speech_cv/models/visual_hybrid_rnnt_ctc_bpe_models.py index 106fbc432926..8249e5d8a7f8 100644 --- a/nemo/collections/multimodal/speech_cv/models/visual_hybrid_rnnt_ctc_bpe_models.py +++ b/nemo/collections/multimodal/speech_cv/models/visual_hybrid_rnnt_ctc_bpe_models.py @@ -17,8 +17,8 @@ from typing import Dict, Optional, Union import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig, ListConfig, OmegaConf, open_dict -from pytorch_lightning import Trainer from nemo.collections.asr.losses.ctc import CTCLoss from nemo.collections.asr.losses.rnnt import RNNTLoss @@ -90,7 +90,10 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): # Setup decoding object self.decoding = RNNTBPEDecoding( - decoding_cfg=self.cfg.decoding, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer, + decoding_cfg=self.cfg.decoding, + decoder=self.decoder, + joint=self.joint, + tokenizer=self.tokenizer, ) # Setup wer object @@ -282,7 +285,10 @@ def change_vocabulary( decoding_cfg = OmegaConf.merge(decoding_cls, decoding_cfg) self.decoding = RNNTBPEDecoding( - decoding_cfg=decoding_cfg, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer, + decoding_cfg=decoding_cfg, + decoder=self.decoder, + joint=self.joint, + tokenizer=self.tokenizer, ) self.wer = WER( @@ -388,7 +394,10 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig, decoder_type: str = decoding_cfg = OmegaConf.merge(decoding_cls, decoding_cfg) self.decoding = RNNTBPEDecoding( - decoding_cfg=decoding_cfg, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer, + decoding_cfg=decoding_cfg, + decoder=self.decoder, + joint=self.joint, + tokenizer=self.tokenizer, ) self.wer = WER( diff --git a/nemo/collections/multimodal/speech_cv/models/visual_hybrid_rnnt_ctc_models.py b/nemo/collections/multimodal/speech_cv/models/visual_hybrid_rnnt_ctc_models.py index 1b30263985da..158bfaddcc96 100644 --- a/nemo/collections/multimodal/speech_cv/models/visual_hybrid_rnnt_ctc_models.py +++ b/nemo/collections/multimodal/speech_cv/models/visual_hybrid_rnnt_ctc_models.py @@ -19,8 +19,8 @@ from typing import List, Optional import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig, OmegaConf, open_dict -from pytorch_lightning import Trainer from tqdm.auto import tqdm from nemo.collections.asr.losses.ctc import CTCLoss diff --git a/nemo/collections/multimodal/speech_cv/models/visual_rnnt_bpe_models.py b/nemo/collections/multimodal/speech_cv/models/visual_rnnt_bpe_models.py index eeffb906981a..11e9d43e1737 100644 --- a/nemo/collections/multimodal/speech_cv/models/visual_rnnt_bpe_models.py +++ b/nemo/collections/multimodal/speech_cv/models/visual_rnnt_bpe_models.py @@ -17,8 +17,8 @@ from typing import Dict, List, Optional, Union import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig, ListConfig, OmegaConf, open_dict -from pytorch_lightning import Trainer from nemo.collections.asr.losses.rnnt import RNNTLoss from nemo.collections.asr.metrics.wer import WER @@ -68,7 +68,10 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): # Setup decoding object self.decoding = RNNTBPEDecoding( - decoding_cfg=self.cfg.decoding, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer, + decoding_cfg=self.cfg.decoding, + decoder=self.decoder, + joint=self.joint, + tokenizer=self.tokenizer, ) # Setup wer object @@ -165,7 +168,10 @@ def change_vocabulary( decoding_cfg = OmegaConf.merge(decoding_cls, decoding_cfg) self.decoding = RNNTBPEDecoding( - decoding_cfg=decoding_cfg, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer, + decoding_cfg=decoding_cfg, + decoder=self.decoder, + joint=self.joint, + tokenizer=self.tokenizer, ) self.wer = WER( @@ -214,7 +220,10 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig): decoding_cfg = OmegaConf.merge(decoding_cls, decoding_cfg) self.decoding = RNNTBPEDecoding( - decoding_cfg=decoding_cfg, decoder=self.decoder, joint=self.joint, tokenizer=self.tokenizer, + decoding_cfg=decoding_cfg, + decoder=self.decoder, + joint=self.joint, + tokenizer=self.tokenizer, ) self.wer = WER( diff --git a/nemo/collections/multimodal/speech_cv/models/visual_rnnt_models.py b/nemo/collections/multimodal/speech_cv/models/visual_rnnt_models.py index 5a86eed93019..75202238d2d0 100644 --- a/nemo/collections/multimodal/speech_cv/models/visual_rnnt_models.py +++ b/nemo/collections/multimodal/speech_cv/models/visual_rnnt_models.py @@ -20,8 +20,8 @@ from typing import Dict, List, Optional, Tuple, Union import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig, OmegaConf, open_dict -from pytorch_lightning import Trainer from tqdm.auto import tqdm from nemo.collections.asr.data import audio_to_text_dataset diff --git a/nemo/collections/multimodal/speech_llm/models/modular_models.py b/nemo/collections/multimodal/speech_llm/models/modular_models.py index 46b2ca3e26fd..aab27cf2d908 100644 --- a/nemo/collections/multimodal/speech_llm/models/modular_models.py +++ b/nemo/collections/multimodal/speech_llm/models/modular_models.py @@ -21,11 +21,11 @@ import sacrebleu import torch from hydra.utils import get_class +from lightning.pytorch.trainer.trainer import Trainer +from lightning.pytorch.utilities import rank_zero_only from omegaconf import ListConfig from omegaconf.dictconfig import DictConfig from omegaconf.omegaconf import OmegaConf, open_dict -from pytorch_lightning.trainer.trainer import Trainer -from pytorch_lightning.utilities import rank_zero_only from nemo.collections.asr.models import ASRModel, EncDecSpeakerLabelModel from nemo.collections.asr.parts.utils.eval_utils import remove_punctuations diff --git a/nemo/collections/multimodal/speech_llm/models/modular_t5_models.py b/nemo/collections/multimodal/speech_llm/models/modular_t5_models.py index 79fc0468e819..a99f5c346831 100644 --- a/nemo/collections/multimodal/speech_llm/models/modular_t5_models.py +++ b/nemo/collections/multimodal/speech_llm/models/modular_t5_models.py @@ -21,10 +21,10 @@ import sacrebleu import torch +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import ListConfig from omegaconf.dictconfig import DictConfig from omegaconf.omegaconf import OmegaConf, open_dict -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.asr.models import ASRModel, SpeechEncDecSelfSupervisedModel from nemo.collections.common.data.utils import move_data_to_device diff --git a/nemo/collections/multimodal_autoregressive/data/README.md b/nemo/collections/multimodal_autoregressive/data/README.md index c4814ad267f8..3f6d5a6c6a81 100644 --- a/nemo/collections/multimodal_autoregressive/data/README.md +++ b/nemo/collections/multimodal_autoregressive/data/README.md @@ -8,27 +8,7 @@ This is an example of how to do autoregressive generation for multiple modalitie ### 1. Vision Understanding using EMU3 Tokenizer #### Download and Extract data -We will be working with coyo dataset which has 700 million images. - -First create credentials for rclone . Create this file at `~/.config/rclone/rclone.conf` -``` -[pbss-team-vfm-share-ro-s3] -type = s3 -env_auth = true -access_key_id = -secret_access_key = -region = us-east-1 -endpoint = https://pdx.s8k.io -``` -To download the images -``` -rclone copy pbss-team-vfm-share-ro-s3:webdataset_images/webdataset_edify_image_v3/coyo_700m/resolution_lt_720/aspect_ratio_16_9/images images --transfers=16 --multi-thread-streams=16 --checkers=8 -P --stats 5s -``` - -To download the captions -``` -rclone copy pbss-team-vfm-share-ro-s3:webdataset_images/webdataset_edify_image_v3/coyo_700m/resolution_lt_720/aspect_ratio_16_9/captions_ai_v3p1 captions_ai_v3p1 --transfers=16 --multi-thread-streams=16 --checkers=8 -P --stats 5s -``` +Download the [COYO700M dataset](https://github.com/kakaobrain/coyo-dataset) Once downloaded extract the data using tar utilities. @@ -70,13 +50,13 @@ Follow usual nemo instructions to train any autoregressive model. ``` #### Inference -To run inference edit the [inference config file](examples/multimodal_autoregressive/conf/megatron_mm_ar_inference.yaml) +To run inference edit the [inference config file](examples/multimodal_autoregressive/conf/megatron_mm_ar_inference_vision_understanding.yaml) *NOTE* Make sure you have a .nemo file (checkpoint). If you just have a regular megatron checkpoint you have to do a conversion as shown in [this doc](https://docs.nvidia.com/nemo-framework/user-guide/latest/llms/gpt/checkpointconversion.html?highlight=convert) Run inference as follows ``` -torchrun --nproc-per-node 2 examples/multimodal_autoregressive/megatron_mm_autoregressive_eval.py +torchrun --nproc-per-node 2 examples/multimodal_autoregressive/megatron_mm_autoregressive_eval_vision_understanding.py ``` @@ -116,13 +96,11 @@ Follow usual nemo instructions to train any autoregressive model. ``` #### Inference -To run inference edit the [inference config file](examples/multimodal_autoregressive/conf/megatron_mm_ar_inference.yaml) +To run inference edit the [inference config file](examples/multimodal_autoregressive/conf/megatron_mm_ar_inference_image_generation.yaml) *NOTE* Make sure you have a .nemo file (checkpoint). If you just have a regular megatron checkpoint you have to do a conversion as shown in [this doc](https://docs.nvidia.com/nemo-framework/user-guide/latest/llms/gpt/checkpointconversion.html?highlight=convert) Run inference as follows ``` -torchrun --nproc-per-node 2 examples/multimodal_autoregressive/megatron_mm_autoregressive_eval.py -``` - -TODO : Instructions to convert visual tokens to images coming soon. \ No newline at end of file +torchrun --nproc-per-node 2 examples/multimodal_autoregressive/megatron_mm_autoregressive_eval_image_generation.py +``` \ No newline at end of file diff --git a/nemo/collections/nlp/data/language_modeling/megatron/base_prompt_learning_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/base_prompt_learning_dataset.py index bbd14f47a651..ea5f8c5a930b 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/base_prompt_learning_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/base_prompt_learning_dataset.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import omegaconf import torch from nemo.collections.nlp.modules.common import VirtualPromptSource @@ -70,8 +71,55 @@ def __init__( # Datasets are a list of file path strings to .json or .jsonl files elif isinstance(datasets[0], str): for path in datasets: - dataset = open(path, 'r', encoding='utf-8') - self.load_data(dataset) + with open(path, 'r', encoding='utf-8') as dataset: + dataset_examples = self.load_data(dataset) + self.examples.extend(dataset_examples) + elif isinstance(datasets[0], omegaconf.ListConfig) or isinstance(datasets[0], list): + # Dataset is a list of tuples with the first element being the probability of sampling from the dataset + # This code repeates the smaller datasets to approximately match the target probabilities + total_examples = 0 + dataset_lengths = [] + target_probs = [] + datasets_examples_list = [] + for prob_and_path in datasets: + prob = prob_and_path[0] + path = prob_and_path[1] + with open(path, 'r', encoding='utf-8') as dataset: + dataset_examples = self.load_data(dataset) + datasets_examples_list.append(dataset_examples) + dataset_lengths.append(len(dataset_examples)) + total_examples += len(dataset_examples) + target_probs.append(prob) + + # Normalize the target probs + target_probs = [prob / sum(target_probs) for prob in target_probs] + current_probs = [dataset_lengths[i] / total_examples for i in range(len(dataset_lengths))] + + # Increase number of examples needed without reducing the larger datasets with low target probs + new_total_examples = total_examples + for dataset_idx in range(len(datasets)): + if target_probs[dataset_idx] < current_probs[dataset_idx]: + target_total_examples = int(dataset_lengths[dataset_idx] / target_probs[dataset_idx]) + new_total_examples = max(new_total_examples, target_total_examples) + + final_total_examples = 0 + final_dataset_lengths = [] + for dataset_idx in range(len(datasets)): + num_samples_required = int(new_total_examples * target_probs[dataset_idx]) + num_repeat = max( + int(round(num_samples_required // dataset_lengths[dataset_idx])), 1 + ) # At least 1 repeat + logging.info("dataset idx {}, num_repeat {}".format(dataset_idx, num_repeat)) + dataset_examples_repeated = datasets_examples_list[dataset_idx] * num_repeat + final_dataset_lengths.append(len(dataset_examples_repeated)) + final_total_examples += len(dataset_examples_repeated) + self.examples.extend(dataset_examples_repeated) + + final_probs = [final_dataset_lengths[i] / final_total_examples for i in range(len(final_dataset_lengths))] + logging.info("Target probs: {}".format(target_probs)) + logging.info("Final probs: {}".format(final_probs)) + logging.info("Initial total examples: {}".format(total_examples)) + logging.info("Final total examples: {}".format(final_total_examples)) else: raise ValueError("Datasets must be a list of dicts or a list of filepath strings") diff --git a/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py index 898ddb7d716b..9da2419520c2 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py @@ -573,15 +573,23 @@ def _build_samples_mapping(self): self.samples_mapping = None def _build_loss_mask(self, processed_example): + seq_boundaries = processed_example['seq_boundaries'] if self.answer_only_loss: - seq_boundaries = processed_example['seq_boundaries'] return np.concatenate( [ processed_example['loss_mask'][seq_boundaries[i] + 1 : seq_boundaries[i + 1]] for i in range(len(seq_boundaries) - 1) ] ) - return [1.0] * (len(processed_example['input_ids']) - len(processed_example['seq_boundaries']) + 1) + return np.concatenate( + [ + [ + 0 if x == self.tokenizer.eos_id else 1.0 + for x in processed_example['input_ids'][seq_boundaries[i] : seq_boundaries[i + 1] - 1] + ] + for i in range(len(seq_boundaries) - 1) + ] + ) def _maybe_cast_to_list(self, x): return [item.tolist() if isinstance(item, np.ndarray) else item for item in x] @@ -622,16 +630,40 @@ def collate_fn(self, batch): position_ids: List[List[int]] = [] cu_seqlens: List[List[int]] = [] + cu_seqlens_unpadded: List[List[int]] = [] for item in batch: position_ids.append([]) cu_seqlens.append([0]) + cu_seqlens_unpadded.append([0]) seqlens = np.array(item['seq_boundaries'][1:]) - np.array(item['seq_boundaries'][:-1]) for l in seqlens: # length minus 1 because input_ids is truncated by 1 for labels position_ids[-1].extend(list(range(l - 1))) cu_seqlens[-1].append(cu_seqlens[-1][-1] + l - 1) - # set last seq to the max seq len because rope and attn kernels expect no padding - cu_seqlens[-1][-1] = max_length + + # the last seq needs to be the max seq len because rope and attn kernels expect no padding + assert cu_seqlens[-1][-1] <= max_length + + # since data is prepadded when cp_size > 1, there may be some extra padding at the end + # of the packed sequence. In this case, we need to add the max seq len to the end. + if cu_seqlens[-1][-1] != max_length: + cu_seqlens[-1].append(max_length) + + for i in range(len(item['seq_boundaries']) - 1): + current_seq = item['input_ids'][item['seq_boundaries'][i] : item['seq_boundaries'][i + 1] - 1] + + # since the data could be prepadded with tokenizer's eos_id, we can find out the index of all the eos_id + eos_idx = np.where(np.array(current_seq) == self.tokenizer.eos_id) + + # The second eos_id index marks the length of the original unpadded sequence if the sequence is + # prepadded for cp_size > 1. Otherwise, there is no extra padding. + seqlen_unpadded = eos_idx[0][0] + 1 if eos_idx[0].any() else len(current_seq) + cu_seqlens_unpadded[-1].append(cu_seqlens_unpadded[-1][-1] + seqlen_unpadded) + + # if extra paddings are added in the packed sequence, they can't be counted as + # actual tokens for training + if len(cu_seqlens[-1]) > len(cu_seqlens_unpadded[-1]): + cu_seqlens_unpadded[-1].append(cu_seqlens_unpadded[-1][-1]) assert len(input_ids[0]) == len( position_ids[0] @@ -652,12 +684,16 @@ def collate_fn(self, batch): if self.return_cu_seqlen: cu_seqlens = self._collate_item(cu_seqlens, max_length=max(len(l) for l in cu_seqlens) + 1, pad_id=-1) - + cu_seqlens_unpadded = self._collate_item( + cu_seqlens_unpadded, max_length=max(len(l) for l in cu_seqlens_unpadded) + 1, pad_id=-1 + ) # Pre-generate `cu_seqlens_argmin` and `max_seqlen` as CPU tensor to avoid device-to-host copies. cu_seqlens = torch.IntTensor(cu_seqlens) cu_seqlens_argmin = torch.argmin(cu_seqlens, dim=1, keepdim=True) seqlens = cu_seqlens[:, 1:] - cu_seqlens[:, :-1] max_seqlen, _ = seqlens.max(dim=1, keepdim=True) + cu_seqlens_unpadded = torch.IntTensor(cu_seqlens_unpadded) + cu_seqlens_unpadded_argmin = torch.argmin(cu_seqlens_unpadded, dim=1, keepdim=True) processed_batch.update( { @@ -667,6 +703,8 @@ def collate_fn(self, batch): 'cu_seqlens': torch.IntTensor(cu_seqlens), # cu_seqlens_q must be in dtype torch.int32 'cu_seqlens_argmin': cu_seqlens_argmin, # only required for perf 'max_seqlen': max_seqlen, # only required for perf + 'cu_seqlens_unpadded': torch.IntTensor(cu_seqlens_unpadded), + 'cu_seqlens_unpadded_argmin': cu_seqlens_unpadded_argmin, } ) else: diff --git a/nemo/collections/nlp/data/machine_translation/preproc_mt_data.py b/nemo/collections/nlp/data/machine_translation/preproc_mt_data.py index b95993ded69e..59181d8cb89f 100644 --- a/nemo/collections/nlp/data/machine_translation/preproc_mt_data.py +++ b/nemo/collections/nlp/data/machine_translation/preproc_mt_data.py @@ -21,8 +21,8 @@ import tempfile from joblib import Parallel, delayed +from lightning.pytorch import Trainer from omegaconf import ListConfig, OmegaConf -from pytorch_lightning import Trainer from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer, create_spt_model from nemo.collections.nlp.data.language_modeling.sentence_dataset import SentenceDataset @@ -33,23 +33,23 @@ class MTDataPreproc: - """ Automatically trains tokenizers and preprocesses machine translation data based on the MTEncDecModelConfig. - For training NMT models with datasets larger than 5M sentence pairs, - it can be inefficient to train them without first creating a tarred dataset. - If the user wants to change the tokenizer, vocab size, or batch size, for example, - they must reprocess the data with the correct configuration. - With MTDataPreproc users can sweep through data configurations and the tarred dataset will - be automatically created according to the model configuration. - To train tokenizer model and create tarred dataset specify in configuration: - model.preproc_out_dir=/path/to/preproc_out - model.encoder_tokenizer.vocab_size=32000 - model.decoder_tokenizer.vocab_size=32000 - model.train_ds.use_tarred_dataset=True - model.train_ds.src_file_name=/path/to/src.txt - model.train_ds.tgt_file_name=/path/to/tgt.txt - model.train_ds.tokens_in_batch=16000 - Once a dataset has been constructed based on this configuration, MTDataPreproc will not process it again. - If a previously trained tokenizer model or tarred dataset is found, MTDataPreproc will not preprocess the data. + """Automatically trains tokenizers and preprocesses machine translation data based on the MTEncDecModelConfig. + For training NMT models with datasets larger than 5M sentence pairs, + it can be inefficient to train them without first creating a tarred dataset. + If the user wants to change the tokenizer, vocab size, or batch size, for example, + they must reprocess the data with the correct configuration. + With MTDataPreproc users can sweep through data configurations and the tarred dataset will + be automatically created according to the model configuration. + To train tokenizer model and create tarred dataset specify in configuration: + model.preproc_out_dir=/path/to/preproc_out + model.encoder_tokenizer.vocab_size=32000 + model.decoder_tokenizer.vocab_size=32000 + model.train_ds.use_tarred_dataset=True + model.train_ds.src_file_name=/path/to/src.txt + model.train_ds.tgt_file_name=/path/to/tgt.txt + model.train_ds.tokens_in_batch=16000 + Once a dataset has been constructed based on this configuration, MTDataPreproc will not process it again. + If a previously trained tokenizer model or tarred dataset is found, MTDataPreproc will not preprocess the data. """ def __init__(self, cfg: MTEncDecModelConfig, trainer: Trainer = None) -> None: @@ -147,12 +147,16 @@ def __init__(self, cfg: MTEncDecModelConfig, trainer: Trainer = None) -> None: global_rank=self.global_rank, encoder_training_sample_size=cfg.encoder_tokenizer.get('training_sample_size', -1), decoder_training_sample_size=cfg.decoder_tokenizer.get('training_sample_size', -1), - encoder_special_tokens=OmegaConf.to_container(cfg.encoder_tokenizer.special_tokens) - if cfg.encoder_tokenizer.special_tokens - else None, - decoder_special_tokens=OmegaConf.to_container(cfg.decoder_tokenizer.special_tokens) - if cfg.decoder_tokenizer.special_tokens - else None, + encoder_special_tokens=( + OmegaConf.to_container(cfg.encoder_tokenizer.special_tokens) + if cfg.encoder_tokenizer.special_tokens + else None + ), + decoder_special_tokens=( + OmegaConf.to_container(cfg.decoder_tokenizer.special_tokens) + if cfg.decoder_tokenizer.special_tokens + else None + ), spt_symbols=spt_symbols, ) # update config @@ -280,10 +284,10 @@ def __init__(self, cfg: MTEncDecModelConfig, trainer: Trainer = None) -> None: ) def tar_files_to_string(self, tar_files): - """ Tar files are generated in the following format: basename.number.tar + """Tar files are generated in the following format: basename.number.tar Where number is an integer from 1 to the number of tar files. We convert this list to a string that can be used in the model config to specify - tarred datasets: basename_OP_1..num_tar_files_CL_.tar + tarred datasets: basename_OP_1..num_tar_files_CL_.tar Args: tar_files (List[str]): List of tar files generated by preprocess_parallel_dataset @@ -337,7 +341,9 @@ def get_enc_dec_tokenizers( @staticmethod def get_monolingual_tokenizer( - tokenizer_name=None, tokenizer_model=None, bpe_dropout=0.0, + tokenizer_name=None, + tokenizer_model=None, + bpe_dropout=0.0, ): if tokenizer_name == 'sentencepiece': tokenizer = SentencePieceTokenizer(model_path=tokenizer_model) @@ -385,14 +391,14 @@ def preprocess_parallel_dataset( src_fname (str): path to source text data tgt_fname (str): path to target text data out_dir (str): path to write tarred dataset - encoder_tokenizer (Any): tokenizer for encoder + encoder_tokenizer (Any): tokenizer for encoder decoder_tokenizer (Any): tokenizer for decoder - max_seq_length (int): maximum sequence length - min_seq_length (int): minimum sequence length - tokens_in_batch (int): tokens per batch per GPU, effectively batch size + max_seq_length (int): maximum sequence length + min_seq_length (int): minimum sequence length + tokens_in_batch (int): tokens per batch per GPU, effectively batch size lines_per_dataset_fragment (int): number of lines to consider for bucketing and padding num_batches_per_tarfile (int): number of batches (pickle files) within each tarfile - tar_file_prefix (str) : add string prefix to tar files + tar_file_prefix (str) : add string prefix to tar files n_jobs (int): number of processes to use for data processing (-2 to use all but 2) """ @@ -471,7 +477,10 @@ def preprocess_parallel_dataset( out_dir, f'remainder-batches.tokens.{tokens_in_batch}.tar_file_{remainder_tar_file_ctr}.tar', ) - remainder_tar_file_ptr = tarfile.open(remainder_tar_file_path, 'w',) + remainder_tar_file_ptr = tarfile.open( + remainder_tar_file_path, + 'w', + ) batch_in_tar_ctr = 0 tar_file_ptr.close() os.remove(tar_file_path) @@ -631,9 +640,9 @@ def preprocess_monolingual_dataset( fname (str): Path to source text data out_dir (str): Path to write tarred dataset tokenizer (Any): Path to tokenizer model - max_seq_length (int): maximum sequence length - min_seq_length (int): minimum sequence length - tokens_in_batch (int): tokens per batch per GPU, effectively batch size + max_seq_length (int): maximum sequence length + min_seq_length (int): minimum sequence length + tokens_in_batch (int): tokens per batch per GPU, effectively batch size lines_per_dataset_fragment (int): number of lines to consider for bucketing and padding num_batches_per_tarfile (int): number of batches (pickle files) within each tarfile global_rank (int): if set to zero, data will be processed on this node @@ -808,7 +817,8 @@ def train_tokenizers( split_by_whitespace=split_by_whitespace, ) os.rename( - os.path.join(out_dir, 'tokenizer.model'), encoder_tokenizer_model, + os.path.join(out_dir, 'tokenizer.model'), + encoder_tokenizer_model, ) else: if encoder_tokenizer_name in supported_train_tokenizers: @@ -1007,7 +1017,10 @@ def write_parallel_batches_to_tarfiles( tar_file_path = os.path.join( out_dir, 'fragment-%s-batches.tokens.%d.%d.tar' % (fragment_index, num_tokens, tar_file_ctr) ) - tar_file_ptr = tarfile.open(tar_file_path, 'w',) + tar_file_ptr = tarfile.open( + tar_file_path, + 'w', + ) batch_ctr = 0 # return tar files paths that have batches remaining diff --git a/nemo/collections/nlp/models/dialogue/dialogue_gpt_classification_model.py b/nemo/collections/nlp/models/dialogue/dialogue_gpt_classification_model.py index 07ca790866c7..6c7472b95c42 100644 --- a/nemo/collections/nlp/models/dialogue/dialogue_gpt_classification_model.py +++ b/nemo/collections/nlp/models/dialogue/dialogue_gpt_classification_model.py @@ -21,8 +21,8 @@ import numpy as np import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig -from pytorch_lightning import Trainer from torch.utils.data import DataLoader from transformers import AutoModelWithLMHead diff --git a/nemo/collections/nlp/models/dialogue/dialogue_gpt_generation_model.py b/nemo/collections/nlp/models/dialogue/dialogue_gpt_generation_model.py index 116605b65d52..7fb0ba770189 100644 --- a/nemo/collections/nlp/models/dialogue/dialogue_gpt_generation_model.py +++ b/nemo/collections/nlp/models/dialogue/dialogue_gpt_generation_model.py @@ -19,8 +19,8 @@ import numpy as np import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig -from pytorch_lightning import Trainer from torch.utils.data import DataLoader from transformers import AutoModelWithLMHead diff --git a/nemo/collections/nlp/models/dialogue/dialogue_nearest_neighbour_model.py b/nemo/collections/nlp/models/dialogue/dialogue_nearest_neighbour_model.py index 29e2627fa038..9bf7ae2a9116 100644 --- a/nemo/collections/nlp/models/dialogue/dialogue_nearest_neighbour_model.py +++ b/nemo/collections/nlp/models/dialogue/dialogue_nearest_neighbour_model.py @@ -19,8 +19,8 @@ import numpy as np import torch import torch.nn.functional as F +from lightning.pytorch import Trainer from omegaconf import DictConfig -from pytorch_lightning import Trainer from transformers import AutoModel from nemo.collections.nlp.data.dialogue import DialogueSGDDataProcessor diff --git a/nemo/collections/nlp/models/dialogue/dialogue_s2s_generation_model.py b/nemo/collections/nlp/models/dialogue/dialogue_s2s_generation_model.py index 48f3e5127a88..3f0d09d7dc66 100644 --- a/nemo/collections/nlp/models/dialogue/dialogue_s2s_generation_model.py +++ b/nemo/collections/nlp/models/dialogue/dialogue_s2s_generation_model.py @@ -18,8 +18,8 @@ import numpy as np import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig, OmegaConf, open_dict -from pytorch_lightning import Trainer from torch.utils.data import DataLoader from transformers import AutoModelForSeq2SeqLM diff --git a/nemo/collections/nlp/models/dialogue/dialogue_zero_shot_intent_model.py b/nemo/collections/nlp/models/dialogue/dialogue_zero_shot_intent_model.py index 5298c060df08..1df19cf8a556 100644 --- a/nemo/collections/nlp/models/dialogue/dialogue_zero_shot_intent_model.py +++ b/nemo/collections/nlp/models/dialogue/dialogue_zero_shot_intent_model.py @@ -19,8 +19,8 @@ import numpy as np import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig -from pytorch_lightning import Trainer from transformers import AutoModelForSequenceClassification, AutoTokenizer from nemo.collections.nlp.data.dialogue import DialogueSGDDataProcessor diff --git a/nemo/collections/nlp/models/dialogue/intent_slot_classification_model.py b/nemo/collections/nlp/models/dialogue/intent_slot_classification_model.py index 777d468084e2..09a81b33c973 100644 --- a/nemo/collections/nlp/models/dialogue/intent_slot_classification_model.py +++ b/nemo/collections/nlp/models/dialogue/intent_slot_classification_model.py @@ -16,8 +16,8 @@ from typing import Dict, List, Optional import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig, OmegaConf -from pytorch_lightning import Trainer from torch.utils.data import DataLoader from nemo.collections.common.losses import AggregatorLoss, CrossEntropyLoss diff --git a/nemo/collections/nlp/models/dialogue/sgdqa_model.py b/nemo/collections/nlp/models/dialogue/sgdqa_model.py index 3b30dfccd9ce..6cd2243423a4 100644 --- a/nemo/collections/nlp/models/dialogue/sgdqa_model.py +++ b/nemo/collections/nlp/models/dialogue/sgdqa_model.py @@ -22,8 +22,8 @@ from typing import List, Optional import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig -from pytorch_lightning import Trainer from torch.utils.data import DataLoader from nemo.collections.nlp.data.dialogue import DialogueSGDBERTDataset, DialogueSGDDataProcessor diff --git a/nemo/collections/nlp/models/duplex_text_normalization/duplex_decoder.py b/nemo/collections/nlp/models/duplex_text_normalization/duplex_decoder.py index 7d4cac46cc28..253962e55621 100644 --- a/nemo/collections/nlp/models/duplex_text_normalization/duplex_decoder.py +++ b/nemo/collections/nlp/models/duplex_text_normalization/duplex_decoder.py @@ -19,8 +19,8 @@ from typing import Dict, List, Optional, Union import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig -from pytorch_lightning import Trainer from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, DataCollatorForSeq2Seq import nemo.collections.nlp.data.text_normalization.constants as constants @@ -307,7 +307,7 @@ def _infer( span_ends: List[List[int]], inst_directions: List[str], ): - """ Main function for Inference + """Main function for Inference Args: sents: A list of inputs tokenized by a basic tokenizer. nb_spans: A list of ints where each int indicates the number of semiotic spans in each input. @@ -521,9 +521,9 @@ def _setup_dataloader_from_config(self, cfg: DictConfig, data_split: str): tokenizer_name=self.transformer_name, mode=self.mode, max_len=self.max_sequence_len, - decoder_data_augmentation=cfg.get('decoder_data_augmentation', False) - if data_split == "train" - else False, + decoder_data_augmentation=( + cfg.get('decoder_data_augmentation', False) if data_split == "train" else False + ), lang=self.lang, use_cache=cfg.get('use_cache', False), max_insts=cfg.get('max_insts', -1), diff --git a/nemo/collections/nlp/models/duplex_text_normalization/duplex_tagger.py b/nemo/collections/nlp/models/duplex_text_normalization/duplex_tagger.py index feeda99bdbe5..1ce005403999 100644 --- a/nemo/collections/nlp/models/duplex_text_normalization/duplex_tagger.py +++ b/nemo/collections/nlp/models/duplex_text_normalization/duplex_tagger.py @@ -16,8 +16,8 @@ from typing import Dict, List, Optional import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig -from pytorch_lightning import Trainer from torch import nn from transformers import AutoModelForTokenClassification, AutoTokenizer, DataCollatorForTokenClassification from transformers.tokenization_utils_base import BatchEncoding @@ -151,7 +151,7 @@ def on_test_epoch_end(self): # Functions for inference @torch.no_grad() def _infer(self, sents: List[List[str]], inst_directions: List[str]): - """ Main function for Inference + """Main function for Inference Args: sents: A list of inputs tokenized by a basic tokenizer. @@ -248,7 +248,7 @@ def _infer(self, sents: List[List[str]], inst_directions: List[str]): return all_tag_preds, nb_spans, span_starts, span_ends def _postprocess_tag_preds(self, words: List[str], inst_dir: str, preds: List[str]): - """ Function for postprocessing the raw tag predictions of the model. It + """Function for postprocessing the raw tag predictions of the model. It corrects obvious mistakes in the tag predictions such as a TRANSFORM span starts with I_TRANSFORM_TAG (instead of B_TRANSFORM_TAG). @@ -280,7 +280,7 @@ def _postprocess_tag_preds(self, words: List[str], inst_dir: str, preds: List[st return final_preds def decode_tag_preds(self, tag_preds: List[List[str]]): - """ Decoding the raw tag predictions to locate the semiotic spans in the + """Decoding the raw tag predictions to locate the semiotic spans in the input texts. Args: diff --git a/nemo/collections/nlp/models/enc_dec_nlp_model.py b/nemo/collections/nlp/models/enc_dec_nlp_model.py index d9aa3c017bae..60c6b616c20a 100644 --- a/nemo/collections/nlp/models/enc_dec_nlp_model.py +++ b/nemo/collections/nlp/models/enc_dec_nlp_model.py @@ -15,8 +15,8 @@ from dataclasses import dataclass from typing import Any +from lightning.pytorch.trainer.trainer import Trainer from omegaconf.omegaconf import MISSING -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.models.nlp_model import NLPModel from nemo.collections.nlp.modules.common.decoder_module import DecoderModule @@ -35,8 +35,7 @@ class EncDecNLPModelConfig(ModelConfig): class EncDecNLPModel(NLPModel): - """Base class for encoder-decoder NLP models. - """ + """Base class for encoder-decoder NLP models.""" def __init__(self, cfg: EncDecNLPModelConfig, trainer: Trainer = None): super().__init__(cfg=cfg, trainer=trainer) diff --git a/nemo/collections/nlp/models/entity_linking/entity_linking_model.py b/nemo/collections/nlp/models/entity_linking/entity_linking_model.py index 4afae81e3893..640520cdaaa7 100644 --- a/nemo/collections/nlp/models/entity_linking/entity_linking_model.py +++ b/nemo/collections/nlp/models/entity_linking/entity_linking_model.py @@ -15,8 +15,8 @@ from typing import Dict, Optional import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig -from pytorch_lightning import Trainer from transformers import AutoTokenizer from nemo.collections.common.losses import MultiSimilarityLoss diff --git a/nemo/collections/nlp/models/glue_benchmark/glue_benchmark_model.py b/nemo/collections/nlp/models/glue_benchmark/glue_benchmark_model.py index 4447ebb89386..e90cf9d88c30 100644 --- a/nemo/collections/nlp/models/glue_benchmark/glue_benchmark_model.py +++ b/nemo/collections/nlp/models/glue_benchmark/glue_benchmark_model.py @@ -19,8 +19,8 @@ import numpy as np import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig -from pytorch_lightning import Trainer from nemo.collections.common.losses import CrossEntropyLoss, MSELoss from nemo.collections.nlp.data.glue_benchmark.glue_benchmark_dataset import GLUE_TASKS_NUM_LABELS, GLUEDataset diff --git a/nemo/collections/nlp/models/information_retrieval/base_ir_model.py b/nemo/collections/nlp/models/information_retrieval/base_ir_model.py index 67424320d185..91d86fef1851 100644 --- a/nemo/collections/nlp/models/information_retrieval/base_ir_model.py +++ b/nemo/collections/nlp/models/information_retrieval/base_ir_model.py @@ -17,8 +17,8 @@ import numpy as np import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig, OmegaConf -from pytorch_lightning import Trainer from nemo.collections.nlp.data import BertInformationRetrievalDataset from nemo.collections.nlp.models.nlp_model import NLPModel diff --git a/nemo/collections/nlp/models/information_retrieval/bert_dpr_model.py b/nemo/collections/nlp/models/information_retrieval/bert_dpr_model.py index 03b62d91170c..bfbec123d13e 100644 --- a/nemo/collections/nlp/models/information_retrieval/bert_dpr_model.py +++ b/nemo/collections/nlp/models/information_retrieval/bert_dpr_model.py @@ -15,8 +15,8 @@ from typing import Dict, Optional import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig -from pytorch_lightning import Trainer from nemo.collections.common.losses import SmoothedCrossEntropyLoss from nemo.collections.nlp.data import BertInformationRetrievalDataset @@ -63,29 +63,50 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): @typecheck() def forward( - self, q_input_ids, q_token_type_ids, q_attention_mask, p_input_ids, p_token_type_ids, p_attention_mask, + self, + q_input_ids, + q_token_type_ids, + q_attention_mask, + p_input_ids, + p_token_type_ids, + p_attention_mask, ): q_vectors = self.q_encoder( - input_ids=q_input_ids, token_type_ids=q_token_type_ids, attention_mask=q_attention_mask, + input_ids=q_input_ids, + token_type_ids=q_token_type_ids, + attention_mask=q_attention_mask, ) q_vectors = q_vectors[:, 0] batch_size, hidden_size = q_vectors.size() p_vectors = self.p_encoder( - input_ids=p_input_ids, token_type_ids=p_token_type_ids, attention_mask=p_attention_mask, + input_ids=p_input_ids, + token_type_ids=p_token_type_ids, + attention_mask=p_attention_mask, ) num_passages = p_vectors.shape[0] // batch_size p_vectors = p_vectors[:, 0].view(-1, num_passages, hidden_size) p_positives, p_negatives = p_vectors[:, 0], p_vectors[:, 1:] scores = torch.cat( - (torch.matmul(q_vectors, p_positives.T), torch.einsum("ij,ipj->ip", q_vectors, p_negatives),), dim=1, + ( + torch.matmul(q_vectors, p_positives.T), + torch.einsum("ij,ipj->ip", q_vectors, p_negatives), + ), + dim=1, ) return scores def compute_scores_and_loss(self, inputs): - (q_input_ids, q_input_mask, q_input_type_ids, p_input_ids, p_input_mask, p_input_type_ids,) = inputs + ( + q_input_ids, + q_input_mask, + q_input_type_ids, + p_input_ids, + p_input_mask, + p_input_type_ids, + ) = inputs batch_size, num_passages, p_seq_length = p_input_ids.size() q_seq_length = q_input_ids.size()[-1] @@ -100,10 +121,17 @@ def compute_scores_and_loss(self, inputs): normalized_scores = torch.log_softmax(scores, dim=-1) labels = torch.arange(batch_size)[:, None].long().to(normalized_scores.device) - loss = self.loss(log_probs=normalized_scores, labels=labels, output_mask=torch.ones_like(labels),) + loss = self.loss( + log_probs=normalized_scores, + labels=labels, + output_mask=torch.ones_like(labels), + ) scores = scores[:, 0] - scores = torch.cat((torch.diag(scores)[:, None], scores[:, batch_size:]), dim=1,) + scores = torch.cat( + (torch.diag(scores)[:, None], scores[:, batch_size:]), + dim=1, + ) return scores, loss diff --git a/nemo/collections/nlp/models/information_retrieval/bert_joint_ir_model.py b/nemo/collections/nlp/models/information_retrieval/bert_joint_ir_model.py index a4dc4356342a..33885e6b50c6 100644 --- a/nemo/collections/nlp/models/information_retrieval/bert_joint_ir_model.py +++ b/nemo/collections/nlp/models/information_retrieval/bert_joint_ir_model.py @@ -15,8 +15,8 @@ from typing import Dict, Optional import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig -from pytorch_lightning import Trainer from nemo.collections.common.losses import SmoothedCrossEntropyLoss from nemo.collections.nlp.models.information_retrieval.base_ir_model import BaseIRModel @@ -53,7 +53,9 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.bert_model = self.get_lm_model_with_padded_embedding(cfg) hidden_size = self.bert_model.config.hidden_size self.sim_score_regressor = SequenceRegression( - hidden_size=hidden_size, num_layers=1, dropout=cfg.language_model.sim_score_dropout, + hidden_size=hidden_size, + num_layers=1, + dropout=cfg.language_model.sim_score_dropout, ) self.loss = SmoothedCrossEntropyLoss(pad_id=self.tokenizer.pad_id) @@ -61,7 +63,9 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): def forward(self, input_ids, attention_mask, token_type_ids): hidden_states = self.bert_model( - input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, + input_ids=input_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask, ) if isinstance(hidden_states, tuple): hidden_states = hidden_states[0] diff --git a/nemo/collections/nlp/models/information_retrieval/megatron_bert_embedding_model.py b/nemo/collections/nlp/models/information_retrieval/megatron_bert_embedding_model.py index 5e38b61938c9..a5b71d5bcb69 100644 --- a/nemo/collections/nlp/models/information_retrieval/megatron_bert_embedding_model.py +++ b/nemo/collections/nlp/models/information_retrieval/megatron_bert_embedding_model.py @@ -16,13 +16,11 @@ import os import numpy as np - - import torch +from lightning.pytorch.trainer.trainer import Trainer from megatron.core.models.bert.bert_layer_specs import bert_layer_with_transformer_engine_spec from omegaconf import DictConfig, OmegaConf, open_dict from omegaconf.dictconfig import DictConfig -from pytorch_lightning.trainer.trainer import Trainer from torch.distributed import all_gather as all_gather_no_backprop from torch.distributed.nn.functional import all_gather as all_gather_with_backprop @@ -46,7 +44,6 @@ from nemo.collections.nlp.parts.utils_funcs import get_last_rank from nemo.utils import logging - try: from megatron.core import parallel_state from megatron.core.pipeline_parallel.schedules import get_forward_backward_func diff --git a/nemo/collections/nlp/models/information_retrieval/megatron_gpt_embedding_model.py b/nemo/collections/nlp/models/information_retrieval/megatron_gpt_embedding_model.py index c7565f45358e..b5240ec2e170 100644 --- a/nemo/collections/nlp/models/information_retrieval/megatron_gpt_embedding_model.py +++ b/nemo/collections/nlp/models/information_retrieval/megatron_gpt_embedding_model.py @@ -17,8 +17,8 @@ import numpy as np import torch +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import DictConfig, ListConfig -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.data.information_retrieval.gpt_embedding_dataset import GPTEmbeddingDataset from nemo.collections.nlp.data.language_modeling.megatron.base_dataset_utils import ( diff --git a/nemo/collections/nlp/models/information_retrieval/megatron_gpt_reranker_model.py b/nemo/collections/nlp/models/information_retrieval/megatron_gpt_reranker_model.py index e316871fe607..fa593adf5c8f 100644 --- a/nemo/collections/nlp/models/information_retrieval/megatron_gpt_reranker_model.py +++ b/nemo/collections/nlp/models/information_retrieval/megatron_gpt_reranker_model.py @@ -17,8 +17,8 @@ import numpy as np import torch +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import DictConfig, ListConfig -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.data.information_retrieval.gpt_embedding_dataset import GPTRerankerDataset from nemo.collections.nlp.data.language_modeling.megatron.base_dataset_utils import ( diff --git a/nemo/collections/nlp/models/intent_slot_classification/intent_slot_classification_model.py b/nemo/collections/nlp/models/intent_slot_classification/intent_slot_classification_model.py index 0cd1d07af5dd..a49bc699ab24 100644 --- a/nemo/collections/nlp/models/intent_slot_classification/intent_slot_classification_model.py +++ b/nemo/collections/nlp/models/intent_slot_classification/intent_slot_classification_model.py @@ -17,8 +17,8 @@ from typing import Dict, List, Optional import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig, OmegaConf -from pytorch_lightning import Trainer from torch.utils.data import DataLoader from nemo.collections.common.losses import AggregatorLoss, CrossEntropyLoss @@ -38,8 +38,7 @@ class IntentSlotClassificationModel(NLPModel): def __init__(self, cfg: DictConfig, trainer: Trainer = None): - """ Initializes BERT Joint Intent and Slot model. - """ + """Initializes BERT Joint Intent and Slot model.""" self.max_seq_length = cfg.language_model.max_seq_length # init superclass # Check the presence of data_dir. @@ -75,7 +74,7 @@ def _set_defaults_data_desc(self, cfg): OmegaConf.set_struct(cfg, True) def _set_data_desc_to_cfg(self, cfg, data_dir, train_ds, validation_ds): - """ Method creates IntentSlotDataDesc and copies generated values to cfg.data_desc. """ + """Method creates IntentSlotDataDesc and copies generated values to cfg.data_desc.""" # Save data from data desc to config - so it can be reused later, e.g. in inference. data_desc = IntentSlotDataDesc(data_dir=data_dir, modes=[train_ds.prefix, validation_ds.prefix]) OmegaConf.set_struct(cfg, False) @@ -109,7 +108,7 @@ def _set_data_desc_to_cfg(self, cfg, data_dir, train_ds, validation_ds): OmegaConf.set_struct(cfg, True) def _save_label_ids(self, label_ids: Dict[str, int], filename: str) -> None: - """ Saves label ids map to a file """ + """Saves label ids map to a file""" with open(filename, 'w') as out: labels, _ = zip(*sorted(label_ids.items(), key=lambda x: x[1])) out.write('\n'.join(labels)) @@ -117,7 +116,7 @@ def _save_label_ids(self, label_ids: Dict[str, int], filename: str) -> None: logging.info(f'Labels mapping saved to : {out.name}') def _reconfigure_classifier(self): - """ Method reconfigures the classifier depending on the settings of model cfg.data_desc """ + """Method reconfigures the classifier depending on the settings of model cfg.data_desc""" self.classifier = SequenceTokenClassifier( hidden_size=self.hidden_size, diff --git a/nemo/collections/nlp/models/intent_slot_classification/multi_label_intent_slot_classification_model.py b/nemo/collections/nlp/models/intent_slot_classification/multi_label_intent_slot_classification_model.py index c689b97ab0a5..7a2bec1f2cc0 100644 --- a/nemo/collections/nlp/models/intent_slot_classification/multi_label_intent_slot_classification_model.py +++ b/nemo/collections/nlp/models/intent_slot_classification/multi_label_intent_slot_classification_model.py @@ -18,8 +18,8 @@ import numpy as np import numpy.typing as npt import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig, OmegaConf -from pytorch_lightning import Trainer from sklearn.metrics import f1_score, precision_score, recall_score from torch.utils.data import DataLoader @@ -38,10 +38,10 @@ class MultiLabelIntentSlotClassificationModel(IntentSlotClassificationModel): def __init__(self, cfg: DictConfig, trainer: Trainer = None): - """ + """ Initializes BERT Joint Intent and Slot model. - Args: + Args: cfg: configuration object trainer: trainer for Pytorch Lightning """ @@ -69,12 +69,12 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): def _set_data_desc_to_cfg( self, cfg: DictConfig, data_dir: str, train_ds: DictConfig, validation_ds: DictConfig ) -> None: - """ - Creates MultiLabelIntentSlotDataDesc and copies generated values to Configuration object's data descriptor. - - Args: + """ + Creates MultiLabelIntentSlotDataDesc and copies generated values to Configuration object's data descriptor. + + Args: cfg: configuration object - data_dir: data directory + data_dir: data directory train_ds: training dataset file name validation_ds: validation dataset file name @@ -101,7 +101,10 @@ def _set_data_desc_to_cfg( if not hasattr(cfg, "class_labels") or cfg.class_labels is None: cfg.class_labels = {} cfg.class_labels = OmegaConf.create( - {"intent_labels_file": "intent_labels.csv", "slot_labels_file": "slot_labels.csv",} + { + "intent_labels_file": "intent_labels.csv", + "slot_labels_file": "slot_labels.csv", + } ) slot_labels_file = os.path.join(data_dir, cfg.class_labels.slot_labels_file) @@ -114,7 +117,7 @@ def _set_data_desc_to_cfg( OmegaConf.set_struct(cfg, True) def _reconfigure_classifier(self) -> None: - """ Method reconfigures the classifier depending on the settings of model cfg.data_desc """ + """Method reconfigures the classifier depending on the settings of model cfg.data_desc""" self.classifier = SequenceTokenClassifier( hidden_size=self.bert_model.config.hidden_size, @@ -135,7 +138,8 @@ def _reconfigure_classifier(self) -> None: self.slot_loss = CrossEntropyLoss(logits_ndim=3) self.total_loss = AggregatorLoss( - num_inputs=2, weights=[self.cfg.intent_loss_weight, 1.0 - self.cfg.intent_loss_weight], + num_inputs=2, + weights=[self.cfg.intent_loss_weight, 1.0 - self.cfg.intent_loss_weight], ) # setup to track metrics @@ -161,12 +165,22 @@ def validation_step(self, batch, batch_idx) -> None: batch: batches of data from DataLoader batch_idx: batch idx from DataLoader - Returns: + Returns: None """ - (input_ids, input_type_ids, input_mask, loss_mask, subtokens_mask, intent_labels, slot_labels,) = batch + ( + input_ids, + input_type_ids, + input_mask, + loss_mask, + subtokens_mask, + intent_labels, + slot_labels, + ) = batch intent_logits, slot_logits = self( - input_ids=input_ids, token_type_ids=input_type_ids, attention_mask=input_mask, + input_ids=input_ids, + token_type_ids=input_type_ids, + attention_mask=input_mask, ) # calculate combined loss for intents and slots @@ -201,7 +215,7 @@ def _setup_dataloader_from_config(self, cfg: DictConfig) -> DataLoader: Args: cfg: configuration object - + Returns: DataLoader for model's data """ @@ -289,8 +303,8 @@ def prediction_probabilities(self, queries: List[str], test_ds: DictConfig) -> n def optimize_threshold(self, test_ds: DictConfig, file_name: str) -> None: """ - Set the optimal threshold of the model from performance on validation set. This threshold is used to round the - logits to 0 or 1. + Set the optimal threshold of the model from performance on validation set. This threshold is used to round the + logits to 0 or 1. Args: test_ds: location of test dataset @@ -361,16 +375,16 @@ def predict_from_examples( queries: text sequences test_ds: Dataset configuration section. threshold: Threshold for rounding prediction logits - + Returns: predicted_intents: model intent predictions with their probabilities - Example: [[('flight', 0.84)], [('airfare', 0.54), + Example: [[('flight', 0.84)], [('airfare', 0.54), ('flight', 0.73), ('meal', 0.24)]] predicted_slots: model slot predictions Example: ['O B-depart_date.month_name B-depart_date.day_number', 'O O B-flight_stop O O O'] - predicted_vector: model intent predictions for each individual query. Binary values within each list + predicted_vector: model intent predictions for each individual query. Binary values within each list indicate whether a class is prediced for the given query (1 for True, 0 for False) Example: [[1,0,0,0,0,0], [0,0,1,0,0,0]] """ diff --git a/nemo/collections/nlp/models/language_modeling/bert_lm_model.py b/nemo/collections/nlp/models/language_modeling/bert_lm_model.py index 6b03d86982b0..dc7103b67aa6 100644 --- a/nemo/collections/nlp/models/language_modeling/bert_lm_model.py +++ b/nemo/collections/nlp/models/language_modeling/bert_lm_model.py @@ -16,8 +16,8 @@ from typing import Dict, Optional import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig, OmegaConf -from pytorch_lightning import Trainer from torch.utils.data import DataLoader from nemo.collections.common.losses import AggregatorLoss, CrossEntropyLoss, SmoothedCrossEntropyLoss @@ -75,7 +75,11 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): config_file = self.register_artifact('language_model.config_file', cfg.language_model.config_file) self.bert_model = get_lm_model( - config_file=config_file, config_dict=config_dict, vocab_file=vocab_file, trainer=trainer, cfg=cfg, + config_file=config_file, + config_dict=config_dict, + vocab_file=vocab_file, + trainer=trainer, + cfg=cfg, ) self.hidden_size = self.bert_model.config.hidden_size @@ -127,7 +131,9 @@ def forward(self, input_ids, attention_mask, token_type_ids): in the `nn.Module` in vanilla PyTorch. """ hidden_states = self.bert_model( - input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, + input_ids=input_ids, + token_type_ids=token_type_ids, + attention_mask=attention_mask, ) if isinstance(hidden_states, tuple): hidden_states = hidden_states[0] @@ -225,7 +231,9 @@ def _setup_preprocessed_dataloader(self, cfg: Optional[DictConfig]): files = [dataset] files.sort() dl = BertPretrainingPreprocessedDataloader( - data_files=files, max_predictions_per_seq=max_predictions_per_seq, batch_size=batch_size, + data_files=files, + max_predictions_per_seq=max_predictions_per_seq, + batch_size=batch_size, ) return dl diff --git a/nemo/collections/nlp/models/language_modeling/megatron/bert/bert_model.py b/nemo/collections/nlp/models/language_modeling/megatron/bert/bert_model.py index 0d75ab7cc706..c629db5af3c3 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron/bert/bert_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron/bert/bert_model.py @@ -208,6 +208,7 @@ def forward( rotary_pos_emb=None, rotary_pos_cos=None, rotary_pos_sin=None, + attention_bias=None, inference_params=None, packed_seq_params=None, ): diff --git a/nemo/collections/nlp/models/language_modeling/megatron/falcon/falcon_decoder_layer.py b/nemo/collections/nlp/models/language_modeling/megatron/falcon/falcon_decoder_layer.py index 131f154d6709..7c3f3c194f14 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron/falcon/falcon_decoder_layer.py +++ b/nemo/collections/nlp/models/language_modeling/megatron/falcon/falcon_decoder_layer.py @@ -108,6 +108,7 @@ def forward( rotary_pos_emb=None, rotary_pos_cos=None, rotary_pos_sin=None, + attention_bias=None, inference_params=None, packed_seq_params=None, ): diff --git a/nemo/collections/nlp/models/language_modeling/megatron/gemma2/gemma2_modules.py b/nemo/collections/nlp/models/language_modeling/megatron/gemma2/gemma2_modules.py index 5113ee745895..9ea1b4afe318 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron/gemma2/gemma2_modules.py +++ b/nemo/collections/nlp/models/language_modeling/megatron/gemma2/gemma2_modules.py @@ -49,7 +49,8 @@ class Gemma2DotProductAttention(MegatronModule): Region where selective activation recomputation is applied. This region is memory intensive but less compute intensive which makes activation checkpointing more efficient for LLMs (20B+). - See Reducing Activation Recomputation in Large Transformer Models: https://arxiv.org/abs/2205.05198 for more details. + See Reducing Activation Recomputation in Large Transformer Models: + https://arxiv.org/abs/2205.05198 for more details. We use the following notation: h: hidden size @@ -126,7 +127,12 @@ def forward( attention_mask: Tensor, attn_mask_type: AttnMaskType = None, packed_seq_params: PackedSeqParams = None, + **kwargs, ): + """Forward. + Modified from mcore.transformer.dot_product_attention to support Gemma2-specific + final_logit_softcapping. + """ assert packed_seq_params is None, ( "Packed sequence is not supported by DotProductAttention." "Please use TEDotProductAttention instead." ) @@ -243,6 +249,8 @@ def forward( class TERowParallelLinearLayerNorm(TERowParallelLinear): + """Modified From TERowParallelLinear with an additional Post-LN.""" + def __init__( self, input_size: int, @@ -270,12 +278,16 @@ def __init__( self.post_layernorm = TENorm(config, output_size) def forward(self, x): + """Forward with additional Post LN on output""" output, bias = super().forward(x) return self.post_layernorm(output), bias class Gemma2OutputLayer(ColumnParallelLinear): + """Extends from ColumnParallelLinear with logit soft capping.""" + def forward(self, *args, **kwargs): + """Forward with logit soft capping.""" output, bias = super().forward(*args, **kwargs) output = logit_softcapping(output, self.config.final_logit_softcapping) return output, bias diff --git a/nemo/collections/nlp/models/language_modeling/megatron/gpt_full_te_layer_autocast_spec.py b/nemo/collections/nlp/models/language_modeling/megatron/gpt_full_te_layer_autocast_spec.py index d1945139dee9..1def214113ee 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron/gpt_full_te_layer_autocast_spec.py +++ b/nemo/collections/nlp/models/language_modeling/megatron/gpt_full_te_layer_autocast_spec.py @@ -252,6 +252,7 @@ def forward( rotary_pos_emb=None, rotary_pos_cos=None, rotary_pos_sin=None, + attention_bias=None, inference_params=None, packed_seq_params=None, # TODO: handle this ): diff --git a/nemo/collections/nlp/models/language_modeling/megatron_bart_model.py b/nemo/collections/nlp/models/language_modeling/megatron_bart_model.py index 1c768829e3e2..4a53edacb566 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_bart_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_bart_model.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from lightning.pytorch.trainer.trainer import Trainer from omegaconf.dictconfig import DictConfig -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.models.language_modeling.megatron_t5_model import MegatronT5Model @@ -48,7 +48,9 @@ def _validate_cfg(self): @property def _build_train_valid_test_datasets_kwargs(self): """allows child classes to add kwargs to dataset building""" - return dict(delete_mask_prob=self._cfg.data.get('delete_mask_prob', 0.0),) + return dict( + delete_mask_prob=self._cfg.data.get('delete_mask_prob', 0.0), + ) def list_available_models(self): pass diff --git a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py index d2a21e50e486..37ec8a82cef1 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_base_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_base_model.py @@ -23,12 +23,12 @@ import omegaconf import torch import torch.nn as nn +from lightning.pytorch.plugins.precision import MixedPrecisionPlugin +from lightning.pytorch.trainer.connectors.logger_connector.fx_validator import _FxValidator +from lightning.pytorch.trainer.trainer import Trainer +from lightning.pytorch.utilities.exceptions import MisconfigurationException from omegaconf import OmegaConf, open_dict from omegaconf.dictconfig import DictConfig -from pytorch_lightning.plugins.precision import MixedPrecisionPlugin -from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator -from pytorch_lightning.trainer.trainer import Trainer -from pytorch_lightning.utilities.exceptions import MisconfigurationException from nemo.collections.nlp.models.nlp_model import NLPModel from nemo.collections.nlp.modules.common.megatron.attention import HAVE_FLASH_ATTENTION @@ -200,7 +200,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer, no_lm_init=True): global_batch_size=cfg.get('global_batch_size'), rampup_batch_size=cfg.get('rampup_batch_size', None), use_fp8=cfg.get('fp8', False), - init_mpi_proc_group=cfg.get('ub_tp_comm_overlap', False), + init_mpi_proc_group=cfg.get('ub_tp_comm_overlap', False) + and cfg.get('ub_tp_comm_bootstrap_backend', 'nccl') == 'mpi', seed=self.cfg.get('seed', 1234), apex_transformer_log_level=self.cfg.get('apex_transformer_log_level', 30), use_te_rng_tracker=self.cfg.get('use_te_rng_tracker', False), @@ -1173,6 +1174,7 @@ def build_model_parallel_config(self) -> ModelParallelConfig: "grad_sync_func": None, # set dynamically during training "param_sync_func": None, # set dynamically during training "tp_comm_overlap": self.cfg.get('ub_tp_comm_overlap', False), + "tp_comm_bootstrap_backend": self.cfg.get('ub_tp_comm_bootstrap_backend', 'nccl'), } # instantitate ModelParallelConfig from this dict diff --git a/nemo/collections/nlp/models/language_modeling/megatron_base_prompt_learning_model.py b/nemo/collections/nlp/models/language_modeling/megatron_base_prompt_learning_model.py index 2a356012c728..b00b6fcf0302 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_base_prompt_learning_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_base_prompt_learning_model.py @@ -18,9 +18,9 @@ from typing import Any, Optional import torch +from lightning.pytorch.trainer.trainer import Trainer from omegaconf.dictconfig import DictConfig from omegaconf.omegaconf import open_dict -from pytorch_lightning.trainer.trainer import Trainer from torch import Tensor from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer diff --git a/nemo/collections/nlp/models/language_modeling/megatron_bert_model.py b/nemo/collections/nlp/models/language_modeling/megatron_bert_model.py index 0eb5ea1c0048..e6945d1ada56 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_bert_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_bert_model.py @@ -18,8 +18,8 @@ import torch import torch.nn.functional as F +from lightning.pytorch.trainer.trainer import Trainer from omegaconf.dictconfig import DictConfig -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.data.language_modeling.megatron import dataset_utils from nemo.collections.nlp.data.language_modeling.megatron.data_samplers import ( diff --git a/nemo/collections/nlp/models/language_modeling/megatron_glue_model.py b/nemo/collections/nlp/models/language_modeling/megatron_glue_model.py index c0a4b6351530..d3829c3e8de1 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_glue_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_glue_model.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from lightning.pytorch.trainer.trainer import Trainer from omegaconf.dictconfig import DictConfig -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.data.glue_benchmark.glue_benchmark_dataset import ( TextToTextGLUEDataset, diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_adapter_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_adapter_model.py index c6b4d055ef6e..44860c3178f6 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_adapter_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_adapter_model.py @@ -21,8 +21,8 @@ import os import torch +from lightning.pytorch.trainer.trainer import Trainer from omegaconf.dictconfig import DictConfig -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.common.parts.adapter_modules import LinearAdapterConfig from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel @@ -162,7 +162,7 @@ def state_dict(self, destination=None, prefix=None, keep_vars=False): def load_state_dict(self, state_dict, strict: bool = True): """ - Loads a state_dict expecting the state_dict to contain key,values + Loads a state_dict expecting the state_dict to contain key,values only for the adapter parameters. """ for name, module in self.frozen_model.named_modules(): @@ -176,13 +176,13 @@ def load_state_dict(self, state_dict, strict: bool = True): def setup_optimizer_param_groups(self): """ - ModelPT override. Optimizer will get self._optimizer_param_groups. + ModelPT override. Optimizer will get self._optimizer_param_groups. Makes two optimizer param groups, one for the frozen model params - and one for the prompt-table/prompt-encoder params. The learning + and one for the prompt-table/prompt-encoder params. The learning rate for the frozen model's params will always be zero effectively freezing the model's params but still allowing for the needed gradients - to be passed around in pipeline parallel models. The prompt-encoder - and/or prompt table will use the learning rate set by the user. + to be passed around in pipeline parallel models. The prompt-encoder + and/or prompt table will use the learning rate set by the user. """ self.frozen_model.freeze() # Freeze the entire model opt_params = [] @@ -246,8 +246,8 @@ class MegatronGPTAdapterLearningModel(MegatronGPTBaseAdapterModel): Two adapter's are inserted into each Transformer layer in the base GPT Model. It is assumed that these set of adapters will then be trained for a specific task. - Once trained, the adapter weights will be saved and can be re-loaded - and infused into the same GPT Model for inference. + Once trained, the adapter weights will be saved and can be re-loaded + and infused into the same GPT Model for inference. """ def __init__(self, cfg: DictConfig, trainer: Trainer): @@ -295,7 +295,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): for adapter_key in self.adapter_name_keys: if model_utils.import_class_by_path(adapter_cfg._target_) in module.get_accepted_adapter_types(): module.add_adapter( - name=adapter_key, cfg=adapter_cfg, + name=adapter_key, + cfg=adapter_cfg, ) logging.info(f'After adding adapters:\n{self.frozen_model.summarize()}') @@ -313,8 +314,8 @@ class MegatronGPTInfusedAdapterModel(MegatronGPTBaseAdapterModel): Three adapter's are inserted into each Transformer layer in the base GPT Model. Each adapter is basically a vector that simply scales the key, value or ffn hidden representations. It is assumed that these set of adapters will then be trained for a specific task. - Once trained, the adapter weights will be saved and can be re-loaded - and infused into the same GPT Model for inference. + Once trained, the adapter weights will be saved and can be re-loaded + and infused into the same GPT Model for inference. """ def __init__(self, cfg: DictConfig, trainer: Trainer): diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 8f541e5703e6..a4b8242e0185 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -24,11 +24,11 @@ import packaging import torch +from lightning.pytorch.accelerators import CPUAccelerator +from lightning.pytorch.loops.fetchers import _DataFetcherWrapper +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import OmegaConf from omegaconf.dictconfig import DictConfig -from pytorch_lightning.accelerators import CPUAccelerator -from pytorch_lightning.loops.fetchers import _DataFetcherWrapper -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.common.parts.utils import apply_rope_scaling, extend_instance from nemo.collections.nlp.data.language_modeling.megatron.data_samplers import ( @@ -803,6 +803,7 @@ def initialize_ub_func(self): tp_size=self.cfg.get('tensor_model_parallel_size'), use_fp8=self.cfg.get('fp8'), ub_cfgs=ub_cfgs, + bootstrap_backend=self.cfg.get('ub_tp_comm_bootstrap_backend', 'nccl'), ) self.initialize_ub = False @@ -1230,22 +1231,23 @@ def get_batch_on_this_context_parallel_rank(self, batch): cp_size = parallel_state.get_context_parallel_world_size() if cp_size > 1: cp_rank = parallel_state.get_context_parallel_rank() - for key, val in batch.items(): - if val is not None and key != "context_lengths": - seq_dim = 1 if key != 'attention_mask' else 2 - val = val.view( - *val.shape[0:seq_dim], - 2 * cp_size, - val.shape[seq_dim] // (2 * cp_size), - *val.shape[(seq_dim + 1) :], - ) - index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device="cpu", pin_memory=True).cuda( - non_blocking=True - ) - val = val.index_select(seq_dim, index) - val = val.view(*val.shape[0:seq_dim], -1, *val.shape[(seq_dim + 2) :]) - batch[key] = val - + # check if the batch is not in THD format + if 'cu_seqlens' not in batch: + for key, val in batch.items(): + if val is not None and key != "context_lengths": + seq_dim = 1 if key != 'attention_mask' else 2 + val = val.view( + *val.shape[0:seq_dim], + 2 * cp_size, + val.shape[seq_dim] // (2 * cp_size), + *val.shape[(seq_dim + 1) :], + ) + index = torch.tensor( + [cp_rank, (2 * cp_size - cp_rank - 1)], device="cpu", pin_memory=True + ).cuda(non_blocking=True) + val = val.index_select(seq_dim, index) + val = val.view(*val.shape[0:seq_dim], -1, *val.shape[(seq_dim + 2) :]) + batch[key] = val batch['num_valid_tokens_in_ub'] = num_valid_tokens_in_ub return batch @@ -1260,12 +1262,17 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_ required_keys = set() max_seqlen = batch['max_seqlen'].squeeze() if 'max_seqlen' in batch else None cu_seqlens_argmin = batch['cu_seqlens_argmin'] if 'cu_seqlens_argmin' in batch else None + cu_seqlens_unpadded_argmin = ( + batch['cu_seqlens_unpadded_argmin'] if 'cu_seqlens_unpadded_argmin' in batch else None + ) if parallel_state.get_pipeline_model_parallel_world_size() == 1: required_keys.update(batch.keys()) else: required_keys.add('attention_mask') if 'cu_seqlens' in batch: required_keys.add('cu_seqlens') + if 'cu_seqlens_unpadded' in batch: + required_keys.add('cu_seqlens_unpadded') if parallel_state.is_pipeline_first_stage(): required_keys.update(('tokens', 'position_ids')) if parallel_state.is_pipeline_last_stage(): @@ -1300,12 +1307,16 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_ if 'cu_seqlens' in batch: # packed sequence from GPTSFTPackedDataset # these args are passed eventually into TEDotProductAttention.forward() cu_seqlens = batch['cu_seqlens'].squeeze() # remove batch size dimension (mbs=1) + cu_seqlens_unpadded = batch['cu_seqlens_unpadded'].squeeze() # remove -1 "paddings" added in collate_fn if cu_seqlens_argmin is not None: cu_seqlens = cu_seqlens[: cu_seqlens_argmin.item()] else: cu_seqlens = cu_seqlens[: torch.argmin(cu_seqlens)] - + if cu_seqlens_unpadded_argmin is not None: + cu_seqlens_unpadded = cu_seqlens_unpadded[: cu_seqlens_unpadded_argmin.item()] + else: + cu_seqlens_unpadded = cu_seqlens_unpadded[: torch.argmin(cu_seqlens_unpadded)] try: from megatron.core.packed_seq_params import PackedSeqParams except (ImportError, ModuleNotFoundError) as e: @@ -1316,9 +1327,42 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_ ) raise e + # get packed sequences for this context parallel rank + cp_size = parallel_state.get_context_parallel_world_size() + if cp_size > 1: + try: + import transformer_engine_torch as tex + except ModuleNotFoundError as e: + logging.error( + "Please update Transformer Engine to >= 1.10 to use Context Parallel with THD format data" + ) + raise e + cp_rank = parallel_state.get_context_parallel_rank() + for key in required_keys: + val = batch[key] + if key not in { + "cu_seqlens", + "cu_seqlens_unpadded", + "cu_seqlens_argmin", + "cu_seqlens_unpadded_argmin", + "max_seqlen", + "token_count", + }: + index = tex.thd_get_partitioned_indices(cu_seqlens, val.size(1), cp_size, cp_rank) + val = val.index_select(1, index) + batch[key] = val + forward_args = { + 'input_ids': batch['tokens'], + 'position_ids': batch['position_ids'], + 'attention_mask': None if self.get_attention_mask_from_fusion else batch['attention_mask'], + 'labels': batch['labels'] if 'labels' in batch else None, + } + forward_args['packed_seq_params'] = PackedSeqParams( - cu_seqlens_q=cu_seqlens, - cu_seqlens_kv=cu_seqlens, + cu_seqlens_q=cu_seqlens_unpadded, + cu_seqlens_kv=cu_seqlens_unpadded, + cu_seqlens_q_padded=cu_seqlens, + cu_seqlens_kv_padded=cu_seqlens, max_seqlen_q=max_seqlen, max_seqlen_kv=max_seqlen, qkv_format='thd', diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_prompt_learning_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_prompt_learning_model.py index 78f671142c1b..7d39459ae654 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_prompt_learning_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_prompt_learning_model.py @@ -18,9 +18,9 @@ from typing import Any, List, Optional, Union import torch +from lightning.pytorch.trainer.trainer import Trainer from omegaconf.dictconfig import DictConfig from omegaconf.omegaconf import open_dict -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer from nemo.collections.nlp.data.language_modeling.megatron.gpt_prompt_learning_dataset import GPTPromptLearningDataset diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py index 08bc5501363c..2d3f43b2f2a8 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py @@ -17,9 +17,9 @@ from typing import Any, Optional import torch +from lightning.pytorch.loops.fetchers import _DataFetcherWrapper +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import DictConfig, ListConfig -from pytorch_lightning.loops.fetchers import _DataFetcherWrapper -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.common.metrics import MetricStringToTorchMetric from nemo.collections.nlp.data.language_modeling.megatron.base_dataset_utils import ( diff --git a/nemo/collections/nlp/models/language_modeling/megatron_griffin_model.py b/nemo/collections/nlp/models/language_modeling/megatron_griffin_model.py index 1e5a2f0c15c0..40e147b90903 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_griffin_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_griffin_model.py @@ -13,8 +13,8 @@ # limitations under the License. import torch +from lightning.pytorch.trainer.trainer import Trainer from omegaconf.dictconfig import DictConfig -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.models.language_modeling.megatron.griffin.griffin_model import GriffinModel from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel diff --git a/nemo/collections/nlp/models/language_modeling/megatron_griffin_sft_model.py b/nemo/collections/nlp/models/language_modeling/megatron_griffin_sft_model.py index c53d231b2719..584a4b0572f7 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_griffin_sft_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_griffin_sft_model.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import DictConfig from omegaconf.dictconfig import DictConfig -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.models.language_modeling.megatron_base_model import MegatronBaseModel from nemo.collections.nlp.models.language_modeling.megatron_gpt_sft_model import MegatronGPTSFTModel @@ -48,8 +48,8 @@ def _reset_activation_checkpointing_args(self): def on_validation_model_zero_grad(self) -> None: """ - Skip gradient zeroing at the beginning of validation routine. - This is needed when overlapping the AllGather of the updated parameters with the following valdation step. - """ + Skip gradient zeroing at the beginning of validation routine. + This is needed when overlapping the AllGather of the updated parameters with the following valdation step. + """ if not self.validation_param_sync_overlap: MegatronBaseModel.on_validation_model_zero_grad(self) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py b/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py index 7b92b9e25d69..e530a40d8aaa 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_lm_encoder_decoder_model.py @@ -18,11 +18,11 @@ from typing import Any, Dict, List, Optional import torch +from lightning.pytorch.accelerators import CPUAccelerator +from lightning.pytorch.loops.fetchers import _DataFetcherWrapper +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import OmegaConf, open_dict from omegaconf.dictconfig import DictConfig -from pytorch_lightning.accelerators import CPUAccelerator -from pytorch_lightning.loops.fetchers import _DataFetcherWrapper -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.data.language_modeling.megatron.data_samplers import ( MegatronPretrainingRandomSampler, @@ -32,12 +32,10 @@ from nemo.collections.nlp.modules.common.megatron.build_model import build_model from nemo.collections.nlp.modules.common.megatron.module import Float16Module from nemo.collections.nlp.modules.common.megatron.token_level_encoder_decoder import ( - AttnMaskType, MegatronTokenLevelEncoderDecoderModule, ) from nemo.collections.nlp.modules.common.megatron.utils import ( average_losses_across_data_parallel_group, - build_attention_mask_3d, get_params_for_weight_decay_optimization, ) from nemo.collections.nlp.modules.common.text_generation_utils import ( @@ -683,14 +681,13 @@ def fwd_output_and_loss_func(dataloader_iter, model): if self.mcore_t5: # attn mask logic follows megatron.data.t5_dataset.py in Megatron-LM - encoder_attn_mask_3d = build_attention_mask_3d( - encoder_attn_mask, encoder_attn_mask, AttnMaskType.padding - ) - decoder_attn_mask_3d = build_attention_mask_3d( - decoder_attn_mask, decoder_attn_mask, AttnMaskType.causal - ) - enc_dec_attn_mask_3d = build_attention_mask_3d( - decoder_attn_mask, encoder_attn_mask, AttnMaskType.padding + encoder_attn_mask = encoder_attn_mask < 0.5 + decoder_attn_mask = decoder_attn_mask < 0.5 + encoder_attn_mask_3d = encoder_attn_mask.unsqueeze(1).unsqueeze(1) + decoder_attn_mask_3d = decoder_attn_mask.unsqueeze(1).unsqueeze(1) + enc_dec_attn_mask_3d = ( + decoder_attn_mask_3d, + encoder_attn_mask_3d, ) output = model( # model is MCoreT5Model @@ -816,10 +813,8 @@ def fwd_output_only_func(dataloader_iter, model): encoder_attn_mask, ) = batch - # attn mask logic follows megatron.data.t5_dataset.py in Megatron-LM - encoder_attn_mask_3d = build_attention_mask_3d( - encoder_attn_mask, encoder_attn_mask, AttnMaskType.padding - ) + encoder_attn_mask = encoder_attn_mask < 0.5 + encoder_attn_mask_3d = encoder_attn_mask.unsqueeze(1).unsqueeze(1) output = model( encoder_input_ids=encoder_input_ids, @@ -841,15 +836,13 @@ def fwd_output_only_func(dataloader_iter, model): decoder_attn_mask, ) = batch - # attn mask logic follows megatron.data.t5_dataset.py in Megatron-LM - encoder_attn_mask_3d = build_attention_mask_3d( - encoder_attn_mask, encoder_attn_mask, AttnMaskType.padding - ) - decoder_attn_mask_3d = build_attention_mask_3d( - decoder_attn_mask, decoder_attn_mask, AttnMaskType.causal - ) - enc_dec_attn_mask_3d = build_attention_mask_3d( - decoder_attn_mask, encoder_attn_mask, AttnMaskType.padding + encoder_attn_mask = encoder_attn_mask < 0.5 + decoder_attn_mask = decoder_attn_mask < 0.5 + encoder_attn_mask_3d = encoder_attn_mask.unsqueeze(1).unsqueeze(1) + decoder_attn_mask_3d = decoder_attn_mask.unsqueeze(1).unsqueeze(1) + enc_dec_attn_mask_3d = ( + decoder_attn_mask_3d, + encoder_attn_mask_3d, ) # re-transpose encoder_hidden_states from [batch, seq_len, hidden] to [seq_len, batch, hidden] diff --git a/nemo/collections/nlp/models/language_modeling/megatron_mamba_model.py b/nemo/collections/nlp/models/language_modeling/megatron_mamba_model.py index 4f0000dafaa2..ad92421ee607 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_mamba_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_mamba_model.py @@ -13,8 +13,8 @@ # limitations under the License. import torch +from lightning.pytorch.trainer.trainer import Trainer from omegaconf.dictconfig import DictConfig -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel from nemo.utils import logging diff --git a/nemo/collections/nlp/models/language_modeling/megatron_mamba_sft_model.py b/nemo/collections/nlp/models/language_modeling/megatron_mamba_sft_model.py index ebcc47004711..cacdb1c190e7 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_mamba_sft_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_mamba_sft_model.py @@ -12,15 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import DictConfig from omegaconf.dictconfig import DictConfig -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.models.language_modeling.megatron_base_model import MegatronBaseModel from nemo.collections.nlp.models.language_modeling.megatron_gpt_sft_model import MegatronGPTSFTModel from nemo.collections.nlp.models.language_modeling.megatron_mamba_model import MegatronMambaModel - __all__ = ['MegatronMambaSFTModel'] diff --git a/nemo/collections/nlp/models/language_modeling/megatron_retrieval_model.py b/nemo/collections/nlp/models/language_modeling/megatron_retrieval_model.py index 42323e503f7d..147c832f4b9a 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_retrieval_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_retrieval_model.py @@ -16,8 +16,8 @@ from typing import Any, List, Optional, Union import torch +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import DictConfig -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.data.language_modeling.megatron.data_samplers import ( MegatronPretrainingRandomSampler, @@ -294,7 +294,10 @@ def training_step(self, batch, batch_idx): self.log('lr', lr, batch_size=1) self.log('global_step', self.trainer.global_step, prog_bar=True, batch_size=1) self.log( - 'consumed_samples', self._compute_consumed_samples_after_training_step(), prog_bar=True, batch_size=1, + 'consumed_samples', + self._compute_consumed_samples_after_training_step(), + prog_bar=True, + batch_size=1, ) self._reduced_loss_buffer = [] return lm_loss @@ -427,7 +430,10 @@ def build_pretraining_data_loader(self, dataset, consumed_samples): # Torch dataloader. return torch.utils.data.DataLoader( - dataset, batch_sampler=batch_sampler, num_workers=self.cfg.data.num_workers, pin_memory=True, + dataset, + batch_sampler=batch_sampler, + num_workers=self.cfg.data.num_workers, + pin_memory=True, ) def setup(self, stage=None): diff --git a/nemo/collections/nlp/models/language_modeling/megatron_retro_fine_tune_model.py b/nemo/collections/nlp/models/language_modeling/megatron_retro_fine_tune_model.py index 1eaec4238648..924da5825024 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_retro_fine_tune_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_retro_fine_tune_model.py @@ -15,8 +15,8 @@ from functools import partial import torch +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import DictConfig, ListConfig -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.common.data import ConcatMapDataset from nemo.collections.common.metrics import MetricStringToTorchMetric @@ -50,11 +50,13 @@ def build_all_datasets( - cfg, tokenizer, train_valid_test_num_samples, + cfg, + tokenizer, + train_valid_test_num_samples, ): """Build train, valid, and test RETRO datasets. - There is one to one mapping between data_prefix and knn_map_path. - Currently only supports one retrieval dataset. + There is one to one mapping between data_prefix and knn_map_path. + Currently only supports one retrieval dataset. """ train_dataset = RetroQAFineTuneDataset( cfg.train_ds.get('file_name'), @@ -97,7 +99,7 @@ def build_all_datasets( class MegatronRetroFinetuneModel(MegatronRetrievalModel): - """Finetune RETRO Model """ + """Finetune RETRO Model""" def build_train_valid_test_datasets(self): logging.info('Building RETRO datasets.') @@ -114,7 +116,9 @@ def build_train_valid_test_datasets(self): ] self._train_ds, self._validation_ds, self._test_ds = build_all_datasets( - cfg=self.cfg.data, tokenizer=self.tokenizer, train_valid_test_num_samples=train_valid_test_num_samples, + cfg=self.cfg.data, + tokenizer=self.tokenizer, + train_valid_test_num_samples=train_valid_test_num_samples, ) if self._train_ds is not None: logging.info(f'Length of train dataset: {len(self._train_ds)}') @@ -143,5 +147,9 @@ def build_pretraining_data_loader(self, dataset, consumed_samples): drop_last=True, ) return torch.utils.data.DataLoader( - dataset, batch_sampler=batch_sampler, collate_fn=collate_fn, num_workers=0, pin_memory=True, + dataset, + batch_sampler=batch_sampler, + collate_fn=collate_fn, + num_workers=0, + pin_memory=True, ) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_retro_model.py b/nemo/collections/nlp/models/language_modeling/megatron_retro_model.py index a6bf75fb9444..493d512fd30e 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_retro_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_retro_model.py @@ -23,10 +23,10 @@ from typing import Any, Dict, Iterator, List, Optional, Union import torch +from lightning.pytorch.accelerators import CPUAccelerator +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import OmegaConf, open_dict from omegaconf.dictconfig import DictConfig -from pytorch_lightning.accelerators import CPUAccelerator -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.data.language_modeling.megatron.data_samplers import ( MegatronPretrainingRandomSampler, diff --git a/nemo/collections/nlp/models/language_modeling/megatron_t0_model.py b/nemo/collections/nlp/models/language_modeling/megatron_t0_model.py index cee1b11a160b..92827b31a259 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_t0_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_t0_model.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import torch +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import DictConfig, ListConfig -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.data.language_modeling.megatron.base_dataset_utils import ( get_datasets_weights_and_num_samples, diff --git a/nemo/collections/nlp/models/language_modeling/megatron_t5_adapter_model.py b/nemo/collections/nlp/models/language_modeling/megatron_t5_adapter_model.py index 31eb4519ded2..a6e6afc8b7eb 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_t5_adapter_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_t5_adapter_model.py @@ -21,9 +21,9 @@ from typing import Any import torch +from lightning.pytorch.trainer.trainer import Trainer from omegaconf.dictconfig import DictConfig from omegaconf.omegaconf import open_dict -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.common.parts.adapter_modules import LinearAdapterConfig from nemo.collections.nlp.models.language_modeling.megatron_t5_model import MegatronT5Model @@ -60,7 +60,15 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): self.adapter_name_keys = [] def forward( - self, input_ids, dec_input, enc_mask, dec_mask, position_ids, taskname_ids, labels=None, inference=False, + self, + input_ids, + dec_input, + enc_mask, + dec_mask, + position_ids, + taskname_ids, + labels=None, + inference=False, ): # Call forward on T5 model with preprocessed embeddings if self.autocast_dtype == torch.float32: @@ -195,13 +203,13 @@ def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> A def setup_optimizer_param_groups(self): """ - ModelPT override. Optimizer will get self._optimizer_param_groups. + ModelPT override. Optimizer will get self._optimizer_param_groups. Makes two optimizer param groups, one for the frozen model params - and one for the prompt-table/prompt-encoder params. The learning + and one for the prompt-table/prompt-encoder params. The learning rate for the frozen model's params will always be zero effectively freezing the model's params but still allowing for the needed gradients - to be passed around in pipeline parallel models. The prompt-encoder - and/or prompt table will use the learning rate set by the user. + to be passed around in pipeline parallel models. The prompt-encoder + and/or prompt table will use the learning rate set by the user. """ self.frozen_model.freeze() # Freeze the entire model opt_params = [] @@ -266,7 +274,7 @@ def state_dict(self, destination=None, prefix=None, keep_vars=False): def load_state_dict(self, state_dict, strict: bool = True): """ - Loads a state_dict expecting the state_dict to contain key,values + Loads a state_dict expecting the state_dict to contain key,values only for the adapter parameters. """ for name, module in self.frozen_model.named_modules(): @@ -319,7 +327,7 @@ def on_validation_epoch_end(self): gather_results_dedup = list(set(itertools.chain(*gather_results))) correct = 0 - for (input, pred, label) in gather_results_dedup: + for input, pred, label in gather_results_dedup: if pred == label: correct += 1 @@ -559,8 +567,8 @@ class MegatronT5InfusedAdapterModel(MegatronT5BaseAdapterModel): Three adapter's are inserted into each Transformer layer in the base GPT Model. Each adapter is basically a vector that simply scales the key, value or ffn hidden representations. It is assumed that these set of adapters will then be trained for a specific task. - Once trained, the adapter weights will be saved and can be re-loaded - and infused into the same GPT Model for inference. + Once trained, the adapter weights will be saved and can be re-loaded + and infused into the same GPT Model for inference. """ def __init__(self, cfg: DictConfig, trainer: Trainer): @@ -670,7 +678,7 @@ def state_dict(self, destination=None, prefix=None, keep_vars=False): def load_state_dict(self, state_dict, strict: bool = True): """ - Loads a state_dict expecting the state_dict to contain key,values + Loads a state_dict expecting the state_dict to contain key,values only for the adapter parameters. """ encoder = self.frozen_model.enc_dec_model.enc_dec_model.encoder diff --git a/nemo/collections/nlp/models/language_modeling/megatron_t5_model.py b/nemo/collections/nlp/models/language_modeling/megatron_t5_model.py index 0f5022795446..1df10403a9e7 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_t5_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_t5_model.py @@ -15,8 +15,8 @@ import enum import math +from lightning.pytorch.trainer.trainer import Trainer from omegaconf.dictconfig import DictConfig -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.data.language_modeling.megatron.dataset_utils import build_train_valid_test_datasets from nemo.collections.nlp.models.language_modeling.megatron_lm_encoder_decoder_model import ( @@ -79,7 +79,9 @@ def _validate_cfg(self): @property def _build_train_valid_test_datasets_kwargs(self): """allows child classes to add kwargs to dataset building""" - return dict(max_seq_length_dec=self._cfg.data.seq_length_dec,) + return dict( + max_seq_length_dec=self._cfg.data.seq_length_dec, + ) def _build_vocab(self): self.num_sentinel_tokens = self._cfg.tokenizer.num_sentinel_tokens @@ -210,9 +212,9 @@ def build_train_valid_test_datasets(self): ] if self.trainer.limit_val_batches <= 1.0 and isinstance(self.trainer.limit_val_batches, float): - train_valid_test_num_samples[ - 1 - ] = 1 # This is to make sure we only have one epoch on every validation iteration + train_valid_test_num_samples[1] = ( + 1 # This is to make sure we only have one epoch on every validation iteration + ) self._train_ds, self._validation_ds, self._test_ds = build_train_valid_test_datasets( cfg=self._cfg, diff --git a/nemo/collections/nlp/models/language_modeling/megatron_t5_prompt_learning_model.py b/nemo/collections/nlp/models/language_modeling/megatron_t5_prompt_learning_model.py index 1f54cb87428e..187f24c884b7 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_t5_prompt_learning_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_t5_prompt_learning_model.py @@ -16,10 +16,10 @@ from typing import Any, List import torch +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import OmegaConf from omegaconf.dictconfig import DictConfig from omegaconf.omegaconf import open_dict -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.data.language_modeling.megatron.t5_prompt_learning_dataset import T5PromptLearningDataset from nemo.collections.nlp.models.language_modeling.megatron_base_prompt_learning_model import ( diff --git a/nemo/collections/nlp/models/language_modeling/megatron_t5_sft_model.py b/nemo/collections/nlp/models/language_modeling/megatron_t5_sft_model.py index c70f44925d33..6f9a69f27529 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_t5_sft_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_t5_sft_model.py @@ -16,9 +16,9 @@ from typing import Dict, List import torch +from lightning.pytorch.loops.fetchers import _DataFetcherWrapper +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import DictConfig, ListConfig -from pytorch_lightning.loops.fetchers import _DataFetcherWrapper -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.common.data import ConcatMapDataset from nemo.collections.common.metrics import MetricStringToTorchMetric diff --git a/nemo/collections/nlp/models/language_modeling/transformer_lm_model.py b/nemo/collections/nlp/models/language_modeling/transformer_lm_model.py index 69db0d46e75e..3b8e1f819ea1 100644 --- a/nemo/collections/nlp/models/language_modeling/transformer_lm_model.py +++ b/nemo/collections/nlp/models/language_modeling/transformer_lm_model.py @@ -19,8 +19,8 @@ import numpy as np import torch import torch.utils.data as pt_data +from lightning.pytorch import Trainer from omegaconf import DictConfig, OmegaConf -from pytorch_lightning import Trainer from nemo.collections.common.losses import SmoothedCrossEntropyLoss from nemo.collections.common.metrics import GlobalAverageLossMetric @@ -59,9 +59,11 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): tokenizer_model=cfg.tokenizer.get("tokenizer_model", None), vocab_file=cfg.tokenizer.get("vocab_file", None), bpe_dropout=cfg.tokenizer.get("bpe_dropout", 0.0), - special_tokens=OmegaConf.to_container(cfg.tokenizer.special_tokens) - if cfg.tokenizer.get("special_tokens", None) - else None, + special_tokens=( + OmegaConf.to_container(cfg.tokenizer.special_tokens) + if cfg.tokenizer.get("special_tokens", None) + else None + ), ) # init superclass @@ -99,7 +101,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): # tie weights of embedding and softmax matrices self.log_softmax.mlp.layer0.weight = self.encoder.embedding.token_embedding.weight - std_init_range = 1 / self.encoder.hidden_size ** 0.5 + std_init_range = 1 / self.encoder.hidden_size**0.5 # initialize weights if not using pretrained encoder if not self._cfg.encoder.get('pretrained', False): @@ -199,7 +201,12 @@ def on_test_epoch_end(self): self.test_step_outputs.clear() # free memory def setup_tokenizer( - self, tokenizer_name=None, tokenizer_model=None, vocab_file=None, bpe_dropout=0.0, special_tokens=None, + self, + tokenizer_name=None, + tokenizer_model=None, + vocab_file=None, + bpe_dropout=0.0, + special_tokens=None, ): supported_tokenizers = ['huggingface', 'sentencepiece', 'word'] diff --git a/nemo/collections/nlp/models/machine_translation/megatron_nmt_model.py b/nemo/collections/nlp/models/machine_translation/megatron_nmt_model.py index 4461b417f311..b5f228f21e1a 100644 --- a/nemo/collections/nlp/models/machine_translation/megatron_nmt_model.py +++ b/nemo/collections/nlp/models/machine_translation/megatron_nmt_model.py @@ -19,10 +19,10 @@ import numpy as np import torch +from lightning.pytorch.loops.fetchers import _DataFetcherWrapper +from lightning.pytorch.trainer.trainer import Trainer from omegaconf.dictconfig import DictConfig from omegaconf.listconfig import ListConfig -from pytorch_lightning.loops.fetchers import _DataFetcherWrapper -from pytorch_lightning.trainer.trainer import Trainer from sacrebleu import corpus_bleu from nemo.collections.nlp.data.common.sequence_to_sequence_dataset import ( diff --git a/nemo/collections/nlp/models/machine_translation/mt_enc_dec_bottleneck_model.py b/nemo/collections/nlp/models/machine_translation/mt_enc_dec_bottleneck_model.py index 41c6125ba05f..96077c4da82e 100644 --- a/nemo/collections/nlp/models/machine_translation/mt_enc_dec_bottleneck_model.py +++ b/nemo/collections/nlp/models/machine_translation/mt_enc_dec_bottleneck_model.py @@ -16,7 +16,7 @@ import numpy as np import torch -from pytorch_lightning import Trainer +from lightning.pytorch import Trainer from nemo.collections.common.losses import NLLLoss from nemo.collections.nlp.models.machine_translation.mt_enc_dec_config import MTBottleneckModelConfig @@ -184,7 +184,11 @@ def loss( output_mask = (tgt_labels != self.decoder_tokenizer.pad_id).type_as(tgt_log_probs) log_p_x_given_z_per_token = ( - -recon_loss_fn(log_probs=tgt_log_probs, labels=tgt_labels,).view(tgt_log_probs.shape[:2]) * output_mask + -recon_loss_fn( + log_probs=tgt_log_probs, + labels=tgt_labels, + ).view(tgt_log_probs.shape[:2]) + * output_mask ) # probability per sample @@ -216,7 +220,10 @@ def loss( if self.model_type in ["mim", "vae"]: # tokens = tgt_mask.sum() - q_z_given_x = torch.distributions.Normal(loc=z_mean, scale=torch.exp(0.5 * z_logv),) + q_z_given_x = torch.distributions.Normal( + loc=z_mean, + scale=torch.exp(0.5 * z_logv), + ) # average latent distribution to match averaging of observations if self.recon_per_token: # average latent per dimension - to heuristically match per-token reconstruction @@ -225,7 +232,10 @@ def loss( log_q_z_given_x = q_z_given_x.log_prob(z).sum(-1).sum(-1).mean() # build prior distribution - p_z = torch.distributions.Normal(loc=torch.zeros_like(z), scale=torch.ones_like(z),) + p_z = torch.distributions.Normal( + loc=torch.zeros_like(z), + scale=torch.ones_like(z), + ) if self.recon_per_token: # average latent distribution similar to averaging of observations log_p_z = p_z.log_prob(z).mean(-1).mean(-1).mean() @@ -267,7 +277,11 @@ def forward(self, src, src_mask, tgt, tgt_mask, timer=None): if timer is not None: timer.start("encoder") - enc_hiddens, enc_mask = self.encoder(input_ids=src, encoder_mask=src_mask, return_mask=True,) + enc_hiddens, enc_mask = self.encoder( + input_ids=src, + encoder_mask=src_mask, + return_mask=True, + ) # build posterior distribution q(x|z) z, z_mean, z_logv = self.encode_latent(hidden=enc_hiddens) @@ -283,7 +297,10 @@ def forward(self, src, src_mask, tgt, tgt_mask, timer=None): context_hiddens = self.latent2hidden(z) tgt_hiddens = self.decoder( - input_ids=tgt, decoder_mask=tgt_mask, encoder_embeddings=context_hiddens, encoder_mask=enc_mask, + input_ids=tgt, + decoder_mask=tgt_mask, + encoder_embeddings=context_hiddens, + encoder_mask=enc_mask, ) # build decoding distribution @@ -426,18 +443,25 @@ def eval_step(self, batch, batch_idx, mode, dataloader_idx=0): return_info=True, ) # pass cache to sampler in order to reuse encoder's output - cache = dict(z=z, z_mean=z_mean, z_mask=z_mask, timer=timer,) + cache = dict( + z=z, + z_mean=z_mean, + z_mask=z_mask, + timer=timer, + ) inputs, translations = self.batch_translate(src=src_ids, src_mask=src_mask, cache=cache) num_measurements = labels.shape[0] * labels.shape[1] if dataloader_idx == 0: getattr(self, f'{mode}_loss')( - loss=eval_loss, num_measurements=num_measurements, + loss=eval_loss, + num_measurements=num_measurements, ) else: getattr(self, f'{mode}_loss_{dataloader_idx}')( - loss=eval_loss, num_measurements=num_measurements, + loss=eval_loss, + num_measurements=num_measurements, ) np_tgt = tgt_ids.detach().cpu().numpy() ground_truths = [self.decoder_tokenizer.ids_to_text(tgt) for tgt in np_tgt] diff --git a/nemo/collections/nlp/models/machine_translation/mt_enc_dec_model.py b/nemo/collections/nlp/models/machine_translation/mt_enc_dec_model.py index 708d4236be7f..78b701699259 100644 --- a/nemo/collections/nlp/models/machine_translation/mt_enc_dec_model.py +++ b/nemo/collections/nlp/models/machine_translation/mt_enc_dec_model.py @@ -25,9 +25,9 @@ import torch import torch.distributed as dist import torch.utils.data as pt_data +from lightning.pytorch import Trainer +from lightning.pytorch.utilities import rank_zero_only from omegaconf import DictConfig, ListConfig, OmegaConf -from pytorch_lightning import Trainer -from pytorch_lightning.utilities import rank_zero_only from sacrebleu import corpus_bleu from nemo.collections.common.data import ConcatDataset @@ -120,17 +120,21 @@ def __init__(self, cfg: MTEncDecModelConfig, trainer: Trainer = None): encoder_tokenizer, decoder_tokenizer = MTEncDecModel.setup_enc_dec_tokenizers( encoder_tokenizer_library=self.encoder_tokenizer_library, encoder_tokenizer_model=encoder_tokenizer_model, - encoder_bpe_dropout=cfg.encoder_tokenizer.get('bpe_dropout', 0.0) - if cfg.encoder_tokenizer.get('bpe_dropout', 0.0) is not None - else 0.0, + encoder_bpe_dropout=( + cfg.encoder_tokenizer.get('bpe_dropout', 0.0) + if cfg.encoder_tokenizer.get('bpe_dropout', 0.0) is not None + else 0.0 + ), encoder_model_name=cfg.encoder.get('model_name') if hasattr(cfg.encoder, 'model_name') else None, encoder_r2l=cfg.encoder_tokenizer.get('r2l', False), decoder_tokenizer_library=self.decoder_tokenizer_library, encoder_tokenizer_vocab_file=encoder_vocab_file, decoder_tokenizer_model=decoder_tokenizer_model, - decoder_bpe_dropout=cfg.decoder_tokenizer.get('bpe_dropout', 0.0) - if cfg.decoder_tokenizer.get('bpe_dropout', 0.0) is not None - else 0.0, + decoder_bpe_dropout=( + cfg.decoder_tokenizer.get('bpe_dropout', 0.0) + if cfg.decoder_tokenizer.get('bpe_dropout', 0.0) is not None + else 0.0 + ), decoder_model_name=cfg.decoder.get('model_name') if hasattr(cfg.decoder, 'model_name') else None, decoder_r2l=cfg.decoder_tokenizer.get('r2l', False), special_tokens=self.special_tokens, @@ -254,7 +258,7 @@ def __init__(self, cfg: MTEncDecModelConfig, trainer: Trainer = None): self.log_softmax.mlp.layer0.weight = self.decoder.embedding.token_embedding.weight # TODO: encoder and decoder with different hidden size? - std_init_range = 1 / self.encoder.hidden_size ** 0.5 + std_init_range = 1 / self.encoder.hidden_size**0.5 # initialize weights if not using pretrained encoder/decoder if not self._cfg.encoder.get('pretrained', False): @@ -341,7 +345,10 @@ def filter_predicted_ids(cls, ids, decoder_tokenizer): return ids def test_encoder_ids(self, ids, raise_error=False): - invalid_ids = torch.logical_or((ids >= self.encoder_tokenizer.vocab_size).any(), (ids < 0).any(),) + invalid_ids = torch.logical_or( + (ids >= self.encoder_tokenizer.vocab_size).any(), + (ids < 0).any(), + ) if raise_error and invalid_ids: raise ValueError("Encoder ids are out of range (tip: check encoder tokenizer)") @@ -349,7 +356,10 @@ def test_encoder_ids(self, ids, raise_error=False): return not invalid_ids def test_decoder_ids(self, ids, raise_error=False): - invalid_ids = torch.logical_or((ids >= self.decoder_tokenizer.vocab_size).any(), (ids < 0).any(),) + invalid_ids = torch.logical_or( + (ids >= self.decoder_tokenizer.vocab_size).any(), + (ids < 0).any(), + ) if raise_error and invalid_ids: raise ValueError("Decoder ids are out of range (tip: check decoder tokenizer)") @@ -655,7 +665,10 @@ def setup_training_data(self, train_data_config: Optional[DictConfig]): multilingual=self.multilingual, multilingual_ids=self.multilingual_ids, ) - self._train_dl = MTEncDecModel._setup_dataloader_from_config(cfg=train_data_config, dataset=self._train_ds,) + self._train_dl = MTEncDecModel._setup_dataloader_from_config( + cfg=train_data_config, + dataset=self._train_ds, + ) # Need to set this because if using an IterableDataset, the length of the dataloader is the total number # of samples rather than the number of batches, and this messes up the tqdm progress bar. @@ -714,7 +727,9 @@ def setup_validation_data(self, val_data_config: Optional[DictConfig]): for dataloader_idx in range(len(self._validation_dl)): if dataloader_idx == 0: setattr( - self, f'val_loss', GlobalAverageLossMetric(dist_sync_on_step=False, take_avg_loss=True), + self, + f'val_loss', + GlobalAverageLossMetric(dist_sync_on_step=False, take_avg_loss=True), ) else: setattr( @@ -737,7 +752,9 @@ def setup_test_data(self, test_data_config: Optional[DictConfig]): for dataloader_idx in range(len(self._test_dl)): if dataloader_idx == 0: setattr( - self, f'test_loss', GlobalAverageLossMetric(dist_sync_on_step=False, take_avg_loss=True), + self, + f'test_loss', + GlobalAverageLossMetric(dist_sync_on_step=False, take_avg_loss=True), ) else: setattr( @@ -886,13 +903,15 @@ def _setup_dataloader_from_config(cls, cfg, dataset): return torch.utils.data.DataLoader( dataset=dataset, batch_size=1, - sampler=None - if ( - cfg.get("use_tarred_dataset", False) - or cfg.get("dataset_type", "") == "tarred" - or isinstance(dataset, ConcatDataset) - ) - else sampler, + sampler=( + None + if ( + cfg.get("use_tarred_dataset", False) + or cfg.get("dataset_type", "") == "tarred" + or isinstance(dataset, ConcatDataset) + ) + else sampler + ), num_workers=cfg.get("num_workers", 2), pin_memory=cfg.get("pin_memory", False), drop_last=cfg.get("drop_last", False), @@ -983,9 +1002,11 @@ def _setup_eval_dataloader_from_config(cls, cfg, datasets): torch.utils.data.DataLoader( dataset=dataset, batch_size=1, - sampler=None - if (cfg.get("use_tarred_dataset", False) or isinstance(datasets[0], ConcatDataset)) - else sampler, + sampler=( + None + if (cfg.get("use_tarred_dataset", False) or isinstance(datasets[0], ConcatDataset)) + else sampler + ), num_workers=cfg.get("num_workers", 2), pin_memory=cfg.get("pin_memory", False), drop_last=cfg.get("drop_last", False), @@ -1188,7 +1209,10 @@ def translate( ) if return_beam_scores: _, all_translations, scores, best_translations = self.batch_translate( - src, src_mask, return_beam_scores=True, cache=cache, + src, + src_mask, + return_beam_scores=True, + cache=cache, ) return_val = all_translations, scores, best_translations else: diff --git a/nemo/collections/nlp/models/nlp_model.py b/nemo/collections/nlp/models/nlp_model.py index b27c00c5d7c3..6a87eb28723c 100644 --- a/nemo/collections/nlp/models/nlp_model.py +++ b/nemo/collections/nlp/models/nlp_model.py @@ -19,13 +19,13 @@ from typing import Any, Mapping, Optional, Union import torch -from lightning_fabric.utilities.cloud_io import _load as pl_load +from lightning.fabric.utilities.cloud_io import _load as pl_load +from lightning.pytorch import Trainer +from lightning.pytorch.core.saving import _load_state as ptl_load_state +from lightning.pytorch.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml +from lightning.pytorch.utilities import rank_zero_only +from lightning.pytorch.utilities.migration import pl_legacy_patch from omegaconf import DictConfig, OmegaConf -from pytorch_lightning import Trainer -from pytorch_lightning.core.saving import _load_state as ptl_load_state -from pytorch_lightning.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml -from pytorch_lightning.utilities import rank_zero_only -from pytorch_lightning.utilities.migration import pl_legacy_patch from transformers import TRANSFORMERS_CACHE from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer @@ -397,7 +397,22 @@ def dummy(): model.trainer.strategy.launcher.launch(dummy, trainer=model.trainer) model.trainer.strategy.setup_environment() sharded_state_dict = model.sharded_state_dict() - checkpoint['state_dict'] = sharded_state_dict + if kwargs.get("load_mlm", False): + mlm_sharded_state_dict = {} + for k, v in sharded_state_dict.items(): + # Remove 'model.' from the sharded_state_dict keys + new_key = k.replace('model.', '', 1) + + # Update the key attribute of the ShardedTensor value + new_value = v + if hasattr(v, 'key'): + new_value.key = v.key.replace('model.', '', 1) + + # Add the updated key-value pair to the new dictionary + mlm_sharded_state_dict[new_key] = new_value + checkpoint['state_dict'] = mlm_sharded_state_dict + else: + checkpoint['state_dict'] = sharded_state_dict # load the checkpoint from disk checkpoint = dist_checkpointing.load(sharded_state_dict=checkpoint, checkpoint_dir=checkpoint_dir) # restore the weights diff --git a/nemo/collections/nlp/models/question_answering/qa_base_model.py b/nemo/collections/nlp/models/question_answering/qa_base_model.py index 7ca78f2e136e..cb07e43c3dc1 100644 --- a/nemo/collections/nlp/models/question_answering/qa_base_model.py +++ b/nemo/collections/nlp/models/question_answering/qa_base_model.py @@ -15,8 +15,8 @@ from typing import Optional import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig, OmegaConf -from pytorch_lightning import Trainer from nemo.collections.nlp.data.question_answering.data_processor.qa_processing import ( EVALUATION_MODE, diff --git a/nemo/collections/nlp/models/question_answering/qa_bert_model.py b/nemo/collections/nlp/models/question_answering/qa_bert_model.py index d4bdef6d871d..4036b23999d8 100644 --- a/nemo/collections/nlp/models/question_answering/qa_bert_model.py +++ b/nemo/collections/nlp/models/question_answering/qa_bert_model.py @@ -17,8 +17,8 @@ import numpy as np import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig -from pytorch_lightning import Trainer from transformers.models.bert.tokenization_bert import BasicTokenizer from nemo.collections.common.losses import SpanningLoss diff --git a/nemo/collections/nlp/models/question_answering/qa_gpt_model.py b/nemo/collections/nlp/models/question_answering/qa_gpt_model.py index 059cf5625f15..f8c883643fe0 100644 --- a/nemo/collections/nlp/models/question_answering/qa_gpt_model.py +++ b/nemo/collections/nlp/models/question_answering/qa_gpt_model.py @@ -16,8 +16,8 @@ from typing import List, Optional import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig -from pytorch_lightning import Trainer from transformers import AutoModelForCausalLM from nemo.collections.nlp.data.question_answering.data_processor.qa_processing import QAProcessor diff --git a/nemo/collections/nlp/models/question_answering/qa_model.py b/nemo/collections/nlp/models/question_answering/qa_model.py index 2147d7d6a5bf..01b07bb8b3b0 100644 --- a/nemo/collections/nlp/models/question_answering/qa_model.py +++ b/nemo/collections/nlp/models/question_answering/qa_model.py @@ -16,8 +16,8 @@ from typing import Optional import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig, OmegaConf -from pytorch_lightning import Trainer from torch.cuda.amp import autocast from nemo.collections.common.losses import SpanningLoss diff --git a/nemo/collections/nlp/models/question_answering/qa_s2s_model.py b/nemo/collections/nlp/models/question_answering/qa_s2s_model.py index 5ad959fd1b6f..a703e23bc837 100644 --- a/nemo/collections/nlp/models/question_answering/qa_s2s_model.py +++ b/nemo/collections/nlp/models/question_answering/qa_s2s_model.py @@ -16,8 +16,8 @@ from typing import List, Optional import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig, OmegaConf, open_dict -from pytorch_lightning import Trainer from torch.cuda.amp import autocast from transformers import AutoModelForSeq2SeqLM diff --git a/nemo/collections/nlp/models/rag/custom_bert_embedder.py b/nemo/collections/nlp/models/rag/custom_bert_embedder.py index d27ee98a14ef..84361e2728b5 100644 --- a/nemo/collections/nlp/models/rag/custom_bert_embedder.py +++ b/nemo/collections/nlp/models/rag/custom_bert_embedder.py @@ -15,10 +15,10 @@ from typing import Any, List import torch +from lightning.pytorch.trainer.trainer import Trainer from llama_index.core.bridge.pydantic import PrivateAttr from llama_index.core.embeddings import BaseEmbedding from omegaconf import DictConfig -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.models.information_retrieval.megatron_bert_embedding_model import MegatronBertEmbeddingModel from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy diff --git a/nemo/collections/nlp/models/rag/custom_gpt_llm.py b/nemo/collections/nlp/models/rag/custom_gpt_llm.py index f26a86cfaaf7..1bbeed38991b 100644 --- a/nemo/collections/nlp/models/rag/custom_gpt_llm.py +++ b/nemo/collections/nlp/models/rag/custom_gpt_llm.py @@ -14,10 +14,10 @@ from typing import Any +from lightning.pytorch.trainer.trainer import Trainer from llama_index.core.bridge.pydantic import PrivateAttr from llama_index.core.llms import CompletionResponse, CompletionResponseGen, CustomLLM, LLMMetadata from llama_index.core.llms.callbacks import llm_completion_callback -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel from nemo.collections.nlp.modules.common.transformer.text_generation import LengthParam, SamplingParam diff --git a/nemo/collections/nlp/models/spellchecking_asr_customization/spellchecking_model.py b/nemo/collections/nlp/models/spellchecking_asr_customization/spellchecking_model.py index d9e08f6764fc..6d4974993bcb 100644 --- a/nemo/collections/nlp/models/spellchecking_asr_customization/spellchecking_model.py +++ b/nemo/collections/nlp/models/spellchecking_asr_customization/spellchecking_model.py @@ -17,8 +17,8 @@ from typing import Dict, Optional import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig -from pytorch_lightning import Trainer from nemo.collections.common.losses import CrossEntropyLoss from nemo.collections.nlp.data.spellchecking_asr_customization import ( diff --git a/nemo/collections/nlp/models/text2sparql/text2sparql_model.py b/nemo/collections/nlp/models/text2sparql/text2sparql_model.py index 6503364fc07e..df7eefa310bb 100644 --- a/nemo/collections/nlp/models/text2sparql/text2sparql_model.py +++ b/nemo/collections/nlp/models/text2sparql/text2sparql_model.py @@ -19,8 +19,8 @@ from typing import Dict, List, Optional, Tuple import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig, OmegaConf -from pytorch_lightning import Trainer from transformers import AutoModel, BartForConditionalGeneration, EncoderDecoderModel from nemo.collections.common.metrics import Perplexity @@ -145,7 +145,10 @@ def training_step(self, batch: Tuple, batch_idx: int) -> Dict: """ input_ids, input_mask, decoder_input_ids, labels = batch loss = self.forward( - input_ids=input_ids, attention_mask=input_mask, decoder_input_ids=decoder_input_ids, labels=labels, + input_ids=input_ids, + attention_mask=input_mask, + decoder_input_ids=decoder_input_ids, + labels=labels, )[0] tensorboard_logs = {"train_loss": loss, "lr": self._optimizer.param_groups[0]["lr"]} @@ -159,7 +162,10 @@ def validation_step(self, batch: Tuple, batch_idx: int) -> Dict: """ input_ids, input_mask, decoder_input_ids, labels = batch loss, logits = self.forward( - input_ids=input_ids, attention_mask=input_mask, decoder_input_ids=decoder_input_ids, labels=labels, + input_ids=input_ids, + attention_mask=input_mask, + decoder_input_ids=decoder_input_ids, + labels=labels, )[:2] self.validation_perplexity(logits=logits) diff --git a/nemo/collections/nlp/models/text_classification/text_classification_model.py b/nemo/collections/nlp/models/text_classification/text_classification_model.py index 033447304bbf..b2da2fe21701 100644 --- a/nemo/collections/nlp/models/text_classification/text_classification_model.py +++ b/nemo/collections/nlp/models/text_classification/text_classification_model.py @@ -17,8 +17,8 @@ from typing import Dict, List, Optional import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig -from pytorch_lightning import Trainer from nemo.collections.common.losses import CrossEntropyLoss from nemo.collections.nlp.data.text_classification import TextClassificationDataset, calc_class_weights diff --git a/nemo/collections/nlp/models/text_normalization_as_tagging/thutmose_tagger.py b/nemo/collections/nlp/models/text_normalization_as_tagging/thutmose_tagger.py index 4c11dc157b2b..ddcb3a774055 100644 --- a/nemo/collections/nlp/models/text_normalization_as_tagging/thutmose_tagger.py +++ b/nemo/collections/nlp/models/text_normalization_as_tagging/thutmose_tagger.py @@ -17,8 +17,8 @@ from typing import Dict, List, Optional import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig -from pytorch_lightning import Trainer from nemo.collections.common.losses import CrossEntropyLoss from nemo.collections.nlp.data.text_normalization_as_tagging import ( @@ -289,7 +289,7 @@ def on_test_epoch_end(self): # Functions for inference @torch.no_grad() def _infer(self, sents: List[str]) -> List[List[int]]: - """ Main function for Inference + """Main function for Inference Args: sents: A list of input sentences (lowercase spoken-domain words separated by space). diff --git a/nemo/collections/nlp/models/token_classification/punctuation_capitalization_lexical_audio_model.py b/nemo/collections/nlp/models/token_classification/punctuation_capitalization_lexical_audio_model.py index 69df9b6ac009..bd42517a5720 100644 --- a/nemo/collections/nlp/models/token_classification/punctuation_capitalization_lexical_audio_model.py +++ b/nemo/collections/nlp/models/token_classification/punctuation_capitalization_lexical_audio_model.py @@ -17,8 +17,8 @@ import numpy as np import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig, open_dict -from pytorch_lightning import Trainer from torch.nn import Linear from tqdm import tqdm @@ -53,27 +53,27 @@ def update_model_config_to_support_adapter(model_cfg): class PunctuationCapitalizationLexicalAudioModel(PunctuationCapitalizationModel): """ - A model for restoring punctuation and capitalization in text using lexical and audio features. - - The model consists of a language model and two multilayer perceptrons (MLP) on top the fusion of LM and AM. The first - MLP serves for punctuation prediction and the second is for capitalization prediction. You can use only BERT-like - HuggingFace language models (model ``forward`` method accepts ``input_ids``, ``token_types_ids``, - ``attention_mask`` arguments). See more about model config options :ref:`here`. - And any :class:`~nemo.collections.asr.models.EncDecCTCModel` which has encoder module which is used as an AM. - - For training and testing use dataset - :class:`~nemo.collections.nlp.data.token_classification.punctuation_capitalization_dataset.BertPunctuationCapitalizationDataset` with parameter ``use_audio`` set to ``True``, - for training on huge amounts of data which cannot be loaded into memory simultaneously use - :class:`~nemo.collections.nlp.data.token_classification.punctuation_capitalization_tarred_dataset.BertPunctuationCapitalizationTarredDataset` with parameter ``use_audio`` set to ``True``. - - Args: - cfg: a model configuration. It should follow dataclass - :class:`~nemo.collections.nlp.models.token_classification.punctuation_capitalization_config.PunctuationCapitalizationLexicalAudioModelConfig` - See an example of full config in - `nemo/examples/nlp/token_classification/conf/punctuation_capitalization_lexical_audio_config.yaml - `_ - trainer: an instance of a PyTorch Lightning trainer - """ + A model for restoring punctuation and capitalization in text using lexical and audio features. + + The model consists of a language model and two multilayer perceptrons (MLP) on top the fusion of LM and AM. The first + MLP serves for punctuation prediction and the second is for capitalization prediction. You can use only BERT-like + HuggingFace language models (model ``forward`` method accepts ``input_ids``, ``token_types_ids``, + ``attention_mask`` arguments). See more about model config options :ref:`here`. + And any :class:`~nemo.collections.asr.models.EncDecCTCModel` which has encoder module which is used as an AM. + + For training and testing use dataset + :class:`~nemo.collections.nlp.data.token_classification.punctuation_capitalization_dataset.BertPunctuationCapitalizationDataset` with parameter ``use_audio`` set to ``True``, + for training on huge amounts of data which cannot be loaded into memory simultaneously use + :class:`~nemo.collections.nlp.data.token_classification.punctuation_capitalization_tarred_dataset.BertPunctuationCapitalizationTarredDataset` with parameter ``use_audio`` set to ``True``. + + Args: + cfg: a model configuration. It should follow dataclass + :class:`~nemo.collections.nlp.models.token_classification.punctuation_capitalization_config.PunctuationCapitalizationLexicalAudioModelConfig` + See an example of full config in + `nemo/examples/nlp/token_classification/conf/punctuation_capitalization_lexical_audio_config.yaml + `_ + trainer: an instance of a PyTorch Lightning trainer + """ def __init__(self, cfg: DictConfig, trainer: Trainer = None) -> None: super().__init__(cfg, trainer) @@ -199,31 +199,31 @@ def forward( features_length: torch.Tensor = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ - Executes a forward pass through the model. For more details see ``forward`` method of :class:`~nemo.collections.nlp.models.token_classification.punctuation_capitalization_config.PunctuationCapitalizationLexicalAudioModelConfig` - and ``forward`` method of :class:'~nemo.collections.asr.models.EncDecCTCModel' - - Args: - input_ids (:obj:`torch.Tensor`): an integer torch tensor of shape ``[Batch, Time]``. Contains encoded - source tokens. - attention_mask (:obj:`torch.Tensor`): a boolean torch tensor of shape ``[Batch, Time]``. Contains an - attention mask for excluding paddings. - token_type_ids (:obj:`torch.Tensor`): an integer torch Tensor of shape ``[Batch, Time]``. Contains an index - of segment to which a token belongs. If ``token_type_ids`` is not ``None``, then it should be a zeros - tensor. - features (:obj:`torch.Tensor`): tensor that represents a batch of raw audio signals, - of shape [B, T]. T here represents timesteps, with 1 second of audio represented as - sample_rate number of floating point values. - features_length (:obj:`torch.Tensor`): Vector of length B, that contains the individual lengths of the audio - sequences. - - Returns: - :obj:`Tuple[torch.Tensor, torch.Tensor]`: a tuple containing - - - ``punct_logits`` (:obj:`torch.Tensor`): a float torch tensor of shape - ``[Batch, Time, NumPunctuationLabels]`` containing punctuation logits - - ``capit_logits`` (:obj:`torch.Tensor`): a float torch tensor of shape - ``[Batch, Time, NumCapitalizationLabels]`` containing capitalization logits - """ + Executes a forward pass through the model. For more details see ``forward`` method of :class:`~nemo.collections.nlp.models.token_classification.punctuation_capitalization_config.PunctuationCapitalizationLexicalAudioModelConfig` + and ``forward`` method of :class:'~nemo.collections.asr.models.EncDecCTCModel' + + Args: + input_ids (:obj:`torch.Tensor`): an integer torch tensor of shape ``[Batch, Time]``. Contains encoded + source tokens. + attention_mask (:obj:`torch.Tensor`): a boolean torch tensor of shape ``[Batch, Time]``. Contains an + attention mask for excluding paddings. + token_type_ids (:obj:`torch.Tensor`): an integer torch Tensor of shape ``[Batch, Time]``. Contains an index + of segment to which a token belongs. If ``token_type_ids`` is not ``None``, then it should be a zeros + tensor. + features (:obj:`torch.Tensor`): tensor that represents a batch of raw audio signals, + of shape [B, T]. T here represents timesteps, with 1 second of audio represented as + sample_rate number of floating point values. + features_length (:obj:`torch.Tensor`): Vector of length B, that contains the individual lengths of the audio + sequences. + + Returns: + :obj:`Tuple[torch.Tensor, torch.Tensor]`: a tuple containing + + - ``punct_logits`` (:obj:`torch.Tensor`): a float torch tensor of shape + ``[Batch, Time, NumPunctuationLabels]`` containing punctuation logits + - ``capit_logits`` (:obj:`torch.Tensor`): a float torch tensor of shape + ``[Batch, Time, NumCapitalizationLabels]`` containing capitalization logits + """ self.update_max_seq_length(seq_length=features.size(1), device=features.device) lexical_hidden_states = self.bert_model( input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask @@ -232,7 +232,8 @@ def forward( lexical_hidden_states = lexical_hidden_states[0] processed_signal, processed_signal_length = self.audio_encoder.preprocessor( - input_signal=features, length=features_length, + input_signal=features, + length=features_length, ) if self.audio_encoder.spec_augmentation is not None and self.training: @@ -301,49 +302,49 @@ def add_punctuation_capitalization( target_sr: Optional[int] = None, ) -> List[str]: """ - Adds punctuation and capitalization to the queries. Use this method for inference. - - Parameters ``max_seq_length``, ``step``, ``margin`` are for controlling the way queries are split into segments - which are processed by the model. Parameter ``max_seq_length`` is a length of a segment after tokenization - including special tokens [CLS] in the beginning and [SEP] in the end of a segment. Parameter ``step`` is a - shift between consequent segments. Parameter ``margin`` is used to exclude negative effect of subtokens near - borders of segments which have only one side context. - - If segments overlap, probabilities of overlapping predictions are multiplied and then the label with - corresponding to the maximum probability is selected. - - Args: - queries (:obj:`List[str]`): lower cased text without punctuation. - batch_size (:obj:`List[str]`, `optional`): batch size to use during inference. If ``batch_size`` parameter - is not provided, then it will be equal to length of ``queries`` list. - max_seq_length (:obj:`int`, `optional`, defaults to :obj:`64`): maximum sequence length of a segment after - tokenization including :code:`[CLS]` and :code:`[SEP]` tokens. - step (:obj:`int`, `optional`, defaults to :obj:`8`): relative shift of consequent segments into which long - queries are split. Long queries are split into segments which can overlap. Parameter ``step`` controls - such overlapping. Imagine that queries are tokenized into characters, ``max_seq_length=5``, and - ``step=2``. In such case, query ``"hello"`` is tokenized into segments - ``[['[CLS]', 'h', 'e', 'l', '[SEP]'], ['[CLS]', 'l', 'l', 'o', '[SEP]']]``. - margin (:obj:`int`, `optional`, defaults to :obj:`16`): number of subtokens in the beginning and the end of - segments which are not used for prediction computation. The first segment does not have left margin and - the last segment does not have right margin. For example, if an input sequence is tokenized into - characters, ``max_seq_length=5``, ``step=1``, and ``margin=1``, then query ``"hello"`` will be - tokenized into segments ``[['[CLS]', 'h', 'e', 'l', '[SEP]'], ['[CLS]', 'e', 'l', 'l', '[SEP]'], - ['[CLS]', 'l', 'l', 'o', '[SEP]']]``. These segments are passed to the model. Before final predictions - computation, margins are removed. In the next list, subtokens which logits are not used for final - predictions computation are marked with asterisk: ``[['[CLS]'*, 'h', 'e', 'l'*, '[SEP]'*], - ['[CLS]'*, 'e'*, 'l', 'l'*, '[SEP]'*], ['[CLS]'*, 'l'*, 'l', 'o', '[SEP]'*]]``. - return_labels (:obj:`bool`, `optional`, defaults to :obj:`False`): whether to return labels in NeMo format - (see :ref:`nlp/punctuation_and_capitalization/NeMo Data Format`) instead of queries with restored - punctuation and capitalization. - dataloader_kwargs (:obj:`Dict[str, Any]`, `optional`): an optional dictionary with parameters of PyTorch - data loader. May include keys: ``'num_workers'``, ``'pin_memory'``, ``'worker_init_fn'``, - ``'prefetch_factor'``, ``'persistent_workers'``. - audio_queries (:obj:`List[str]`, `optional`): paths to audio files. - target_sr (:obj:`int`, `optional`): target sample rate for audios. - Returns: - :obj:`List[str]`: a list of queries with restored capitalization and punctuation if - ``return_labels=False``, else a list of punctuation and capitalization labels strings for all queries - """ + Adds punctuation and capitalization to the queries. Use this method for inference. + + Parameters ``max_seq_length``, ``step``, ``margin`` are for controlling the way queries are split into segments + which are processed by the model. Parameter ``max_seq_length`` is a length of a segment after tokenization + including special tokens [CLS] in the beginning and [SEP] in the end of a segment. Parameter ``step`` is a + shift between consequent segments. Parameter ``margin`` is used to exclude negative effect of subtokens near + borders of segments which have only one side context. + + If segments overlap, probabilities of overlapping predictions are multiplied and then the label with + corresponding to the maximum probability is selected. + + Args: + queries (:obj:`List[str]`): lower cased text without punctuation. + batch_size (:obj:`List[str]`, `optional`): batch size to use during inference. If ``batch_size`` parameter + is not provided, then it will be equal to length of ``queries`` list. + max_seq_length (:obj:`int`, `optional`, defaults to :obj:`64`): maximum sequence length of a segment after + tokenization including :code:`[CLS]` and :code:`[SEP]` tokens. + step (:obj:`int`, `optional`, defaults to :obj:`8`): relative shift of consequent segments into which long + queries are split. Long queries are split into segments which can overlap. Parameter ``step`` controls + such overlapping. Imagine that queries are tokenized into characters, ``max_seq_length=5``, and + ``step=2``. In such case, query ``"hello"`` is tokenized into segments + ``[['[CLS]', 'h', 'e', 'l', '[SEP]'], ['[CLS]', 'l', 'l', 'o', '[SEP]']]``. + margin (:obj:`int`, `optional`, defaults to :obj:`16`): number of subtokens in the beginning and the end of + segments which are not used for prediction computation. The first segment does not have left margin and + the last segment does not have right margin. For example, if an input sequence is tokenized into + characters, ``max_seq_length=5``, ``step=1``, and ``margin=1``, then query ``"hello"`` will be + tokenized into segments ``[['[CLS]', 'h', 'e', 'l', '[SEP]'], ['[CLS]', 'e', 'l', 'l', '[SEP]'], + ['[CLS]', 'l', 'l', 'o', '[SEP]']]``. These segments are passed to the model. Before final predictions + computation, margins are removed. In the next list, subtokens which logits are not used for final + predictions computation are marked with asterisk: ``[['[CLS]'*, 'h', 'e', 'l'*, '[SEP]'*], + ['[CLS]'*, 'e'*, 'l', 'l'*, '[SEP]'*], ['[CLS]'*, 'l'*, 'l', 'o', '[SEP]'*]]``. + return_labels (:obj:`bool`, `optional`, defaults to :obj:`False`): whether to return labels in NeMo format + (see :ref:`nlp/punctuation_and_capitalization/NeMo Data Format`) instead of queries with restored + punctuation and capitalization. + dataloader_kwargs (:obj:`Dict[str, Any]`, `optional`): an optional dictionary with parameters of PyTorch + data loader. May include keys: ``'num_workers'``, ``'pin_memory'``, ``'worker_init_fn'``, + ``'prefetch_factor'``, ``'persistent_workers'``. + audio_queries (:obj:`List[str]`, `optional`): paths to audio files. + target_sr (:obj:`int`, `optional`): target sample rate for audios. + Returns: + :obj:`List[str]`: a list of queries with restored capitalization and punctuation if + ``return_labels=False``, else a list of punctuation and capitalization labels strings for all queries + """ if len(queries) == 0: return [] @@ -408,7 +409,9 @@ def add_punctuation_capitalization( acc_probs[q_i] = b_probs_i else: all_preds[q_i], acc_probs[q_i] = self._move_acc_probs_to_token_preds( - all_preds[q_i], acc_probs[q_i], start_word_id - len(all_preds[q_i]), + all_preds[q_i], + acc_probs[q_i], + start_word_id - len(all_preds[q_i]), ) acc_probs[q_i] = self._update_accumulated_probabilities(acc_probs[q_i], b_probs_i) for all_preds, acc_probs in [(all_punct_preds, acc_punct_probs), (all_capit_preds, acc_capit_probs)]: diff --git a/nemo/collections/nlp/models/token_classification/punctuation_capitalization_model.py b/nemo/collections/nlp/models/token_classification/punctuation_capitalization_model.py index 6e2d1f5762ec..8cf153dfdf76 100644 --- a/nemo/collections/nlp/models/token_classification/punctuation_capitalization_model.py +++ b/nemo/collections/nlp/models/token_classification/punctuation_capitalization_model.py @@ -20,8 +20,8 @@ import numpy as np import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig, OmegaConf -from pytorch_lightning import Trainer from tqdm import tqdm from nemo.collections.common.losses import AggregatorLoss, CrossEntropyLoss @@ -812,7 +812,13 @@ def _setup_dataloader_from_config(self, cfg: DictConfig, train: bool) -> torch.u raise ValueError( f"If `use_tarred_dataset` is `False`, then you need to provide `tokens_in_batch` parameter." ) - text_file, labels_file, = Path(cfg.ds_item) / cfg.text_file, Path(cfg.ds_item) / cfg.labels_file + ( + text_file, + labels_file, + ) = ( + Path(cfg.ds_item) / cfg.text_file, + Path(cfg.ds_item) / cfg.labels_file, + ) if cfg.audio_file: audio_file = Path(cfg.ds_item) / cfg.audio_file if self.label_ids_are_set: @@ -1010,7 +1016,8 @@ def _transform_logit_to_prob_and_remove_margins_and_extract_word_probs( stm = self._remove_margins(stm, margin, keep_left=first, keep_right=last) for b_probs, logits in [(b_punct_probs, pl), (b_capit_probs, cl)]: p = torch.nn.functional.softmax( - self._remove_margins(logits, margin, keep_left=first, keep_right=last)[stm], dim=-1, + self._remove_margins(logits, margin, keep_left=first, keep_right=last)[stm], + dim=-1, ) b_probs.append(p.detach().cpu().numpy()) return b_punct_probs, b_capit_probs, new_start_word_ids @@ -1191,7 +1198,9 @@ def add_punctuation_capitalization( ): inp_ids, inp_type_ids, inp_mask, subtokens_mask, start_word_ids, query_ids, is_first, is_last = batch punct_logits, capit_logits = self.forward( - input_ids=inp_ids.to(d), token_type_ids=inp_type_ids.to(d), attention_mask=inp_mask.to(d), + input_ids=inp_ids.to(d), + token_type_ids=inp_type_ids.to(d), + attention_mask=inp_mask.to(d), ) _res = self._transform_logit_to_prob_and_remove_margins_and_extract_word_probs( punct_logits, capit_logits, subtokens_mask, start_word_ids, margin, is_first, is_last @@ -1208,7 +1217,9 @@ def add_punctuation_capitalization( acc_probs[q_i] = b_probs_i else: all_preds[q_i], acc_probs[q_i] = self._move_acc_probs_to_token_preds( - all_preds[q_i], acc_probs[q_i], start_word_id - len(all_preds[q_i]), + all_preds[q_i], + acc_probs[q_i], + start_word_id - len(all_preds[q_i]), ) acc_probs[q_i] = self._update_accumulated_probabilities(acc_probs[q_i], b_probs_i) for all_preds, acc_probs in [(all_punct_preds, acc_punct_probs), (all_capit_preds, acc_capit_probs)]: diff --git a/nemo/collections/nlp/models/token_classification/token_classification_model.py b/nemo/collections/nlp/models/token_classification/token_classification_model.py index 0b465bae663c..99bb2328b956 100644 --- a/nemo/collections/nlp/models/token_classification/token_classification_model.py +++ b/nemo/collections/nlp/models/token_classification/token_classification_model.py @@ -16,8 +16,8 @@ from typing import List, Optional, Union import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig, OmegaConf -from pytorch_lightning import Trainer from torch.utils.data import DataLoader from nemo.collections.common.losses import CrossEntropyLoss diff --git a/nemo/collections/nlp/models/zero_shot_intent_recognition/zero_shot_intent_model.py b/nemo/collections/nlp/models/zero_shot_intent_recognition/zero_shot_intent_model.py index e65f3d7749eb..07e0826c712c 100644 --- a/nemo/collections/nlp/models/zero_shot_intent_recognition/zero_shot_intent_model.py +++ b/nemo/collections/nlp/models/zero_shot_intent_recognition/zero_shot_intent_model.py @@ -18,8 +18,8 @@ import numpy as np import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig -from pytorch_lightning import Trainer from nemo.collections.nlp.data.zero_shot_intent_recognition.zero_shot_intent_dataset import ( ZeroShotIntentDataset, @@ -155,7 +155,6 @@ def predict( entailment_idx=1, contradiction_idx=0, ) -> List[Dict]: - """ Given a list of queries and a list of candidate labels, return a ranked list of labels and scores for each query. diff --git a/nemo/collections/nlp/modules/common/lm_utils.py b/nemo/collections/nlp/modules/common/lm_utils.py index af6fc9ecb0a7..86792059b28f 100644 --- a/nemo/collections/nlp/modules/common/lm_utils.py +++ b/nemo/collections/nlp/modules/common/lm_utils.py @@ -17,8 +17,8 @@ from typing import List, Optional, Union from attr import asdict +from lightning.pytorch import Trainer from omegaconf import DictConfig -from pytorch_lightning import Trainer from nemo.collections.nlp.modules.common.bert_module import BertModule from nemo.collections.nlp.modules.common.decoder_module import DecoderModule diff --git a/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py b/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py index da9c98fd94ea..e306a0a9b6b7 100644 --- a/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py +++ b/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py @@ -82,19 +82,20 @@ def forward( rotary_pos_emb: Tensor = None, rotary_pos_cos: Tensor = None, rotary_pos_sin: Tensor = None, + attention_bias: Tensor = None, inference_params: InferenceParams = None, packed_seq_params: PackedSeqParams = None, ): hidden_states = super().forward( - hidden_states, - attention_mask, - context, - context_mask, - rotary_pos_emb, - rotary_pos_cos, - rotary_pos_sin, - inference_params, - packed_seq_params, + hidden_states=hidden_states, + attention_mask=attention_mask, + context=context, + context_mask=context_mask, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + inference_params=inference_params, + packed_seq_params=packed_seq_params, ) mlp_head_adapter = self.get_adapter_module(AdapterName.MLP_HEAD_ADAPTER) @@ -232,6 +233,7 @@ def forward( packed_seq_params=None, rotary_pos_cos=None, rotary_pos_sin=None, + attention_bias=None, ): # hidden_states: [sq, b, h] diff --git a/nemo/collections/nlp/modules/common/megatron/attention.py b/nemo/collections/nlp/modules/common/megatron/attention.py index c1b4e3023e42..d5784081f6f0 100644 --- a/nemo/collections/nlp/modules/common/megatron/attention.py +++ b/nemo/collections/nlp/modules/common/megatron/attention.py @@ -380,6 +380,7 @@ def forward( rotary_pos_emb=None, # rotary positional embedding relative_position_bias=None, checkpoint_core_attention=False, + return_scores=False, ): # hidden_states: [sq, b, h] @@ -398,7 +399,9 @@ def forward( # Some consistency check. if inference_max_sequence_len: - assert self.inference_current_sequence_len < self.inference_key_memory.size(0) + # Added equals to as inference key_memory size refers to cross-attention key size + # which is already equal to the current "sequence length" + assert self.inference_current_sequence_len <= self.inference_key_memory.size(0) assert inference_max_sequence_len == self.inference_key_memory.size(0) # This is added for safety. In case inference_max_sequence_len # is not provided, make sure there is no potential memory left @@ -433,28 +436,40 @@ def forward( (query_layer, key_layer, value_layer) = tensor_parallel.split_tensor_along_last_dim( mixed_x_layer, 3, contiguous_split_chunks=True ) - else: + else: # Else in cross_attention # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)] - mixed_kv_layer, _ = self.key_value(encoder_output) - if self.is_adapter_available(): - lora_kv_adapter = self.get_adapter_module(AdapterName.LORA_KV_ADAPTER) - if lora_kv_adapter and self.adapter_cfg[AdapterName.LORA_KV_ADAPTER]['enabled']: - lora_mixed_kv_layer = lora_kv_adapter(encoder_output) - mixed_kv_layer = mixed_kv_layer + lora_mixed_kv_layer - - # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn] - new_tensor_shape = mixed_kv_layer.size()[:-1] + ( - self.num_attention_heads_per_partition, - 2 * self.hidden_size_per_attention_head, - ) - if self.megatron_legacy: - mixed_kv_layer = self._transpose_last_dim(mixed_kv_layer, 2, True) - mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape) + if ( + inference_max_sequence_len is None + ) or self.inference_current_sequence_len < inference_max_sequence_len: + # If we are in traning and inference_max_sequence_len is None + # Or we haven't cached the key and value part of cross attention in the decoder on step 0, + # Do the caching + mixed_kv_layer, _ = self.key_value(encoder_output) + if self.is_adapter_available(): + lora_kv_adapter = self.get_adapter_module(AdapterName.LORA_KV_ADAPTER) + if lora_kv_adapter and self.adapter_cfg[AdapterName.LORA_KV_ADAPTER]['enabled']: + lora_mixed_kv_layer = lora_kv_adapter(encoder_output) + mixed_kv_layer = mixed_kv_layer + lora_mixed_kv_layer + + # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn] + new_tensor_shape = mixed_kv_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + 2 * self.hidden_size_per_attention_head, + ) + if self.megatron_legacy: + mixed_kv_layer = self._transpose_last_dim(mixed_kv_layer, 2, True) + mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape) - # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn] - (key_layer, value_layer) = tensor_parallel.split_tensor_along_last_dim( - mixed_kv_layer, 2, contiguous_split_chunks=True - ) + # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn] + (key_layer, value_layer) = tensor_parallel.split_tensor_along_last_dim( + mixed_kv_layer, 2, contiguous_split_chunks=True + ) + else: + # else if we are in inference and have already cached key, value, can just read cache + key_layer = self.inference_key_memory[: self.inference_current_sequence_len, ...] + value_layer = self.inference_value_memory[: self.inference_current_sequence_len, ...] + if attention_mask is not None: + attention_mask = attention_mask[..., -1, :].unsqueeze(-2) # Attention head [sq, b, h] --> [sq, b, hp] query_layer, _ = self.query(hidden_states) @@ -490,7 +505,9 @@ def forward( if rotary_pos_emb is not None: rotary_pos_emb = rotary_pos_emb if isinstance(rotary_pos_emb, tuple) else ((rotary_pos_emb,) * 2) - if inference_max_sequence_len: + # If we are in cross attention (inference_current_sequence_len == inference_max_sequence_len == inference_key_memory.size(0)) + # We only need to cache this once + if inference_max_sequence_len and self.inference_current_sequence_len < inference_max_sequence_len: # Adjust the range variables. start = self.inference_current_sequence_len self.inference_current_sequence_len += key_layer.size(0) @@ -501,7 +518,7 @@ def forward( key_layer = self.inference_key_memory[:end, ...] value_layer = self.inference_value_memory[:end, ...] # Adjust attention mask - if attention_mask is not None: + if attention_mask is not None and self.attention_type == AttnType.self_attn: attention_mask = attention_mask[..., start:end, :end] # adjust the key rotary positional embedding if rotary_pos_emb is not None: @@ -569,7 +586,10 @@ def forward( relative_position_bias=relative_position_bias, headscale_tensor=self.head_scale_tensor if self.headscale else None, inference_mode=inference_max_sequence_len is not None and query_layer.shape[0] == 1, + return_scores=return_scores, ) + if return_scores: + context_layer, attention_probs = context_layer # ================= # Output. [sq, b, h] @@ -585,6 +605,9 @@ def forward( if get_key_value: output = [output, present] + if return_scores: + output = [output, attention_probs] + return output, bias @@ -857,6 +880,7 @@ def forward( relative_position_bias=None, headscale_tensor=None, inference_mode=None, + return_scores=None, ): b, np, sq, sk, hn = ( query_layer.size(1), @@ -914,9 +938,27 @@ def forward( # relative_position_bias [b, np, sq, sk] # context_layer [b, np, sq, hn] # ================================================== - context_layer = self.attn_fn( - query_layer, key_layer, value_layer, attention_mask, relative_position_bias, inference_mode - ) + if not return_scores: + context_layer = self.attn_fn( + query_layer, + key_layer, + value_layer, + attention_mask, + relative_position_bias, + inference_mode, + ) + else: + # SpeechLLM TTS modifications + context_layer = self.torch_attention_with_prior( + query_layer, + key_layer, + value_layer, + attention_mask, + relative_position_bias, + inference_mode, + return_scores=return_scores, + ) + context_layer, attention_probs = context_layer if headscale_tensor is not None: context_layer = context_layer * headscale_tensor @@ -928,7 +970,10 @@ def forward( new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) context_layer = context_layer.view(*new_context_layer_shape) - return context_layer + if return_scores: + return context_layer, attention_probs + else: + return context_layer def torch_attention(self, query_layer, key_layer, value_layer, attention_mask, attention_bias, inference_mode): sq, b, np, hn = query_layer.shape @@ -986,6 +1031,69 @@ def torch_attention(self, query_layer, key_layer, value_layer, attention_mask, a return context_layer + def torch_attention_with_prior( + self, query_layer, key_layer, value_layer, attention_mask, attention_bias, inference_mode, return_scores=False + ): + sq, b, np, hn = query_layer.shape + sk = key_layer.shape[0] + + if self.multi_query_attention: + query_layer = rearrange(query_layer, 'sq b np hn -> b (np sq) hn') + key_layer = rearrange(key_layer, 'sk b 1 hn -> b hn sk') + value_layer = rearrange(value_layer, 'sv b np hn -> (b np) sv hn') + else: + query_layer = rearrange(query_layer, 'sq b np hn -> (b np) sq hn') + key_layer = rearrange(key_layer, 'sk b np hn -> (b np) hn sk') + value_layer = rearrange(value_layer, 'sv b np hn -> (b np) sv hn') + + matmul_input_buffer = torch.empty( + query_layer.shape[0], + query_layer.shape[1], + key_layer.shape[2], + dtype=query_layer.dtype, + device=query_layer.device, + ) + + matmul_result = torch.baddbmm( + matmul_input_buffer, + query_layer, + key_layer, + beta=0.0, + alpha=(1.0 / self.norm_factor) if self.normalize_attention_scores else 1.0, + ) + + # change view to [b, np, sq, sk] + attention_scores = matmul_result.view(b, np, sq, sk) + + if attention_bias is not None: + # attention_bias is not None only for cross attention layers right now in T5 + attention_scores = torch.log_softmax(attention_scores, dim=-1) + attention_bias + + _attention_probs = self.scale_mask_softmax(attention_scores, attention_mask) + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + + if not self.sequence_parallel: + with tensor_parallel.random.get_cuda_rng_tracker().fork(): + attention_probs = self.attention_dropout(_attention_probs) + else: + attention_probs = self.attention_dropout(_attention_probs) + + # change view [b * np, sq, sk] + attention_probs = rearrange(attention_probs, 'b np sq sk -> (b np) sq sk') + + # matmul: [b * np, sq, hn] + context_layer = torch.bmm(attention_probs, value_layer) + + # change view [b, np, sq, hn] + context_layer = rearrange(context_layer, '(b np) sq hn -> b np sq hn', np=np) + + if return_scores: + # return context_layer, _attention_probs + return context_layer, attention_scores + else: + return context_layer + def flash_attention(self, query_layer, key_layer, value_layer, attention_mask, attention_bias, inference_mode): query_layer = rearrange(query_layer, 'sq b np hn -> b sq np hn') key_layer = rearrange(key_layer, 'sk b np hn -> b sk np hn') diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_decoders.py b/nemo/collections/nlp/modules/common/megatron/megatron_decoders.py index 712ce10b81b5..d2945a061584 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_decoders.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_decoders.py @@ -13,7 +13,7 @@ # limitations under the License. """Transformer based language model.""" -from ast import Mod +from nemo.collections.nlp.modules.common.megatron.layer_type import LayerType from nemo.collections.nlp.modules.common.megatron.megatron_transformer_decoder import MegatronTransformerDecoderModule from nemo.collections.nlp.modules.common.megatron.retrieval_transformer import ( MegatronRetrievalTransformerDecoderModule, @@ -87,7 +87,7 @@ def get_decoder_model( transformer_block_type="pre_ln", hidden_steps=-1, parent_model_type=ModelType.encoder_or_decoder, - layer_type=None, + layer_type=LayerType.decoder, chunk_size=64, layer_number_offset=0, # this is use only for attention norm_factor scaling megatron_legacy=False, @@ -158,6 +158,7 @@ def get_decoder_model( moe_dropout=moe_dropout, position_embedding_type=position_embedding_type, use_flash_attention=use_flash_attention, + layer_type=layer_type, ) elif arch == "retro": decoder = MegatronRetrievalTransformerDecoderModule( diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_encoder_decoder.py b/nemo/collections/nlp/modules/common/megatron/megatron_encoder_decoder.py index c4192dacb45a..744a6e18c8b1 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_encoder_decoder.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_encoder_decoder.py @@ -13,7 +13,6 @@ # limitations under the License. """Transformer based language model.""" -from ast import Mod import torch @@ -46,8 +45,7 @@ class MegatronTransformerEncoderDecoderModule(MegatronModule): - """Transformer encoder-decoder model. - """ + """Transformer encoder-decoder model.""" def __init__( self, @@ -85,6 +83,8 @@ def __init__( encoder_attn_mask_type = AttnMaskType.padding elif hasattr(encoder.model, 'self_attn_mask_type'): encoder_attn_mask_type = encoder.model.self_attn_mask_type + elif isinstance(encoder.model, torch.nn.ModuleList) and hasattr(encoder.model[0], 'self_attn_mask_type'): + encoder_attn_mask_type = encoder.model[0].self_attn_mask_type else: raise AttributeError( "Could not find an attribute for encoder self_attn_mask_type, make sure it is set when instatiating the encoder or pass it to the constructor of this class." @@ -142,7 +142,11 @@ def encode( # apply hidden transformations if needed if self.hiddens_module is not None: enc_output = self.hiddens_module.apply_hidden_transforms( - {"hiddens": enc_output, "hiddens_mask": self.get_hiddens_mask(enc_attn_mask),}, batch_data=batch_data, + { + "hiddens": enc_output, + "hiddens_mask": self.get_hiddens_mask(enc_attn_mask), + }, + batch_data=batch_data, ) return enc_output @@ -157,6 +161,11 @@ def decode( dec_get_key_value=False, dec_self_attention_relative_position_bias=None, dec_cross_attention_relative_position_bias=None, + return_all_crossattention_probs=False, + set_inference_key_value_memory=False, + decoder_max_sequence_len=None, + encoder_max_sequence_len=None, + enc_output_to_layers=None, ): if self.decoder is None: raise ValueError(f"Cannot call .decode(...) when self.decoder is None.") @@ -170,6 +179,11 @@ def decode( enc_attn_mask=enc_attn_mask, dec_self_attention_relative_position_bias=dec_self_attention_relative_position_bias, dec_cross_attention_relative_position_bias=dec_cross_attention_relative_position_bias, + return_all_crossattention_probs=return_all_crossattention_probs, + set_inference_key_value_memory=set_inference_key_value_memory, + decoder_max_sequence_len=decoder_max_sequence_len, + encoder_max_sequence_len=encoder_max_sequence_len, + enc_output_to_layers=enc_output_to_layers, ) return dec_output @@ -191,6 +205,11 @@ def forward( dec_self_attention_relative_position_bias=None, dec_cross_attention_relative_position_bias=None, batch_data=None, + return_all_crossattention_probs=False, + set_inference_key_value_memory=False, + decoder_max_sequence_len=None, + encoder_max_sequence_len=None, + enc_output_to_layers=None, ): # encoder if enc_output is None: @@ -207,7 +226,10 @@ def forward( assert self.encoder_hidden_state is not None enc_output = self.encoder_hidden_state else: - enc_attn_mask = enc_output_attn_mask.to(enc_attn_mask) + if isinstance(enc_output_attn_mask, list): + enc_attn_mask = [mask.to(enc_attn_mask[midx]) for midx, mask in enumerate(enc_output_attn_mask)] + else: + enc_attn_mask = enc_output_attn_mask.to(enc_attn_mask) if self.decoder is None or output_enc_hidden_only: return enc_output @@ -216,15 +238,22 @@ def forward( dec_output = self.decode( dec_input=dec_input, dec_attn_mask=dec_attn_mask, - enc_output=enc_output["enc_output"] # enc_output is a dict if we used hidden transformations - if self.hiddens_module is not None - else enc_output, + enc_output=( + enc_output["enc_output"] # enc_output is a dict if we used hidden transformations + if self.hiddens_module is not None + else enc_output + ), # Adjust encoder attention mask if encoder is a perceiver. enc_attn_mask=self.get_hiddens_mask(enc_attn_mask), dec_layer_past=dec_layer_past, dec_get_key_value=dec_get_key_value, dec_self_attention_relative_position_bias=dec_self_attention_relative_position_bias, dec_cross_attention_relative_position_bias=dec_cross_attention_relative_position_bias, + return_all_crossattention_probs=return_all_crossattention_probs, + set_inference_key_value_memory=set_inference_key_value_memory, + decoder_max_sequence_len=decoder_max_sequence_len, + encoder_max_sequence_len=encoder_max_sequence_len, + enc_output_to_layers=enc_output_to_layers, ) # if self.hiddens_module is not None enc_output is a dict, else it is a torch.tensor @@ -246,7 +275,10 @@ def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars= def load_state_dict(self, state_dict, strict=True): """Customized load.""" - self.encoder.load_state_dict(state_dict[self._encoder_key], strict=strict) - self.decoder.load_state_dict(state_dict[self._decoder_key], strict=strict) - if self.hiddens_module is not None: - self.hiddens_module.load_state_dict(state_dict[self._hiddens_module], strict=strict) + try: + self.encoder.load_state_dict(state_dict[self._encoder_key], strict=strict) + self.decoder.load_state_dict(state_dict[self._decoder_key], strict=strict) + if self.hiddens_module is not None: + self.hiddens_module.load_state_dict(state_dict[self._hiddens_module], strict=strict) + except KeyError as e: + super().load_state_dict(state_dict, strict=strict) diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_encoders.py b/nemo/collections/nlp/modules/common/megatron/megatron_encoders.py index 601eb320e8fc..3d2b2c1ecc13 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_encoders.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_encoders.py @@ -14,7 +14,10 @@ """Transformer based language model.""" from nemo.collections.nlp.modules.common.megatron.megatron_perceiver_encoders import MegatronPerceiverEncoderModule -from nemo.collections.nlp.modules.common.megatron.megatron_transformer_encoder import MegatronTransformerEncoderModule +from nemo.collections.nlp.modules.common.megatron.megatron_transformer_encoder import ( + MegatronTransformerEncoderModule, + MultiMegatronTransformerEncoderModule, +) from nemo.collections.nlp.modules.common.megatron.retrieval_transformer import ( MegatronRetrievalTransformerEncoderModule, ) @@ -108,6 +111,7 @@ def get_encoder_model( version=1, # model version position_embedding_type='learned_absolute', use_flash_attention=False, + n_transformers=1, ): """Build language model and return along with the key to save.""" @@ -167,6 +171,51 @@ def get_encoder_model( position_embedding_type=position_embedding_type, use_flash_attention=use_flash_attention, ) + elif arch == "multi_transformer": + encoder = MultiMegatronTransformerEncoderModule( + config=config, + n_transformers=n_transformers, + init_method=init_method, + output_layer_init_method=scaled_init_method, + hidden_size=hidden_size, + num_layers=num_layers, + num_attention_heads=num_attention_heads, + apply_query_key_layer_scaling=apply_query_key_layer_scaling, + kv_channels=kv_channels, + ffn_hidden_size=ffn_hidden_size, + encoder_attn_mask_type=encoder_attn_mask_type, + pre_process=pre_process, + post_process=post_process, + hidden_dropout=hidden_dropout, + attention_dropout=attention_dropout, + ffn_dropout=ffn_dropout, + precision=precision, + fp32_residual_connection=fp32_residual_connection, + activations_checkpoint_method=activations_checkpoint_method, + activations_checkpoint_num_layers=activations_checkpoint_num_layers, + activations_checkpoint_granularity=activations_checkpoint_granularity, + layernorm_epsilon=layernorm_epsilon, + bias_activation_fusion=bias_activation_fusion, + bias_dropout_add_fusion=bias_dropout_add_fusion, + masked_softmax_fusion=masked_softmax_fusion, + persist_layer_norm=persist_layer_norm, + openai_gelu=openai_gelu, + onnx_safe=onnx_safe, + activation=activation, + bias=bias, + normalization=normalization, + transformer_block_type=transformer_block_type, + headscale=headscale, + parent_model_type=parent_model_type, + megatron_legacy=megatron_legacy, + normalize_attention_scores=normalize_attention_scores, + num_moe_experts=num_moe_experts, + moe_frequency=moe_frequency, + moe_dropout=moe_dropout, + position_embedding_type=position_embedding_type, + use_flash_attention=use_flash_attention, + ) + elif arch == "retro": encoder = MegatronRetrievalTransformerEncoderModule( config=config, diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_transformer_decoder.py b/nemo/collections/nlp/modules/common/megatron/megatron_transformer_decoder.py index 4a05a08820e7..14677552492b 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_transformer_decoder.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_transformer_decoder.py @@ -52,8 +52,7 @@ class MegatronTransformerDecoderModule(MegatronModule, Exportable, MegatronDecoderModule): - """Transformer decoder model. - """ + """Transformer decoder model.""" def __init__( self, @@ -97,6 +96,7 @@ def __init__( moe_dropout=0.0, position_embedding_type='learned_absolute', use_flash_attention=False, + layer_type=LayerType.decoder, ): super(MegatronTransformerDecoderModule, self).__init__(config=config) @@ -121,7 +121,7 @@ def __init__( # Transformer. self.model = ParallelTransformer( config=config, - layer_type=LayerType.decoder, + layer_type=layer_type, init_method=self.init_method, output_layer_init_method=self.output_layer_init_method, num_layers=self.num_layers, @@ -165,7 +165,7 @@ def __init__( self._model_key = 'model' def set_input_tensor(self, input_tensor): - """ See megatron.model.transformer.set_input_tensor()""" + """See megatron.model.transformer.set_input_tensor()""" self.model.set_input_tensor(input_tensor) def forward( @@ -178,15 +178,41 @@ def forward( get_key_value=False, dec_self_attention_relative_position_bias=None, dec_cross_attention_relative_position_bias=None, + return_all_crossattention_probs=False, + set_inference_key_value_memory=False, + decoder_max_sequence_len=None, + encoder_max_sequence_len=None, + enc_output_to_layers=None, ): # convert to Megatron mask dec_attn_mask_3d = build_attention_mask_3d( - source_mask=dec_attn_mask, target_mask=dec_attn_mask, attn_mask_type=self.model_attn_mask_type, - ) - enc_dec_attn_mask_3d = build_attention_mask_3d( - source_mask=dec_attn_mask, target_mask=enc_attn_mask, attn_mask_type=AttnMaskType.padding, + source_mask=dec_attn_mask, + target_mask=dec_attn_mask, + attn_mask_type=self.model_attn_mask_type, ) + if isinstance(enc_output, list): + assert len(enc_output) == len(enc_attn_mask) + enc_dec_attn_mask_3d = [] + for i in range(len(enc_output)): + enc_dec_attn_mask_3d.append( + attn_mask_postprocess( + build_attention_mask_3d( + source_mask=dec_attn_mask, + target_mask=enc_attn_mask[i], + attn_mask_type=AttnMaskType.padding, + ) + ) + ) + else: + enc_dec_attn_mask_3d = attn_mask_postprocess( + build_attention_mask_3d( + source_mask=dec_attn_mask, + target_mask=enc_attn_mask, + attn_mask_type=AttnMaskType.padding, + ) + ) + # transformer decoder dec_output = self.model( dec_input, @@ -194,9 +220,14 @@ def forward( layer_past=layer_past, get_key_value=get_key_value, encoder_output=enc_output, - enc_dec_attn_mask=attn_mask_postprocess(enc_dec_attn_mask_3d), + enc_dec_attn_mask=enc_dec_attn_mask_3d, self_attention_relative_position_bias=dec_self_attention_relative_position_bias, cross_attention_relative_position_bias=dec_cross_attention_relative_position_bias, + return_all_crossattention_probs=return_all_crossattention_probs, + set_inference_key_value_memory=set_inference_key_value_memory, + decoder_max_sequence_len=decoder_max_sequence_len, + encoder_max_sequence_len=encoder_max_sequence_len, + enc_output_to_layers=enc_output_to_layers, ) return dec_output diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_transformer_encoder.py b/nemo/collections/nlp/modules/common/megatron/megatron_transformer_encoder.py index 7a41e1300066..a9b80868558f 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_transformer_encoder.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_transformer_encoder.py @@ -13,6 +13,8 @@ # limitations under the License. """Transformer based language model.""" +import torch + from nemo.collections.nlp.modules.common.megatron.layer_type import LayerType from nemo.collections.nlp.modules.common.megatron.megatron_encoder_module import MegatronEncoderModule from nemo.collections.nlp.modules.common.megatron.module import MegatronModule @@ -163,7 +165,7 @@ def __init__( self._model_key = 'model' def set_input_tensor(self, input_tensor): - """ See megatron.model.transformer.set_input_tensor()""" + """See megatron.model.transformer.set_input_tensor()""" self.model.set_input_tensor(input_tensor) def forward( @@ -173,6 +175,7 @@ def forward( layer_past=None, get_key_value=False, enc_self_attention_relative_position_bias=None, + set_inference_key_value_memory=False, ): # convert to Megatron mask if self.use_flash_attention: @@ -180,7 +183,9 @@ def forward( else: enc_attn_mask_3d = attn_mask_postprocess( build_attention_mask_3d( - source_mask=enc_attn_mask, target_mask=enc_attn_mask, attn_mask_type=self.model_attn_mask_type, + source_mask=enc_attn_mask, + target_mask=enc_attn_mask, + attn_mask_type=self.model_attn_mask_type, ) ) @@ -192,6 +197,7 @@ def forward( get_key_value=get_key_value, self_attention_relative_position_bias=enc_self_attention_relative_position_bias, cross_attention_relative_position_bias=None, + set_inference_key_value_memory=set_inference_key_value_memory, ) return enc_output @@ -231,3 +237,214 @@ def load_state_dict(self, state_dict, strict=True): state_dict_ = state_dict_self_attention self.model.load_state_dict(state_dict_, strict=strict) + + +class MultiMegatronTransformerEncoderModule(MegatronModule, Exportable, MegatronEncoderModule): + """Transformer encoder model.""" + + def __init__( + self, + config: ModelParallelConfig, + n_transformers, + init_method, + output_layer_init_method, + hidden_size, + ffn_hidden_size, + num_layers, + num_attention_heads, + apply_query_key_layer_scaling=True, + kv_channels=None, + pre_process=True, + post_process=True, + encoder_attn_mask_type=AttnMaskType.padding, + hidden_dropout=0.1, + attention_dropout=0.1, + ffn_dropout=0.0, + precision=16, + fp32_residual_connection=False, + activations_checkpoint_method=None, + activations_checkpoint_num_layers=1, + activations_checkpoint_granularity=None, + layernorm_epsilon=1e-5, + bias_activation_fusion=True, + bias_dropout_add_fusion=True, + masked_softmax_fusion=True, + persist_layer_norm=False, + openai_gelu=False, + onnx_safe=False, + activation='gelu', + bias=True, + normalization='layernorm', + transformer_block_type='pre_ln', + headscale=False, + parent_model_type=ModelType.encoder_or_decoder, + megatron_legacy=False, + normalize_attention_scores=True, + num_moe_experts=1, + moe_frequency=1, + moe_dropout=0.0, + position_embedding_type='learned_absolute', + use_flash_attention=False, + ): + super(MultiMegatronTransformerEncoderModule, self).__init__(config=config) + + self.pre_process = pre_process + self.post_process = post_process + self.hidden_size = hidden_size + self.num_layers = num_layers + self.init_method = init_method + self.model_attn_mask_type = encoder_attn_mask_type + self.hidden_dropout = hidden_dropout + self.output_layer_init_method = output_layer_init_method + self.parent_model_type = parent_model_type + self.normalization = normalization + self.transformer_block_type = transformer_block_type + self.use_flash_attention = use_flash_attention + + if kv_channels is None: + + assert ( + hidden_size % num_attention_heads == 0 + ), 'hidden_size must be divisible by num_attention_heads if kv_channels is None' + kv_channels = hidden_size // num_attention_heads + + # Transformer List + self.model = [] + for i in range(n_transformers): + transformer = ParallelTransformer( + config=config, + layer_type=LayerType.encoder, + init_method=self.init_method, + output_layer_init_method=self.output_layer_init_method, + num_layers=self.num_layers, + hidden_size=self.hidden_size, + num_attention_heads=num_attention_heads, + apply_query_key_layer_scaling=apply_query_key_layer_scaling, + kv_channels=kv_channels, + ffn_hidden_size=ffn_hidden_size, + self_attn_mask_type=self.model_attn_mask_type, + pre_process=self.pre_process, + post_process=self.post_process, + precision=precision, + fp32_residual_connection=fp32_residual_connection, + activations_checkpoint_method=activations_checkpoint_method, + activations_checkpoint_num_layers=activations_checkpoint_num_layers, + activations_checkpoint_granularity=activations_checkpoint_granularity, + layernorm_epsilon=layernorm_epsilon, + hidden_dropout=hidden_dropout, + attention_dropout=attention_dropout, + ffn_dropout=ffn_dropout, + bias_activation_fusion=bias_activation_fusion, + bias_dropout_add_fusion=bias_dropout_add_fusion, + masked_softmax_fusion=masked_softmax_fusion, + persist_layer_norm=persist_layer_norm, + openai_gelu=openai_gelu, + onnx_safe=onnx_safe, + activation=activation, + bias=bias, + normalization=normalization, + transformer_block_type=transformer_block_type, + headscale=headscale, + model_type=parent_model_type, + megatron_legacy=megatron_legacy, + normalize_attention_scores=normalize_attention_scores, + num_moe_experts=num_moe_experts, + moe_frequency=moe_frequency, + moe_dropout=moe_dropout, + position_embedding_type=position_embedding_type, + use_flash_attention=use_flash_attention, + ) + self.model.append(transformer) + + self.model = torch.nn.ModuleList(self.model) + + self._model_key = 'model' + + def set_input_tensor(self, input_tensor): + """See megatron.model.transformer.set_input_tensor()""" + for mi in range(len(self.model)): + self.model[mi].set_input_tensor(input_tensor) + + def forward( + self, + enc_input, + enc_attn_mask, + layer_past=None, + get_key_value=False, + enc_self_attention_relative_position_bias=None, + set_inference_key_value_memory=False, + ): + + assert isinstance(enc_input, list) + assert len(enc_input) == len(self.model) + assert isinstance(enc_attn_mask, list) + assert len(enc_attn_mask) == len(self.model) + assert isinstance(enc_self_attention_relative_position_bias, list) + # convert to Megatron mask + enc_outputs = [] + for encoder_number in range(len(self.model)): + enc_input_ = enc_input[encoder_number] + enc_attn_mask_ = enc_attn_mask[encoder_number] + enc_self_attention_relative_position_bias_ = enc_self_attention_relative_position_bias[encoder_number] + + if self.use_flash_attention: + enc_attn_mask_3d = enc_attn_mask_ < 0.5 + else: + enc_attn_mask_3d = attn_mask_postprocess( + build_attention_mask_3d( + source_mask=enc_attn_mask_, + target_mask=enc_attn_mask_, + attn_mask_type=self.model_attn_mask_type, + ) + ) + + # transformer encoder + enc_output = self.model[encoder_number]( + enc_input_, + enc_attn_mask_3d, + layer_past=layer_past, + get_key_value=get_key_value, + self_attention_relative_position_bias=enc_self_attention_relative_position_bias_, + cross_attention_relative_position_bias=None, + set_inference_key_value_memory=set_inference_key_value_memory, + ) + + enc_outputs.append(enc_output) + + return enc_outputs + + def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False): + """For easy load.""" + + state_dict_ = {} + + state_dict_[self._model_key] = self.model.state_dict_for_save_checkpoint(destination, prefix, keep_vars) + + return state_dict_ + + def load_state_dict(self, state_dict, strict=True): + """Customized load.""" + + # Encoder. + if self._model_key in state_dict: + state_dict_ = state_dict[self._model_key] + # for backward compatibility. + elif 'transformer' in state_dict: + state_dict_ = state_dict['transformer'] + else: + # for backward compatibility. + state_dict_ = {} + for key in state_dict.keys(): + if 'transformer.' in key: + state_dict_[key.split('transformer.')[1]] = state_dict[key] + + # for backward compatibility. + state_dict_self_attention = {} + for key in state_dict_.keys(): + if '.attention.' in key: + state_dict_self_attention[key.replace(".attention.", ".self_attention.")] = state_dict_[key] + else: + state_dict_self_attention[key] = state_dict_[key] + state_dict_ = state_dict_self_attention + + self.model.load_state_dict(state_dict_, strict=strict) diff --git a/nemo/collections/nlp/modules/common/megatron/module.py b/nemo/collections/nlp/modules/common/megatron/module.py index ccd485427c3c..a4efb2992166 100644 --- a/nemo/collections/nlp/modules/common/megatron/module.py +++ b/nemo/collections/nlp/modules/common/megatron/module.py @@ -113,7 +113,7 @@ def decoder_cross_attention_relative_position_embeddings_weight(self): def initialize_word_embeddings(self, init_method, vocab_size, hidden_size): if not self.share_token_embeddings: - raise Exception('initialize_word_embeddings() was called but ' 'share_token_embeddings is false') + raise Exception('initialize_word_embeddings() was called but share_token_embeddings is false') # This function just initializes the word embeddings in the final stage # when we are using pipeline parallelism. If we aren't using pipeline @@ -140,7 +140,10 @@ def initialize_word_embeddings(self, init_method, vocab_size, hidden_size): # set word_embeddings weights to 0 here, then copy first # stage's weights using all_reduce below. self.word_embeddings = tensor_parallel.VocabParallelEmbedding( - vocab_size, hidden_size, init_method=init_method, config=self.config, + vocab_size, + hidden_size, + init_method=init_method, + config=self.config, ) self.word_embeddings.weight.data.fill_(0) self.word_embeddings.weight.shared = True diff --git a/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py b/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py index b7b377940eb4..e68113949aa7 100644 --- a/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py +++ b/nemo/collections/nlp/modules/common/megatron/token_level_encoder_decoder.py @@ -42,6 +42,7 @@ ) from nemo.collections.nlp.modules.common.megatron.vocab_parallel_cross_entropy import vocab_parallel_cross_entropy from nemo.core.classes.mixins import adapter_mixins +from nemo.utils import logging try: from apex.transformer.enums import AttnMaskType, ModelType @@ -67,7 +68,11 @@ HAVE_MEGATRON_CORE = False -__all__ = ["MegatronTokenLevelHead", "MegatronTokenLevelEncoderDecoderModule"] +__all__ = [ + "MegatronTokenLevelHead", + "MegatronTokenLevelEncoderDecoderModule", + "MegatronTokenLevelEncoderDecoderSpeechLLMModule", +] class MegatronTokenLevelHead(MegatronModule): @@ -252,6 +257,7 @@ def __init__( moe_dropout=encoder_cfg.get('moe_dropout', 0.0), position_embedding_type=encoder_cfg.get('position_embedding_type', 'learned_absolute'), use_flash_attention=encoder_cfg.get('use_flash_attention', False), + n_transformers=encoder_cfg.get('n_transformers', 1), ) if add_decoder: @@ -388,6 +394,7 @@ def __init__( moe_dropout=decoder_cfg.get('moe_dropout', 0.0), position_embedding_type=decoder_cfg.get('position_embedding_type', 'learned_absolute'), use_flash_attention=decoder_cfg.get('use_flash_attention', False), + layer_type=decoder_cfg.get('layer_type', LayerType.decoder), ) hiddens_module = get_hiddens_module(hiddens_cfg, model_parallel_cfg=config) @@ -410,6 +417,7 @@ def __init__( if add_decoder and post_process: if share_decoder_tokens_head_embeddings: + # parallel_output is True if TP > 1 (3b model) self.tokens_head = MegatronTokenLevelHead( self.word_embeddings_weight().size(0), parallel_output, bias=tokens_head_bias ) @@ -469,7 +477,7 @@ def _validate_config(self): return encoder_kv_channels, decoder_kv_channels def set_input_tensor(self, input_tensor): - """ See megatron.model.transformer.set_input_tensor()""" + """See megatron.model.transformer.set_input_tensor()""" # This is usually handled in schedules.py but some inference code still # gives us non-lists or None @@ -566,7 +574,8 @@ def forward( if self.add_encoder and self.encoder_relative_position_embedding is not None: encoder_self_attention_relative_position_bias = self.encoder_relative_position_embedding( - query_seq_length=enc_seq_length, key_seq_length=enc_seq_length, + query_seq_length=enc_seq_length, + key_seq_length=enc_seq_length, ) if output_enc_hidden_only: @@ -604,8 +613,11 @@ def forward( query_seq_length=dec_input_ids.size(1), key_seq_length=dec_input_ids.size(1) ) if not self.decoder_cfg.relative_position_bias_self_attention_only: - decoder_cross_attention_relative_position_bias = self.decoder_cross_attention_relative_position_embedding( - query_seq_length=dec_input_ids.size(1), key_seq_length=enc_seq_length, + decoder_cross_attention_relative_position_bias = ( + self.decoder_cross_attention_relative_position_embedding( + query_seq_length=dec_input_ids.size(1), + key_seq_length=enc_seq_length, + ) ) else: decoder_cross_attention_relative_position_bias = None @@ -656,7 +668,8 @@ def forward( # check if hiddens is used if self.hiddens_cfg is not None: loss_dict = self.enc_dec_model.hiddens_module.apply_loss_transforms( - outputs=enc_output, batch_data=batch_data, + outputs=enc_output, + batch_data=batch_data, ) loss_dict["tokens_loss"] = tokens_loss # We need to store default output in a known key, so that we can mimic default behaviour @@ -708,8 +721,437 @@ def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars= def load_state_dict(self, state_dict, strict=True): """Customized load.""" - - self.encoder_embedding.encoder_embeddingload_state_dict(state_dict[self._encoder_embedding_key], strict=strict) + self.encoder_embedding.load_state_dict(state_dict[self._encoder_embedding_key], strict=strict) self.decoder_embedding.load_state_dict(state_dict[self._decoder_embedding_key], strict=strict) self.enc_dec_model.load_state_dict(state_dict[self._enc_dec_model_key], strict=strict) self.tokens_head.load_state_dict(state_dict[self._tokens_head_key], strict=strict) + + +class MegatronTokenLevelEncoderDecoderSpeechLLMModule(MegatronTokenLevelEncoderDecoderModule): + def __init__(self, *args, **kwargs): + super(MegatronTokenLevelEncoderDecoderSpeechLLMModule, self).__init__(*args, **kwargs) + # Overridden in MegatronT5SpeechLMModel constructor + self.seq_pattern = "parallel" + self.speech_head_type = "token_level" + self.attn_prior_scaledown_start_step = 10000 + self.attn_prior_end_step = 11000 + self.use_alignment_loss = False + self.return_all_crossattention_probs = False + self.logging_step = False + self.num_cross_attention_heads = 12 # 12 for 220m T5, 16 for 11b T5 + self.enc_output_to_layers = None + + def get_decoder_embeddings(self, dec_input_ids, dec_position_ids, token_type_ids): + if dec_input_ids.dim() <= 2: + dec_input = self.decoder_embedding(dec_input_ids, dec_position_ids, token_type_ids=token_type_ids) + else: + dec_input = None + for i in range(dec_input_ids.size()[1]): + if i == 0: + # For the first channel (text + first layer of speech), use the decoder embedding layer + dec_input = self.decoder_embedding( + dec_input_ids[:, i, :], dec_position_ids, token_type_ids=token_type_ids + ) + else: + # For the rest of the channels (speech), use the speech embedding layer. No need for position, since already added in first layer. + current = self.speech_tokens_embeddings[i - 1](dec_input_ids[:, i, :]).permute(1, 0, 2) + # @pneekhara - Commenting the below because we always want to include all channels for speech. + # @pneekhara - include_channel_flag can become 0 when doing autoregressive inference and the first timestep is zeros + # For text inputs, only include 1st channel embeddings. Zero-out others. + # include_channel_flag = (torch.sum(dec_input_ids[:, i, :], dim=1) > 0).float() # [B] + # current = current * include_channel_flag.unsqueeze(0).unsqueeze(2) + dec_input = dec_input + current + + return dec_input + + def forward( + self, + enc_input_ids=None, + enc_attn_mask=None, + dec_input_ids=None, + dec_attn_mask=None, + token_type_ids=None, + labels=None, + batch_data=None, # additional data to be passed to hiddens module + enc_output=None, # Result of running the entire encoder + enc_output_attn_mask=None, + enc_input=None, # Result of running encoder embedding only + output_enc_hidden_only=False, + speech_mask=None, + cross_attention_prior=None, + text_limits=None, + global_step=None, + set_inference_key_value_memory=False, + decoder_max_sequence_len=None, + encoder_max_sequence_len=None, + ): + """ + Return value is per token / per dimension (i.e., non collapsed loss value) + """ + ( + encoder_self_attention_relative_position_bias, + decoder_self_attention_relative_position_bias, + decoder_cross_attention_relative_position_bias, + ) = (None, None, None) + + if enc_input is not None and enc_output is not None: + raise ValueError( + """Both enc_input and enc_output are not None. + You should only be passing one of them. + enc_input is the result of the encoder embedding layer + enc_output is the result of running the entire transformer encoder.""" + ) + + # In order of precedence, we use enc_output, enc_input, and then enc_input_ids to determine the encoder sequence length. + if enc_output is not None: + # If enc_output is provided in `batch_for_pipeline`, we need to transpose it from [B x S x H] -> [S x B x H]. + if isinstance(enc_output, list): + encoder_self_attention_relative_position_bias = [None for _ in enc_output] + enc_output = [x.transpose(0, 1) for x in enc_output] + enc_seq_length = [x.size(0) for x in enc_output] + else: + enc_output = enc_output.transpose(0, 1) + enc_seq_length = enc_output.size(0) + elif enc_input is not None: + # If enc_input is provided, we need to transpose it from [B x S x H] -> [S x B x H]. + if isinstance(enc_input, list): + encoder_self_attention_relative_position_bias = [None for _ in enc_input] + enc_input = [x.transpose(0, 1) for x in enc_input] + enc_seq_length = [x.size(0) for x in enc_input] + else: + enc_input = enc_input.transpose(0, 1) + enc_seq_length = enc_input.size(0) + # Only need to run encoder embedding and position ids if enc_input or enc_output is not provided. + elif enc_input_ids is not None: + enc_seq_length = enc_input_ids.size(1) + if self.pre_process and self.add_encoder: + # We don't need position ids for RPE, because the embedding layer does not have position embeddings. + if self.encoder_relative_position_embedding is None: + enc_input_ids_p = enc_input_ids[:, 0, :] if enc_input_ids.dim() == 3 else enc_input_ids + enc_position_ids = build_position_ids(enc_input_ids_p) + else: + enc_position_ids = None + enc_input = self.encoder_embedding(enc_input_ids, enc_position_ids, token_type_ids=token_type_ids) + if self.is_adapter_available(): + _sq, _bs, _hs = enc_input.size() + ptuning_adapter = self.get_adapter_module(AdapterName.PTUNING_ADAPTER) + v = ptuning_adapter.virtual_tokens + if ( + ptuning_adapter and _sq >= v + ): # The sequence should be longer the v to insert virtual embeddings. + virtual_embeddings = ptuning_adapter(_bs) + enc_input = enc_input[ + v:, :, : + ] # the first v tokens are pads so that they can be swapped out with virtual embeddings. + enc_input = torch.concat([virtual_embeddings, enc_input], dim=0) + else: + enc_input = None + else: + # This should only happen with PP > 1 for enc-dec prompt learning models + enc_seq_length = enc_attn_mask.size(1) + + if self.add_encoder and self.encoder_relative_position_embedding is not None: + encoder_self_attention_relative_position_bias = self.encoder_relative_position_embedding( + query_seq_length=enc_seq_length, + key_seq_length=enc_seq_length, + ) + + if output_enc_hidden_only: + # When pipeline parallel > 1 we need to make sure encoder exist (will be missing in decoder) + # SpeechT5 should not go here for inference + if enc_output is None and self.enc_dec_model.encoder is not None: + enc_output = self.enc_dec_model.encode( + enc_input=enc_input, + enc_attn_mask=enc_attn_mask, + enc_layer_past=None, + enc_get_key_value=False, + enc_self_attention_relative_position_bias=encoder_self_attention_relative_position_bias, + batch_data=batch_data, + ) + else: + enc_output = self.enc_dec_model.encoder_hidden_state + + return enc_output + else: + if enc_output_attn_mask is None: + enc_output_attn_mask = enc_attn_mask + + if self.pre_process and self.add_decoder: + # We don't need position ids for RPE, because the embedding layer does not have position embeddings. + if self.decoder_relative_position_embedding is None: + dec_input_ids_p = dec_input_ids[:, 0, :] if dec_input_ids.dim() == 3 else dec_input_ids + dec_position_ids = build_position_ids(dec_input_ids_p) + else: + dec_position_ids = None + dec_input = self.get_decoder_embeddings(dec_input_ids, dec_position_ids, token_type_ids) + if not set_inference_key_value_memory and (decoder_max_sequence_len or encoder_max_sequence_len): + # In inference + # On step 0 when set_inference_key_value_memory is True, we need all inputs in case + # we are using decoder context + # Else on step >= 1, only need last input + logging.debug("Clipping dec_input and only keep the last input.") + dec_input = dec_input[-1, :, :].unsqueeze(0) # shape (b, embed_dim) + else: + # Note: This is when the decoder itself is split across PP ranks. + dec_input = None + + if self.add_decoder and self.decoder_relative_position_embedding is not None: + decoder_self_attention_relative_position_bias = self.decoder_relative_position_embedding( + query_seq_length=dec_input_ids.size(1), key_seq_length=dec_input_ids.size(1) + ) + if not self.decoder_cfg.relative_position_bias_self_attention_only: + decoder_cross_attention_relative_position_bias = ( + self.decoder_cross_attention_relative_position_embedding( + query_seq_length=dec_input_ids.size(1), + key_seq_length=enc_seq_length, + ) + ) + else: + decoder_cross_attention_relative_position_bias = None + + return_all_crossattention_probs = self.return_all_crossattention_probs + single_encoder = False + if not isinstance(cross_attention_prior, list): + single_encoder = True + cross_attention_prior = [cross_attention_prior] + + decoder_cross_attention_relative_position_bias = [] + for _cross_attention_prior in cross_attention_prior: + _decoder_cross_attention_relative_position_bias = None + if _cross_attention_prior is not None: + # cross_attention_prior shape [B, dec_len, enc_len] + # Repeat it to make it [B, 12, dec_len, enc_len] + attn_prior_end_step = self.attn_prior_end_step + attn_prior_scaledown_start_step = self.attn_prior_scaledown_start_step + num_attention_heads = self.num_cross_attention_heads + assert attn_prior_scaledown_start_step <= attn_prior_end_step + logging.debug( + f"attn_prior_scaledown_start_step: {attn_prior_scaledown_start_step}, attn_prior_scaledown_start_step: {attn_prior_end_step}" + ) + if global_step >= attn_prior_end_step: + _decoder_cross_attention_relative_position_bias = None + elif global_step > attn_prior_scaledown_start_step and global_step < attn_prior_end_step: + total_annealing_steps = attn_prior_end_step - attn_prior_scaledown_start_step + curr_annealing_step = global_step - attn_prior_scaledown_start_step + curr_cross_attention_prior = _cross_attention_prior + ( + (1.0 - _cross_attention_prior) * curr_annealing_step / total_annealing_steps + ) + _decoder_cross_attention_relative_position_bias = curr_cross_attention_prior.unsqueeze( + 1 + ).repeat(1, num_attention_heads, 1, 1) + _decoder_cross_attention_relative_position_bias = torch.log( + _decoder_cross_attention_relative_position_bias + 1e-8 + ) + else: + _decoder_cross_attention_relative_position_bias = _cross_attention_prior.unsqueeze(1).repeat( + 1, num_attention_heads, 1, 1 + ) + _decoder_cross_attention_relative_position_bias = torch.log( + _decoder_cross_attention_relative_position_bias + 1e-8 + ) + decoder_cross_attention_relative_position_bias.append(_decoder_cross_attention_relative_position_bias) + + return_all_crossattention_probs = return_all_crossattention_probs or self.logging_step + + if single_encoder: + decoder_cross_attention_relative_position_bias = decoder_cross_attention_relative_position_bias[0] + + output = self.enc_dec_model( + enc_input=enc_input, + enc_attn_mask=enc_attn_mask, + dec_input=dec_input, + dec_attn_mask=dec_attn_mask, + enc_layer_past=None, + enc_get_key_value=False, + enc_output=enc_output, + enc_output_attn_mask=enc_output_attn_mask, + dec_layer_past=None, + dec_get_key_value=False, + enc_self_attention_relative_position_bias=encoder_self_attention_relative_position_bias, + dec_self_attention_relative_position_bias=decoder_self_attention_relative_position_bias, + dec_cross_attention_relative_position_bias=decoder_cross_attention_relative_position_bias, + return_all_crossattention_probs=return_all_crossattention_probs, + batch_data=batch_data, + set_inference_key_value_memory=set_inference_key_value_memory, + decoder_max_sequence_len=decoder_max_sequence_len, + encoder_max_sequence_len=encoder_max_sequence_len, + enc_output_to_layers=self.enc_output_to_layers, + ) + + alignment_loss = None + if self.post_process and self.add_decoder: + dec_output, enc_output = output # [s, b, h] + if return_all_crossattention_probs: + dec_output, attention_scores = dec_output + attention_probs = [ + torch.softmax(attention_score, dim=-1) + for lidx, attention_score in enumerate(attention_scores) + if lidx in self.alignment_decoder_layerids + ] + + if text_limits is not None and self.use_alignment_loss and hasattr(self, "forward_sum_loss"): + attention_scores_filtered = [ + attention_scores[lidx] for lidx in self.alignment_decoder_layerids + ] + attention_scores_combined = torch.cat(attention_scores_filtered, dim=1) + text_start_idx = text_limits[0, 0].item() + assert torch.all( + text_limits[:, 0] == text_start_idx + ) # all texts should start at the same index + end_offset = self.alignment_text_end_offset + # align_every_n_head: eg if set to 2, will skip every other head + # if set to 12, will select 1 head from every layer + align_every_n_head = self.align_every_n_head + dec_start_idx = self.decoder_context_len + 1 # +1 to remove bos + attention_scores_sliced = attention_scores_combined[ + :, ::align_every_n_head, dec_start_idx:, text_start_idx : -(2 + end_offset) + ] # -2 to remove eos and pad + attention_logprobs = ( + attention_scores_sliced # not taking log_softmax, since we will do that in loss function + ) + attention_logprobs = torch.mean(attention_logprobs, dim=1, keepdim=True) + dec_len = torch.sum(dec_attn_mask, dim=1) - dec_start_idx + enc_len = text_limits[:, 1] - text_limits[:, 0] - end_offset + alignment_loss = self.forward_sum_loss( + attn_logprob=attention_logprobs, in_lens=enc_len, out_lens=dec_len + ) + else: + attention_probs = None + # project decoder output to vocabulary-size dimensions + if self.share_decoder_tokens_head_embeddings: + first_layer_vocabsize = ( + self.speech_offset + self.speech_codebook_size + ) # variables set in __init__ of speechlm model + token_logits = self.tokens_head(dec_output, self.word_embeddings_weight()) # s, b, vocab + if self.seq_pattern in ["parallel", "delay_parallel"]: + # For flat seq_pattern we need all the logits + token_logits = token_logits[:, :, :first_layer_vocabsize] + speech_layers = self.num_speech_codebooks - 1 + + # speech_logits_list will be used in loss calculation (parallel output) + speech_logits_list = [] + if self.seq_pattern in ["parallel", "delay_parallel"] and torch.count_nonzero(speech_mask) > 0: + for i in range(speech_layers): + last_layer_logits = self.speech_tokens_heads[i](dec_output)[0] # T, B, 1024 + speech_logits_list.append(last_layer_logits) # T, B, 1024 + else: + token_logits = self.tokens_head(dec_output)[0] # T, B, WordEmbSize + + if labels is not None: + if labels.dim() == 2: + # [b, s] -> [s, b] + labels = labels.transpose(0, 1).contiguous() + elif labels.dim() == 3: + # [b, c, s] -> [c, s, b] + labels = labels.permute(1, 2, 0).contiguous() + + # Set label smoothing to 0 if in eval mode. + label_smoothing = self.label_smoothing if self.training else 0.0 + + # tensor_parallel.vocab_parallel_cross_entropy performs log_softmax and return log p(x_i|z) per token i + if self.fp16_cross_entropy: + assert token_logits.dtype == torch.half + if labels.dim() == 3: + raise NotImplementedError("fp16_cross_entropy is not support for labels of dimension 3") + tokens_loss = vocab_parallel_cross_entropy(token_logits, labels, label_smoothing) + else: + if labels.dim() == 2: + tokens_loss = vocab_parallel_cross_entropy(token_logits.float(), labels, label_smoothing) + elif labels.dim() == 3: + if token_logits.size()[0] != labels[0, :, :].size()[0]: + raise Exception("TODO: add a permute") + tokens_loss = vocab_parallel_cross_entropy( + token_logits.float(), labels[0, :, :], label_smoothing + ) + logging.debug(f"token_loss: {tokens_loss}") + logging.debug(f"token_loss: {torch.all(torch.isfinite(tokens_loss))}") + if ( + self.seq_pattern in ["parallel", "delay_parallel"] + and torch.count_nonzero(speech_mask) > 0 + ): + for i in range(speech_layers): + if speech_logits_list[i].size()[0] != labels[i + 1, :, :].size()[0]: + raise Exception("TODO: add a permute") + curr_codebook_loss = ( + vocab_parallel_cross_entropy( + speech_logits_list[i].float(), labels[i + 1, :, :], label_smoothing + ) + * speech_mask.T + ) + tokens_loss += curr_codebook_loss + logging.debug(f"token_loss_{i}: {tokens_loss}") + logging.debug(f"token_loss_{i}: {torch.all(torch.isfinite(tokens_loss))}") + + # [s, b] -> [b, s] + tokens_loss = tokens_loss.transpose(0, 1).contiguous() + + # check if hiddens is used + if self.hiddens_cfg is not None: + raise NotImplementedError("Not currently implemented for speechllm") + else: + return tokens_loss, [token_logits, speech_logits_list, attention_probs, alignment_loss] + else: + # else return token logits (and hiddens if needed) + # [s, b, h] -> [b, s, h] + # If labels is None then we are in inference mode and we return the gathered logits + if self.parallel_output: + # Gather logits from tensor parallel if in parallel_output mode + token_logits = tensor_parallel.gather_from_tensor_model_parallel_region( + token_logits + ) # T, B, 30208 + for _i in range(len(speech_logits_list)): + speech_logits_list[_i] = tensor_parallel.gather_from_tensor_model_parallel_region( + speech_logits_list[_i] + ) # T, B, 1024 + + token_logits = token_logits.transpose(0, 1).contiguous() # (B, T, 30208) + speech_logits = torch.stack(speech_logits_list, dim=-1) # T, B, 1024, 7 + speech_logits = speech_logits.transpose(0, 1).contiguous() # (B, T, 1024, 7) + + _si = self.speech_offset + _ei = _si + self.speech_codebook_size + first_layer_speech_logits = token_logits[:, :, _si:_ei].unsqueeze(-1) # (b, s, 1023, 1) + + all_speech_logits = torch.cat( + [first_layer_speech_logits, speech_logits], dim=-1 + ) # (b, s, 1024, 8) + + if self.hiddens_cfg is not None: + raise NotImplementedError("Not currently implemented for speechllm") + else: + # all_speech_logits: tensor, (b, s, 1024, 8), all layers of speech. + # token_logits: tensor, (b, s, vocab_size), text token logits. + # speech_logits: tensor, (b, s, 1024, 7), 1-7 layers of speech. + # attention_probs: tensor or None, (b, s, ) + # enc_output: tensor, (virtual_token_len+context_token_len+question_token_len+extra_id_0+[SEP], b, ) + return all_speech_logits, [token_logits, speech_logits, attention_probs, enc_output] + + elif self.add_decoder and not self.add_encoder: + decoder_output, _ = output + return decoder_output + else: + encoder_output = output + return encoder_output + + def state_dict(self): + """For easy load when model is combined with other heads, + add an extra key.""" + + state_dict_ = {} + state_dict_[self._encoder_embedding_key] = self.encoder_embedding.state_dict() + state_dict_[self._decoder_embedding_key] = self.decoder_embedding.state_dict() + state_dict_[self._enc_dec_model_key] = self.enc_dec_model.state_dict() + state_dict_[self._tokens_head_key] = self.tokens_head.state_dict() + if hasattr(self, "speech_tokens_heads"): + state_dict_["speech_tokens_heads"] = self.speech_tokens_heads.state_dict() + if hasattr(self, "speech_tokens_embeddings"): + state_dict_["speech_tokens_embeddings"] = self.speech_tokens_embeddings.state_dict() + return state_dict_ + + def load_state_dict(self, state_dict, strict=True): + """Customized load.""" + super().load_state_dict(state_dict, strict=strict) + if hasattr(self, "speech_tokens_heads"): + self.speech_tokens_heads.load_state_dict(state_dict["speech_tokens_heads"], strict=strict) + if hasattr(self, "speech_tokens_embeddings"): + self.speech_tokens_embeddings.load_state_dict(state_dict["speech_tokens_embeddings"], strict=strict) diff --git a/nemo/collections/nlp/modules/common/megatron/transformer.py b/nemo/collections/nlp/modules/common/megatron/transformer.py index ab10b0d0e8b3..c5108d8e3801 100644 --- a/nemo/collections/nlp/modules/common/megatron/transformer.py +++ b/nemo/collections/nlp/modules/common/megatron/transformer.py @@ -22,6 +22,7 @@ import torch import torch.nn as nn from einops import rearrange +from omegaconf.listconfig import ListConfig from nemo.collections.common.parts.adapter_modules import LinearAdapterConfig from nemo.collections.nlp.modules.common.megatron.adapters.parallel_adapters import ( @@ -479,6 +480,10 @@ def forward( self_attention_relative_position_bias=None, cross_attention_relative_position_bias=None, checkpoint_core_attention=False, + return_crossattention_scores=False, + return_selfattention_scores=False, + decoder_max_sequence_len=None, + encoder_max_sequence_len=None, ): # Self attention. if rotary_pos_emb is not None: @@ -489,6 +494,12 @@ def forward( self_attention_pos_emb = None cross_attention_pos_emb = None + if return_crossattention_scores and return_selfattention_scores: + raise NotImplementedError( + "We can only return 1 of cross attention scores or self attention scores. Not both yet." + ) + attention_probs = None + if self.layer_type != LayerType.retrieval_decoder_after_self_attn: # hidden_states: [b, s, h] @@ -507,12 +518,16 @@ def forward( layer_past=layer_past, get_key_value=get_key_value, set_inference_key_value_memory=set_inference_key_value_memory, - inference_max_sequence_len=inference_max_sequence_len, + inference_max_sequence_len=inference_max_sequence_len or decoder_max_sequence_len, rotary_pos_emb=self_attention_pos_emb, relative_position_bias=self_attention_relative_position_bias, checkpoint_core_attention=checkpoint_core_attention, + return_scores=return_selfattention_scores, ) + if return_selfattention_scores: + attention_output, attention_probs = attention_output + if get_key_value: attention_output, presents = attention_output @@ -526,7 +541,7 @@ def forward( attention_bias = None # jit scripting for a nn.module (with dropout) is not - # trigerring the fusion kernel. For now, we use two + # triggering the fusion kernel. For now, we use two # different nn.functional routines to account for varying # dropout semantics during training and inference phases. @@ -553,6 +568,9 @@ def forward( elif self.transformer_block_type in ['pre_ln', 'normformer']: # Layer norm post the self attention. normalization_output = self.post_attention_layernorm(layernorm_input) + else: + normalization_output = None + logging.warning(f"This is a rare case since `normalization_output=None`") else: layernorm_input, normalization_output = hidden_states @@ -579,7 +597,7 @@ def forward( checkpoint_core_attention=checkpoint_core_attention, ) else: - + # Return Scores is being passed only for inter_attention and not self attention attention_output, attention_bias = self.inter_attention( normalization_output, enc_dec_attn_mask, @@ -587,7 +605,12 @@ def forward( rotary_pos_emb=cross_attention_pos_emb, relative_position_bias=cross_attention_relative_position_bias, checkpoint_core_attention=checkpoint_core_attention, + return_scores=return_crossattention_scores, + set_inference_key_value_memory=set_inference_key_value_memory, + inference_max_sequence_len=encoder_max_sequence_len, ) + if return_crossattention_scores: + attention_output, attention_probs = attention_output # If normformer, apply norm on the output of the self attention. if self.transformer_block_type == 'normformer': @@ -632,6 +655,9 @@ def forward( if get_key_value: output = [output, presents] + if attention_probs is not None: + output = [output, attention_probs] + return output @@ -735,6 +761,10 @@ def forward( self_attention_relative_position_bias=None, cross_attention_relative_position_bias=None, checkpoint_core_attention=False, + return_crossattention_scores=False, + return_selfattention_scores=False, + decoder_max_sequence_len=None, + encoder_max_sequence_len=None, ): if self.dtype == torch.float32: return super().forward( @@ -750,6 +780,10 @@ def forward( self_attention_relative_position_bias, cross_attention_relative_position_bias, checkpoint_core_attention, + return_crossattention_scores=return_crossattention_scores, + return_selfattention_scores=return_selfattention_scores, + decoder_max_sequence_len=decoder_max_sequence_len, + encoder_max_sequence_len=encoder_max_sequence_len, ) with torch.autocast(device_type="cuda", dtype=self.dtype): return super().forward( @@ -765,6 +799,10 @@ def forward( self_attention_relative_position_bias, cross_attention_relative_position_bias, checkpoint_core_attention, + return_crossattention_scores=return_crossattention_scores, + return_selfattention_scores=return_selfattention_scores, + decoder_max_sequence_len=decoder_max_sequence_len, + encoder_max_sequence_len=encoder_max_sequence_len, ) @@ -1072,10 +1110,12 @@ def __init__( # Transformer layers. def build_layer(layer_number): - if isinstance(layer_type, list): + if isinstance(layer_type, (list, ListConfig)): lt = layer_type[layer_number - 1] else: lt = layer_type + if isinstance(lt, int): + lt = LayerType(lt) if self.transformer_engine: transformer_layer_args = { @@ -1493,7 +1533,16 @@ def forward( self_attention_relative_position_bias=None, cross_attention_relative_position_bias=None, checkpoint_activations_all_layers=None, + return_all_crossattention_probs=False, + return_all_selfattention_probs=False, + decoder_max_sequence_len=None, + encoder_max_sequence_len=None, + enc_output_to_layers=None, ): + if return_all_crossattention_probs and return_all_selfattention_probs: + raise NotImplementedError( + "We can only return 1 of cross attention probs or self attention probs. Not both yet." + ) # Checks. if inference_max_sequence_len: assert self.activations_checkpoint_method is None, 'inference does not work with activation checkpointing' @@ -1580,6 +1629,7 @@ def forward( if self.inference_params != None: self.inference_params.sequence_len_offset = self.inference_current_sequence_len + attention_probs_list = [] if self.return_select_layer < 0: assert ( parallel_state.get_pipeline_model_parallel_world_size() == 1 @@ -1588,10 +1638,32 @@ def forward( logging.warning("Returning embeddings states only!") return hidden_states + layer_to_encoder_num_mapping = {} + if enc_output_to_layers is not None: + assert len(enc_output_to_layers) == len(encoder_output) + for encoder_idx in range(len(encoder_output)): + for layer_idx in enc_output_to_layers[encoder_idx]: + layer_to_encoder_num_mapping[layer_idx] = encoder_idx + for index in range(self.num_layers): layer = self._get_layer(index) past = None + _encoder_output = encoder_output + _enc_dec_attn_mask = enc_dec_attn_mask + _cross_attention_relative_position_bias = cross_attention_relative_position_bias + _encoder_max_sequence_len = encoder_max_sequence_len + if index in layer_to_encoder_num_mapping: + _encoder_output = encoder_output[layer_to_encoder_num_mapping[index]] + _enc_dec_attn_mask = enc_dec_attn_mask[layer_to_encoder_num_mapping[index]] + _cross_attention_relative_position_bias = cross_attention_relative_position_bias[ + layer_to_encoder_num_mapping[index] + ] + if encoder_max_sequence_len is not None: + _encoder_max_sequence_len = encoder_max_sequence_len[ + layer_to_encoder_num_mapping[index] + ] + if layer_past is not None: past = layer_past[index] @@ -1625,27 +1697,65 @@ def forward( hidden_states = layer( hidden_states, attention_mask, - encoder_output=encoder_output, - enc_dec_attn_mask=enc_dec_attn_mask, + encoder_output=_encoder_output, + enc_dec_attn_mask=_enc_dec_attn_mask, inference_params=self.inference_params, is_first_microbatch=is_first_microbatch, checkpoint_core_attention=checkpoint_core_attention, ) else: - hidden_states = layer( - hidden_states, - attention_mask, - encoder_output=encoder_output, - enc_dec_attn_mask=enc_dec_attn_mask, - layer_past=past, - get_key_value=get_key_value, - set_inference_key_value_memory=set_inference_key_value_memory, - inference_max_sequence_len=inference_max_sequence_len, - rotary_pos_emb=rotary_pos_emb, - self_attention_relative_position_bias=self_attention_relative_position_bias, - cross_attention_relative_position_bias=cross_attention_relative_position_bias, - checkpoint_core_attention=checkpoint_core_attention, - ) + if layer.layer_type == LayerType.decoder and return_all_crossattention_probs: + hidden_states, attention_probs = layer( + hidden_states, + attention_mask, + encoder_output=_encoder_output, + enc_dec_attn_mask=_enc_dec_attn_mask, + layer_past=past, + set_inference_key_value_memory=set_inference_key_value_memory, + inference_max_sequence_len=inference_max_sequence_len, + rotary_pos_emb=rotary_pos_emb, + self_attention_relative_position_bias=self_attention_relative_position_bias, + cross_attention_relative_position_bias=_cross_attention_relative_position_bias, + checkpoint_core_attention=checkpoint_core_attention, + return_crossattention_scores=return_all_crossattention_probs, + decoder_max_sequence_len=decoder_max_sequence_len, + encoder_max_sequence_len=_encoder_max_sequence_len, + ) + attention_probs_list.append(attention_probs) + elif layer.layer_type == LayerType.encoder and return_all_selfattention_probs: + hidden_states, attention_probs = layer( + hidden_states, + attention_mask, + encoder_output=_encoder_output, + enc_dec_attn_mask=_enc_dec_attn_mask, + layer_past=past, + get_key_value=get_key_value, + set_inference_key_value_memory=set_inference_key_value_memory, + inference_max_sequence_len=inference_max_sequence_len, + rotary_pos_emb=rotary_pos_emb, + self_attention_relative_position_bias=self_attention_relative_position_bias, + cross_attention_relative_position_bias=_cross_attention_relative_position_bias, + checkpoint_core_attention=checkpoint_core_attention, + return_selfattention_scores=return_all_selfattention_probs, + ) + attention_probs_list.append(attention_probs) + else: + hidden_states = layer( + hidden_states, + attention_mask, + encoder_output=_encoder_output, + enc_dec_attn_mask=_enc_dec_attn_mask, + layer_past=past, + get_key_value=get_key_value, + set_inference_key_value_memory=set_inference_key_value_memory, + inference_max_sequence_len=inference_max_sequence_len, + rotary_pos_emb=rotary_pos_emb, + self_attention_relative_position_bias=self_attention_relative_position_bias, + cross_attention_relative_position_bias=_cross_attention_relative_position_bias, + checkpoint_core_attention=checkpoint_core_attention, + decoder_max_sequence_len=decoder_max_sequence_len, + encoder_max_sequence_len=_encoder_max_sequence_len, + ) if self.return_select_layer < 0: assert ( @@ -1679,4 +1789,7 @@ def forward( if get_key_value: output = [output, presents] + if return_all_crossattention_probs or return_all_selfattention_probs: + output = [output, attention_probs_list] + return output diff --git a/nemo/collections/nlp/modules/common/megatron/utils.py b/nemo/collections/nlp/modules/common/megatron/utils.py index 601cb7a4d7e8..b0a6f755a9cc 100644 --- a/nemo/collections/nlp/modules/common/megatron/utils.py +++ b/nemo/collections/nlp/modules/common/megatron/utils.py @@ -18,7 +18,6 @@ from typing import Dict, Iterator, List, Optional, Tuple, Union import torch -import torch.nn as nn from torch import Tensor from nemo.utils import logging, logging_mode @@ -474,9 +473,25 @@ def get_iterator_k_split( else: # Split a list of torch tensors assert batch[0].shape[0] % num_microbatches == 0, "Issue with batch size configuration!" - split_batch = [ - torch.tensor_split(item, num_microbatches, dim=0) if torch.is_tensor(item) else item for item in batch - ] + split_batch = [] + for item in batch: + if torch.is_tensor(item): + split_batch.append(torch.tensor_split(item, num_microbatches, dim=0)) + elif isinstance(item, list): + if isinstance(item[0], torch.Tensor): + split_tensors = [torch.tensor_split(elem, num_microbatches, dim=0) for elem in item] + split_tuple = [] + for mbi in range(num_microbatches): + split_tuple.append([split_tensors[i][mbi] for i in range(len(split_tensors))]) + split_tuple = tuple(split_tuple) + split_batch.append(split_tuple) + else: + split_batch.append(split_list(item, num_microbatches)) + elif item is None: + split_batch.append(item) + else: + raise ValueError(f"Unsupported item type: {type(item)}") + microbatches = [ [elem[i] if elem is not None else elem for elem in split_batch] for i in range(num_microbatches) ] diff --git a/nemo/collections/nlp/modules/common/text_generation_utils.py b/nemo/collections/nlp/modules/common/text_generation_utils.py index eeaaea26beac..4743c3216e6a 100644 --- a/nemo/collections/nlp/modules/common/text_generation_utils.py +++ b/nemo/collections/nlp/modules/common/text_generation_utils.py @@ -24,7 +24,7 @@ import numpy as np import torch import torch.nn.functional as F -from lightning_fabric.utilities.seed import seed_everything +from lightning.fabric.utilities.seed import seed_everything from nemo.collections.common.tokenizers.tabular_tokenizer import TabularTokenizer from nemo.collections.multimodal.data.neva.conversation import ( diff --git a/nemo/collections/nlp/parts/megatron_trainer_builder.py b/nemo/collections/nlp/parts/megatron_trainer_builder.py index 7c7360ba3400..11f79baa819a 100644 --- a/nemo/collections/nlp/parts/megatron_trainer_builder.py +++ b/nemo/collections/nlp/parts/megatron_trainer_builder.py @@ -15,11 +15,11 @@ import sys from typing import Optional, Union -from lightning_fabric.utilities.exceptions import MisconfigurationException +from lightning.fabric.utilities.exceptions import MisconfigurationException +from lightning.pytorch import Trainer +from lightning.pytorch.callbacks import ModelSummary +from lightning.pytorch.plugins.environments import TorchElasticEnvironment from omegaconf import DictConfig, open_dict -from pytorch_lightning import Trainer -from pytorch_lightning.callbacks import ModelSummary -from pytorch_lightning.plugins.environments import TorchElasticEnvironment from nemo.collections.common.metrics.perf_metrics import FLOPsMeasurementCallback from nemo.collections.nlp.parts.nlp_overrides import ( diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index 2100e9c1ba8f..73263896af82 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -23,24 +23,24 @@ from pathlib import Path from typing import Any, Callable, Dict, Generator, Iterator, List, Literal, Mapping, Optional, Sized, Union -import pytorch_lightning as pl +import lightning.pytorch as pl import torch -from lightning_fabric.plugins import TorchCheckpointIO -from lightning_fabric.utilities.cloud_io import get_filesystem -from lightning_fabric.utilities.optimizer import _optimizer_to_device +from lightning.fabric.plugins import TorchCheckpointIO +from lightning.fabric.utilities.cloud_io import get_filesystem +from lightning.fabric.utilities.optimizer import _optimizer_to_device +from lightning.pytorch.callbacks.progress import TQDMProgressBar +from lightning.pytorch.callbacks.progress.tqdm_progress import _update_n +from lightning.pytorch.core.optimizer import LightningOptimizer +from lightning.pytorch.loops.fetchers import _DataFetcher +from lightning.pytorch.plugins import ClusterEnvironment +from lightning.pytorch.plugins.io.checkpoint_plugin import CheckpointIO +from lightning.pytorch.plugins.io.wrapper import _WrappingCheckpointIO +from lightning.pytorch.plugins.precision import MixedPrecisionPlugin +from lightning.pytorch.plugins.precision.fsdp import FSDPPrecision +from lightning.pytorch.strategies import DDPStrategy, FSDPStrategy +from lightning.pytorch.trainer.states import TrainerFn +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import OmegaConf -from pytorch_lightning.callbacks.progress import TQDMProgressBar -from pytorch_lightning.callbacks.progress.tqdm_progress import _update_n -from pytorch_lightning.core.optimizer import LightningOptimizer -from pytorch_lightning.loops.fetchers import _DataFetcher -from pytorch_lightning.plugins import ClusterEnvironment -from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO -from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO -from pytorch_lightning.plugins.precision import MixedPrecisionPlugin -from pytorch_lightning.plugins.precision.fsdp import FSDPPrecision -from pytorch_lightning.strategies import DDPStrategy, FSDPStrategy -from pytorch_lightning.trainer.states import TrainerFn -from pytorch_lightning.trainer.trainer import Trainer from torch._C._distributed_c10d import ReduceOp from torch.distributed.algorithms.ddp_comm_hooks.debugging_hooks import noop_hook from torch.distributed.fsdp import BackwardPrefetch, FullStateDictConfig @@ -107,6 +107,7 @@ from megatron.core.tensor_parallel.layers import param_is_not_tensor_parallel_duplicate from megatron.core.transformer.module import Float16Module as MCoreFloat16Module from megatron.core.transformer.transformer_layer import TransformerLayer as MCoreTransformerLayer + from nemo.utils.callbacks.dist_ckpt_io import DistributedCheckpointIO HAVE_MEGATRON_CORE = True @@ -175,9 +176,14 @@ def init_model_parallel( app_state.data_parallel_size = parallel_state.get_data_parallel_world_size() app_state.pipeline_model_parallel_group = parallel_state.get_pipeline_model_parallel_group() - # create MPI process group for UCX-based communication APIs if app_state.init_mpi_proc_group: - torch.distributed.new_group(backend='mpi') + import packaging + + te_version = packaging.version.Version(version('transformer_engine')) + if te_version < packaging.version.Version("1.9"): + # Create MPI process group for bootstrapping at old TE versions. + # From TE version v1.9, the process group is initialized in TE. + torch.distributed.new_group(backend='mpi') class NLPDDPStrategy(DDPStrategy): @@ -376,7 +382,7 @@ def save_checkpoint( self, checkpoint: Dict[str, Any], filepath: Union[str, Path], storage_options: Optional[Any] = None ) -> None: app_state = AppState() - """ PTL method which we override to accomodate distributed checkpoints and + """ PTL method which we override to accomodate distributed checkpoints and the legacy model parallel checkpoints. When using megatron core, the distributed checkpointing library expects save functions to be @@ -1269,6 +1275,7 @@ def restore_from( return_config: bool = False, trainer: Trainer = None, validate_access_integrity: bool = True, + replace_sharded_tensor_key: Optional[str] = None, ): """ Restores model instance (weights and configuration) into .nemo file @@ -1356,6 +1363,9 @@ def dummy(): checkpoint = {} sharded_state_dict = instance.sharded_state_dict() checkpoint['state_dict'] = sharded_state_dict + if replace_sharded_tensor_key: + for v in checkpoint["state_dict"].values(): + v.key = v.key.replace("model", replace_sharded_tensor_key) checkpoint_io = DistributedCheckpointIO.from_config(conf) checkpoint = checkpoint_io.load_checkpoint( diff --git a/nemo/collections/nlp/parts/utils_funcs.py b/nemo/collections/nlp/parts/utils_funcs.py index a989ff3f606c..87fc1aa6f73c 100644 --- a/nemo/collections/nlp/parts/utils_funcs.py +++ b/nemo/collections/nlp/parts/utils_funcs.py @@ -28,9 +28,9 @@ import numpy as np import torch import torch.nn.functional as F +from lightning.pytorch.trainer.trainer import Trainer from matplotlib import pyplot as plt from omegaconf.dictconfig import DictConfig -from pytorch_lightning.trainer.trainer import Trainer from sklearn.metrics import classification_report, confusion_matrix from torch import Tensor diff --git a/nemo/collections/tts/data/speechllm/__init__.py b/nemo/collections/tts/data/speechllm/__init__.py new file mode 100644 index 000000000000..9df65818d226 --- /dev/null +++ b/nemo/collections/tts/data/speechllm/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/tts/data/speechllm/t5_speechllm_dataset.py b/nemo/collections/tts/data/speechllm/t5_speechllm_dataset.py new file mode 100644 index 000000000000..32f0a14f5e65 --- /dev/null +++ b/nemo/collections/tts/data/speechllm/t5_speechllm_dataset.py @@ -0,0 +1,1355 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import enum +import json +import random +from dataclasses import dataclass +from pathlib import Path +from typing import ClassVar, List, Optional, Union + +import numpy as np +import torch +from hydra.utils import instantiate +from omegaconf import OmegaConf +from tqdm.auto import tqdm + +from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer +from nemo.collections.asr.parts.preprocessing.segment import AudioSegment +from nemo.collections.common.tokenizers.text_to_speech.ipa_lexicon import get_ipa_punctuation_list +from nemo.collections.common.tokenizers.text_to_speech.tokenizer_utils import any_locale_text_preprocessing +from nemo.collections.nlp.data.language_modeling.megatron.base_prompt_learning_dataset import BasePromptLearningDataset +from nemo.collections.nlp.models.language_modeling.megatron_t5_model import T5Sentinel +from nemo.collections.nlp.modules.common import VirtualPromptSource +from nemo.collections.nlp.modules.common.megatron.utils import build_position_ids +from nemo.collections.tts.parts.utils.helpers import get_mask_from_lengths +from nemo.collections.tts.parts.utils.tts_dataset_utils import ( + BetaBinomialInterpolator, + beta_binomial_prior_distribution, + general_padding, + get_base_dir, +) +from nemo.utils import logging + +__all__ = ['T5SpeechLMDataset', "Lang"] + + +def get_full_list_puncts(): + punct_set = set() + for locale_id in ["en-US", "de-DE", "fr-FR"]: + punct_list = get_ipa_punctuation_list(locale=locale_id) + punct_set.update(punct_list) + return sorted(punct_set) + + +@dataclass +class G2PConfig: + _target_: str = "nemo.collections.tts.g2p.models.en_us_arpabet.EnglishG2p" + phoneme_dict: str = "scripts/tts_dataset_files/cmudict-0.7b_nv22.10" + heteronyms: str = "scripts/tts_dataset_files/heteronyms-052722" + phoneme_probability: float = 0.5 + + +@dataclass +class EnglishIpaG2pConfig: + _target_: str = "nemo.collections.tts.g2p.models.i18n_ipa.IpaG2p" + phoneme_dict: str = "scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt" + locale: str = "en-US" + heteronyms: str = "scripts/tts_dataset_files/heteronyms-052722" + phoneme_probability: float = 0.5 + grapheme_case: str = "upper" + use_stresses: bool = True + use_chars: bool = True + ignore_ambiguous_words: bool = False + + +@dataclass +class TextTokenizer: + _target_: str = "nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.EnglishPhonemesTokenizer" + punct: bool = True + stresses: bool = True + chars: bool = True + apostrophe: bool = True + pad_with_space: bool = True + add_blank_at: bool = True + g2p: G2PConfig = G2PConfig() + + +@dataclass +class EnglishIpaTextTokenizer: + _target_: str = "nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.IPATokenizer" + locale: str = "en-US" + punct: bool = True + # Define non_default_punct_list as a ClassVar to explicitly mark it as a class variable + non_default_punct_list: ClassVar[List[str]] = get_full_list_puncts() + apostrophe: bool = True + pad_with_space: bool = True + add_blank_at: bool = True + g2p: EnglishIpaG2pConfig = EnglishIpaG2pConfig() + + +@dataclass +class TextTokenizerConfig: + text_tokenizer: TextTokenizer = TextTokenizer() + + +@dataclass +class EnglishIpaTextTokenizerConfig: + text_tokenizer: EnglishIpaTextTokenizer = EnglishIpaTextTokenizer() + + +def _get_default_text_tokenizer_conf(phoneme_probability: float = 0.5, use_ipa: bool = False): + if use_ipa: + g2p = EnglishIpaG2pConfig(phoneme_probability=phoneme_probability) + _text_tokenizer = EnglishIpaTextTokenizer(g2p=g2p) + text_tokenizer: EnglishIpaTextTokenizerConfig = EnglishIpaTextTokenizerConfig(text_tokenizer=_text_tokenizer) + else: + g2p = G2PConfig(phoneme_probability=phoneme_probability) + _text_tokenizer = TextTokenizer(g2p=g2p) + text_tokenizer: TextTokenizerConfig = TextTokenizerConfig(text_tokenizer=_text_tokenizer) + return OmegaConf.create(OmegaConf.to_yaml(text_tokenizer)) + + +def pad_text_to_speech_dims(text_tensor, pad_id, pad_size=7): + token_len = text_tensor.shape[0] + empty_padding = torch.ones((pad_size, token_len), dtype=text_tensor.dtype, device=text_tensor.device) * pad_id + return torch.cat((text_tensor.unsqueeze(0), empty_padding), dim=0) + + +class Lang(enum.Enum): + en = 1 + es = 2 + fr = 3 + zh = 4 + de = 4 + + +class T5SpeechLMDataset(BasePromptLearningDataset): + """ + The dataset class for prompt-tuning or p-tuning pretrained T5 SpeechLM models. + """ + + def __init__( + self, + datasets, + tokenizer, + virtual_prompt_source: VirtualPromptSource, + task_templates: dict, + pseudo_tokens, + pad_token_id: str, + max_seq_length: int, + sample_rate: int, + min_seq_length: int = 1, + add_bos: bool = False, + add_eos: bool = True, + for_train: bool = True, + decoder_starts_with_pad: bool = False, + add_eos_to_decoder_output: bool = True, + add_sentinel_to_input: bool = True, + ul2_prompt_token: str = None, + segment_max_duration: Optional[int] = None, + trim: bool = False, + trim_ref: Optional[float] = None, + trim_top_db: Optional[int] = None, + trim_frame_length: Optional[int] = None, + trim_hop_length: Optional[int] = None, + pad_multiple: int = 1, + pitch_augment: bool = False, + sup_data_path: Optional[Union[Path, str]] = None, + speech_offset: Optional[int] = None, + train_task: Optional[str] = None, + seq_pattern: Optional[str] = "parallel", + use_attention_prior: Optional[bool] = False, + attention_prior_scaling_factor: Optional[float] = 1.0, + spec_aug=False, + spec_aug_time_width=0.2, + spec_aug_time_masks=2, + cross_attention_epsilon: Optional[float] = 0.0, + lm_vocab_size: Optional[int] = None, + num_speech_codebooks: Optional[int] = 8, + codebook_fps: Optional[int] = 86, + add_special_tokens_to_only_first_codebook: Optional[bool] = False, + context_pattern: Optional[str] = "parallel", + context_duration_min: Optional[float] = 3.0, + context_duration_max: Optional[float] = 5.0, + skip_datasets: Optional[List[str]] = [], # substrings of dataset names to skip + english_only_model: Optional[bool] = False, + context_conditioning: Optional[str] = "decoder", # encoder or decoder + use_beta_binomial_interpolator: Optional[str] = False, # encoder or decoder + context_slice_method: Optional[str] = "random", # random or fixed + phoneme_probability: Optional[float] = 0.5, + encoder_type: Optional[str] = "single_transformer", + use_ipa: bool = False, + **kwargs, + ): + """ + Only speech parameters are explained here. + segment_max_duration: Optional[int] = None, - Speech max segment duration + trim: bool = False, - speech parameter + trim_ref: Optional[float] = None, - speech parameter + trim_top_db: Optional[int] = None, - speech parameter + trim_frame_length: Optional[int] = None, - speech parameter + trim_hop_length: Optional[int] = None, - speech parameter + pad_multiple: int = 1, - speech parameter + pitch_augment: bool = False, - speech parameter + sup_data_path: Optional[Union[Path, str]] = None, - Supplementary folder path where codecs are stored. + speech_offset: Optional[int] = None, - if speech tokens then add this offset to the token indices to distinguish between text and speech tokens. + lm_vocab_size: Optional[int] = None, - vocab size of the original language model (phoneme tokens start from this index) + english_only_model: Optional[bool] = False, specify if monolingual or multi-lingual modeling. + use_ipa: bool = False, specify if using IPA tokens or default ARPABET tokens. Either choice still mixes chars. + **kwargs, + """ + # These two variables need to be set before calling super().__init__() because the parent class calls `load_data()` which requires these attributes. + self._rng = random.Random() + self.spec_aug = spec_aug if for_train else False + self.time_width = spec_aug_time_width + self.time_masks = spec_aug_time_masks + self.decoder_starts_with_pad = decoder_starts_with_pad + self.add_eos_to_decoder_output = add_eos_to_decoder_output + self.add_sentinel_to_input = add_sentinel_to_input + self.ul2_prompt_token = ul2_prompt_token + # Speech related variables + self.base_data_dir = None + self.segment_max_duration = segment_max_duration + self.sample_rate = sample_rate + self.featurizer = WaveformFeaturizer(sample_rate=self.sample_rate) + self.pad_multiple = pad_multiple + self.pitch_augment = pitch_augment + self.trim = trim + self.trim_ref = trim_ref if trim_ref is not None else np.max + self.trim_top_db = trim_top_db if trim_top_db is not None else 60 + self.trim_frame_length = trim_frame_length if trim_frame_length is not None else 2048 + self.trim_hop_length = trim_hop_length if trim_hop_length is not None else 512 + self.speech_offset = speech_offset if speech_offset is not None else 3 + self.seq_pattern = seq_pattern + self.use_attention_prior = use_attention_prior + self.attention_prior_scaling_factor = attention_prior_scaling_factor + self.cross_attention_epsilon = cross_attention_epsilon # value of prior for context tokens (b/w 0 and 1) + assert self.cross_attention_epsilon >= 0.0 and self.cross_attention_epsilon <= 1.0 + self.lm_vocab_size = tokenizer.vocab_size if lm_vocab_size is None else lm_vocab_size + self.num_speech_codebooks = num_speech_codebooks + self.codebook_fps = codebook_fps + self.add_special_tokens_to_only_first_codebook = add_special_tokens_to_only_first_codebook + # context_pattern and duration arguments are supported only if context_type is REFSPEAKERCODEC in the manifest + self.context_pattern = context_pattern + self.context_duration_min = context_duration_min + self.context_duration_max = context_duration_max + self.english_only_model = english_only_model + self.phoneme_tokenizer = None + if english_only_model: + self.phoneme_tokenizer = instantiate( + _get_default_text_tokenizer_conf(phoneme_probability=phoneme_probability, use_ipa=use_ipa) + ).text_tokenizer + else: + self.g2p = {"fr": lambda x: x} + if kwargs.get("g2p", None): + if "english" in kwargs["g2p"]: + english_g2p = instantiate(kwargs["g2p"]["english"]) + self.g2p["en"] = lambda x: english_g2p(x) + if "spanish" in kwargs["g2p"]: + spanish_g2p = instantiate(kwargs["g2p"]["spanish"]) + self.g2p["es"] = lambda x: spanish_g2p(x) + if "mandarin" in kwargs["g2p"]: + mandarin_g2p = instantiate(kwargs["g2p"]["mandarin"]) + self.g2p["zh"] = lambda x: mandarin_g2p(x) + if "german" in kwargs["g2p"]: + german_g2p = instantiate(kwargs["g2p"]["german"]) + self.g2p["de"] = lambda x: german_g2p(x) + + self.context_conditioning = context_conditioning + if self.context_conditioning == "decoder": + assert ( + self.context_duration_min == self.context_duration_max + ), "For decoder conditioning, context_duration_min and context_duration_max should be same" + self.decoder_context_len = int( + self.context_duration_min * self.codebook_fps + ) # TODO: Just take from model var? + + # Initialize sup_data_path, sup_data_types and run preprocessing methods for every supplementary data type\ + self.sup_data_path = None + if sup_data_path is not None: + Path(sup_data_path).mkdir(parents=True, exist_ok=True) + self.sup_data_path = sup_data_path + + self.codec_folder = kwargs.pop('codec_folder', None) + self.train_task = train_task + if self.codec_folder is None and self.sup_data_path is not None: + self.codec_folder = Path(self.sup_data_path) / "codec" + elif isinstance(self.codec_folder, str): + self.codec_folder = Path(self.codec_folder) + + self.codec_folder.mkdir(exist_ok=True, parents=True) + + self.context_length = kwargs.pop('context_length', None) # only used in gpt dataset atm + # self.attention_prior_strength = attention_prior_strength + self.transformer_type = kwargs.pop('transformer_type', 'T5') + self.skip_datasets = skip_datasets + + self.beta_binomial_interpolator = ( + BetaBinomialInterpolator(scaling_factor=self.attention_prior_scaling_factor) + if use_beta_binomial_interpolator + else None + ) + self.context_slice_method = context_slice_method + self.encoder_type = encoder_type + super().__init__( + datasets=datasets, + tokenizer=tokenizer, + virtual_prompt_source=virtual_prompt_source, + task_templates=task_templates, + pseudo_tokens=pseudo_tokens, + pad_token_id=pad_token_id, + max_seq_length=max_seq_length, + min_seq_length=min_seq_length, + add_bos=add_bos, + add_eos=add_eos, + for_train=for_train, + ) + + def load_data(self, dataset): + """ + Loads a dataset by filling in the task templates specified in the config file + with the information from each training/inference example. Converts all input + text into token ids. Also replaces the <|VIRTUAL_PROMPT_#|> placeholders in + the task templates with the actual virtual prompt token ids. + + params: + dataset: A list of json objects or a dictionary objects each + containing the information needed for a training example + """ + copy_dataset = list(dataset) + audio_filelist = [] + # This loop is needed to calculate self.base_data_dir. + for json_line in copy_dataset: + if type(json_line) == dict: + doc = json_line + else: + doc = json.loads(json_line) + taskname = doc["taskname"] + prompt_template_fields = self.task_templates[taskname]["prompt_template_fields"] + + for p in prompt_template_fields: + if f"{p}_type" in doc and doc[f"{p}_type"] == "SPEECH": + audio_filelist.append(doc[p]) + self.base_data_dir = get_base_dir(audio_filelist) + + skipped = 0 + tts = 0 + asr = 0 + i = 0 + logging.info(f"copy_dataset len === {len(copy_dataset)}") + examples = [] + for json_line in tqdm(copy_dataset): + i += 1 + + # Read example dict or load the information for a single example from .json file + if type(json_line) == dict: + doc = json_line + else: + doc = json.loads(json_line) + + if self.context_conditioning == "decoder": + # Modify doc to make combine context and anwer + assert ";" not in doc['context'], "Multiple contexts not supported in decoder conditioning" + doc['answer'] = "{};{}".format(doc['context'], doc['answer']) + doc['answer_duration'] = self.context_duration_min + doc['answer_duration'] + doc['answer_type'] = "CONTEXTANSWER" + doc['context_type'] = "DUMMYCONTEXT" + doc['context'] = "DUMMYCONTEXT" + + question_in_manifest = doc['question'] + + if "Text to speech this" in question_in_manifest or "Phoneme TTS" in question_in_manifest: + tts += 1 + if self.train_task not in ['tts', 'all']: + continue + elif "Next token prediction" in question_in_manifest: + if self.train_task != 'tts': + asr += 1 + else: + tts += 1 + continue + else: + if self.train_task == 'tts': + continue + asr += 1 + + if doc["context_type"] == "SPEECH": + assert "context_duration" in doc, f"context_duration key not in document {doc}" + approx_context_len = 3 * (self.codebook_fps + 1) # +1 just to be safe + if self.context_length is not None and doc["context_duration"] < self.context_length: + logging.debug( + f"skipped as context_length of {doc['context_duration']} is less than {self.context_length}" + ) + skipped += 1 + continue + elif "Remove Noise" in question_in_manifest: + approx_context_len = doc["answer_duration"] * (self.codebook_fps + 1) + elif "Extract Speaker Audio" in question_in_manifest: + approx_context_len = ( + doc["answer_duration"] * (self.codebook_fps + 1) + 400 + ) # 400 is the max ref speaker audio + elif ("Text to speech this" in question_in_manifest) or ('Phoneme TTS' in question_in_manifest): + # approx_context_len = 400 + approx_context_len = 5 * ( + self.codebook_fps + 1 + ) # better than 400. TODO: pneekhara: Need to change things for multi-encoder vs single encoder based filtering. + elif "Edit Speech" in question_in_manifest: + approx_context_len = doc["answer_duration"] * (self.codebook_fps + 1) + else: + raise NotImplementedError(f"Unknown context type {doc['context_type']}") + + approx_question_len = len(doc["question"].split(' ')) + 3 + if 'Phoneme TTS' in question_in_manifest: + # approx len is equal to num of characters + approx_question_len = len(question_in_manifest) + + if doc["answer_type"] in ["SPEECH", "AUDIOCODEC", "CONTEXTANSWER"]: + assert "answer_duration" in doc, f"answer_duration key not in document {doc}" + approx_answer_len = doc["answer_duration"] * (self.codebook_fps + 1) + 3 # +3 for EOS, BOS padding + if self.seq_pattern == "delay_parallel": + # In delay parallel, there is padding so add 8 frames + approx_answer_len = approx_answer_len + self.num_speech_codebooks + else: + approx_answer_len = len(doc["answer"].split(' ')) + 3 + + skip_record = False + for skip_dataset in self.skip_datasets: + if skip_dataset in doc['answer']: + skip_record = True + + if not skip_record: + if (self.transformer_type == "GPT") and ( + self.min_seq_length + < approx_context_len + approx_question_len + approx_answer_len + < self.max_seq_length + ): + examples.append(doc) + elif (self.transformer_type == "T5") and ( + self.min_seq_length < approx_context_len + approx_question_len < self.max_seq_length + and self.min_seq_length < approx_answer_len < self.max_seq_length + ): + examples.append(doc) + else: + logging.debug(f"skipped for {approx_context_len + approx_question_len} {approx_answer_len} len") + skipped += 1 + else: + print("Skipping", doc['answer']) + logging.debug(f"skipped for {doc['answer']} as it is in skip_datasets") + skipped += 1 + + logging.info(f'Skipped {skipped} sentences, sequence length too short or too long even after truncation') + + return examples + + def __getitem__(self, idx): + doc = self.examples[idx] + taskname = doc["taskname"] + prompt_template = self.task_templates[taskname]["prompt_template"] + prompt_template_fields = self.task_templates[taskname]["prompt_template_fields"] + total_virtual_tokens = self.task_templates[taskname]["total_virtual_tokens"] + virtual_token_splits = self.task_templates[taskname]["virtual_token_splits"] + truncation_field = self.task_templates[taskname]['truncate_field'] + answer_field = self.task_templates[taskname]["answer_field"] + + input_example = prompt_template + + self._input_sanity_checks( + total_virtual_tokens=total_virtual_tokens, + virtual_token_splits=virtual_token_splits, + prompt_template=prompt_template, + prompt_template_fields=doc.keys(), # Skip this check as we don't need it for TTS + truncation_field=truncation_field, + answer_field=answer_field, + doc=doc, + ) + question_in_manifest = doc['question'] + + # Format the input example according to the template + # Get context, question and answer codes in a dict. + # TODO @xueyang: declare the instructions when initializing the dataset so that they can be re-used. Temporally + # hardcode them here. + question_text = doc["question"].strip() + instructions = ["Phoneme TTS", "Text to speech this"] + for prefix in instructions: + if doc["question"].startswith(prefix): + question_text = doc["question"][len(prefix) :].strip() + break + + input_dict = self._insert_data_in_template(prompt_template_fields, doc, answer_field) + lang = Lang[doc.get("lang", "en")] + context_tokens = input_dict['context'] + question_tokens = input_dict['question'] + + # Logic to prune context + # In case of TTS task, the entire reference speech is not required, so we randomly select a portion + # of the reference audio. + # In case of Next token prediction, We want context[:T] to go in the encoder and context[T+1:] to be + # predicted by the decoder. + start_token_index = 0 + end_token_index = -1 + if ("Text to speech this" in question_in_manifest) and (doc["context_type"] == "SPEECH"): + total_context_len = context_tokens[0].size()[1] + reduced_len = min( + 400, + ( + int(total_context_len * 0.2) + if total_context_len > 600 + else int(total_context_len * random.uniform(0.2, 0.5)) + ), + ) + start_token_index = random.randint( + 0, total_context_len - reduced_len + ) # start index can be greater than 440 + context_tokens[0] = context_tokens[0][ + :, start_token_index : min(start_token_index + 440, start_token_index + reduced_len) + ] + elif "Next token prediction" in question_in_manifest: + total_context_len = context_tokens[0].size()[1] + end_token_index = int(total_context_len * random.uniform(0.01, 0.2)) + context_tokens[0] = context_tokens[0][:, :end_token_index] + + # Get virtual tokens + # `virtual_tokens` is "". + virtual_tokens = self._insert_virtual_token_placeholders(input_example.split(' ')[0], virtual_token_splits) + + # a trick to align with the data format in t5 pretraining + virtual_tokens = self.tokenizer.text_to_ids(virtual_tokens) + if self.add_sentinel_to_input: + question_tokens = question_tokens + self.tokenizer.text_to_ids(T5Sentinel.FIRST.value) + + # Add BOS/EOS to the input of encoder if desired, adds EOS by default + if self.ul2_prompt_token is not None: + ul2_prompt_token_id = self.tokenizer.text_to_ids(self.ul2_prompt_token) + assert len(ul2_prompt_token_id) == 1 + context_tokens = ul2_prompt_token_id + context_tokens + if self.add_bos: + context_tokens = [self.tokenizer.bos_id] + context_tokens + if self.add_eos: + question_tokens = question_tokens + [self.tokenizer.eos_id] + + # Try to truncate input text to fit into the max sequence length + if self._get_len(context_tokens, question_tokens, virtual_tokens) > self.max_seq_length: + context_tokens, question_tokens, virtual_tokens = self._truncate_input_speech( + context_tokens, question_tokens, virtual_tokens + ) + + virtual_tokens, virtual_tokens_len = self.list_to_tensor(virtual_tokens) + context_tokens, context_tokens_len = self.list_to_tensor(context_tokens) + question_tokens, question_tokens_len = self.list_to_tensor(question_tokens) + + if doc["question_type"] == "TEXT" and doc["context_type"] != "TEXT": + question_tokens = pad_text_to_speech_dims( + question_tokens, self.tokenizer.pad_id, self.num_speech_codebooks - 1 + ) + if doc["context_type"] == "TEXT" and doc["question_type"] != "TEXT": + context_tokens = pad_text_to_speech_dims( + context_tokens, self.tokenizer.pad_id, self.num_speech_codebooks - 1 + ) + if doc["context_type"] == "TEXT" and doc["question_type"] == "TEXT": + context_tokens = pad_text_to_speech_dims( + context_tokens, self.tokenizer.pad_id, self.num_speech_codebooks - 1 + ) + question_tokens = pad_text_to_speech_dims( + question_tokens, self.tokenizer.pad_id, self.num_speech_codebooks - 1 + ) + + # context_tokens: tensor, (num_speech_codebooks, audio_context_len) + # question_tokens: tensor, (num_speech_codebooks, instruction token len + question token len + 1 ( + 1 ([SEP])), only first row includes token ids while all other rows are all zeros (pad) + if self.encoder_type == "multi_transformer": + context_and_question_tokens = [context_tokens, question_tokens] + else: + context_and_question_tokens = torch.cat([context_tokens, question_tokens], dim=1) + + # get answer ids + if answer_field in doc.keys(): # training and validation + answer_ids = self._get_tokens(doc, answer_field, doc[answer_field]) + if end_token_index > -1: + answer_ids[0] = answer_ids[0][:, end_token_index:] + + if self.decoder_starts_with_pad: + answer_text_ids = [self.tokenizer.pad_id] + else: + answer_text_ids = [self.tokenizer.bos_id] + # a trick to align with the data format in t5 pretraining + # if self.add_sentinel_to_input: + # answer_text_ids += self.tokenizer.text_to_ids(T5Sentinel.FIRST.value) + answer_text_ids += answer_ids + + if self.add_eos_to_decoder_output: + answer_text_ids += [self.tokenizer.eos_id] + else: + answer_text_ids += self.tokenizer.text_to_ids(T5Sentinel.END.value) + + if self.virtual_prompt_source == VirtualPromptSource.PROMPT_ENCODER: + taskname_id = self.tokenizer.text_to_ids(taskname) + elif ( + self.virtual_prompt_source == VirtualPromptSource.NO_PROMPT + ): # TODO (@adithyare) this class and GPTPromptLearningDataset should be merged. + taskname_id = -1 + else: + raise ValueError("Invalid virtual prompt source specified") + + dec_input = None + dec_labels = None + + # if single-encoder and context_condition is decoder, answer_text_ids = [CLS_id, context audio code tensors, zero-pad, answer audio code tensor, SEP_id] + # if multi-encoder, answer_text_ids = [CLS_id, answer audio codec tensor, SEP_id], so dec_input will not include audio context anymore. + if answer_field in doc.keys(): # training and validation + dec_input = answer_text_ids[:-1] + dec_labels = answer_text_ids[1:] + + # if single-encoder and context_condition is decoder: + # dec_input: shape=(self.num_speech_codebooks, 1([CLS]) + len(context audio frames) + 1([PAD]) + len(answer audio frames)) + # dec_labels: shape=(self.num_speech_codebooks, len(context audio frames) + 1([PAD]) + len(answer audio frames) + 1([SEP])) + # if multi-encoder: + # dec_input: (num_speech_codebooks, 1([CLS]) + len(answer audio frames)) + # dec_labels: (num_speech_codebooks, len(answer audio frames) + 1([SEP])) + dec_input, dec_input_len = self.list_to_tensor(dec_input, True) + dec_labels, dec_labels_len = self.list_to_tensor(dec_labels, True) + is_speech = True if doc["answer_type"] != "TEXT" else False + if is_speech: + assert dec_input.dim() == 2 and dec_labels.dim() == 2 + if self.seq_pattern == "delay_parallel": + num_codebooks = dec_input.shape[0] + dec_input_padded = torch.cat( + [ + torch.zeros_like(dec_input[:, 0:num_codebooks]), + dec_input, + torch.zeros_like(dec_input[:, 0:num_codebooks]), + ], + dim=1, + ) + dec_labels_padded = torch.cat( + [ + torch.zeros_like(dec_labels[:, 0:num_codebooks]), + dec_labels, + torch.zeros_like(dec_labels[:, 0:num_codebooks]), + ], + dim=1, + ) + dec_input_new = [] + dec_labels_new = [] + for _c in range(self.num_speech_codebooks): + st = num_codebooks - _c + et_decoder_input = dec_input_padded.shape[1] - _c + et_decoder_labels = dec_labels_padded.shape[1] - _c + dec_input_new.append(dec_input_padded[_c, st:et_decoder_input]) + dec_labels_new.append(dec_labels_padded[_c, st:et_decoder_labels]) + dec_input = torch.stack(dec_input_new, dim=0) + dec_labels = torch.stack(dec_labels_new, dim=0) + dec_input_len = torch.tensor(dec_input.shape[1]).long() + dec_labels_len = torch.tensor(dec_labels.shape[1]).long() + + if self.encoder_type == "multi_transformer": + enc_len = question_tokens_len + virtual_tokens_len + else: + enc_len = context_tokens_len + question_tokens_len + virtual_tokens_len + # TODO: Remove hardcoding + start_of_question_offset = 4 # For both "Text to Speech this" and "Phoneme TTS" + end_of_question_offset = 2 + cross_attention_prior = torch.zeros(dec_labels_len, enc_len) + self.cross_attention_epsilon + if self.use_attention_prior: + prior_dec_len = dec_labels_len.item() + prior_dec_start_idx = 0 + if self.context_conditioning == "decoder": + prior_dec_len = dec_labels_len.item() - (self.decoder_context_len + 1) + prior_dec_start_idx = self.decoder_context_len + 1 + text_len = question_tokens_len.item() - start_of_question_offset - end_of_question_offset + audio_len = prior_dec_len + if self.beta_binomial_interpolator is not None: + cross_attention_question_prior = torch.from_numpy(self.beta_binomial_interpolator(audio_len, text_len)) + else: + cross_attention_question_prior = torch.from_numpy( + beta_binomial_prior_distribution( + text_len, + audio_len, + scaling_factor=self.attention_prior_scaling_factor, + ) + ) + if self.encoder_type == "multi_transformer": + cross_attention_prior[ + prior_dec_start_idx:, virtual_tokens_len + start_of_question_offset : -end_of_question_offset + ] = cross_attention_question_prior + else: + cross_attention_prior[ + prior_dec_start_idx:, + virtual_tokens_len + context_tokens_len + start_of_question_offset : -end_of_question_offset, + ] = cross_attention_question_prior + + if self.encoder_type == "multi_transformer": + context_and_question_len = [context_tokens_len, question_tokens_len] + else: + context_and_question_len = context_tokens_len + question_tokens_len + return ( + taskname_id, # List, only one item. token id for "squad" + virtual_tokens, # Tensor, shape=(3,). token id for ['', '', ''] + virtual_tokens_len, # tensor, 3 + context_tokens_len, # tensor, 1 + # tensor if single encoder and context_condition is encoder, shape=(self.num_speech_codebooks, 1(context) + question len + 1() + 1([SEP])). only first row includes token ids while all other rows are all zeros (pad). + # list if multi-encoder and context_condition is encoder. + context_and_question_tokens, + # tensor scalar if single encoder and context_condition is decoder, 1 + (question len + 1 + 1). + # list if multi-encoder and context_condition is encoder. + context_and_question_len, + dec_input, # tensor, shape=(self.num_speech_codebooks, 1 CLS + context audio frame len + 1 pad + answer audio frame len), first column is [CLS_id, 0*7]^T + dec_input_len, # scalar tensor, 1 CLS + context audio frame len + 1 pad + answer audio frame len. 1 corresponds to CLS id + dec_labels, # tensor, shape=(self.num_speech_codebooks, context audio frame len + 1 pad + answer frame len + 1 SEP). + dec_labels_len, # tensor, context audio frame len + 1 PAD + answer frame len + 1 SEP. 1 corresponds to SEP id. + is_speech, # True + cross_attention_prior, # tensor, shape=(dec_labels_len, context_tokens_len + question_tokens_len + virtual_tokens_len). + lang.value, # int, + question_text, # str, answer transcript without question type (Phoneme TTS or Text to speech this). + ) + + def _truncate_input_speech(self, context_tokens, question_tokens, virtual_tokens): + total_len = self._get_len(context_tokens, question_tokens, virtual_tokens) + context_len = self._get_element_len(context_tokens) + truncation_length = total_len - self.max_seq_length + 1 + context_tokens[0] = context_tokens[0][:, min(truncation_length, context_len) :] + return context_tokens, question_tokens, virtual_tokens + + def list_to_tensor(self, element, fill=False): + """ + Convert list to tensor. The list might contain integers, 2D-tensors (speech tokens) and combination of two. + If all of them are ints, simply convert to tensor + If combination of 2D-tensor and ints. Convert int to the dimension of the tensor. + example: [2, 4, 5] -> torch.tensor([2, 4, 5]) + example: [2, torch.tensor([[4, 5, 6], [6, 7, 8]])] -> torch.tensor( [[-1, 4, 5, 6], [2, 6, 7, 8]] ) + """ + ret, ln = None, None + if element is None: + return ret, ln + + max_len = max([1 if isinstance(item, int) else len(item) for item in element]) + if max_len == 1: + ret = torch.as_tensor(element).long() + ln = torch.tensor(ret.size()[0]).long() + else: + ret = [] + for e in element: + if isinstance(e, int): + tmp = torch.full((self.num_speech_codebooks, 1), e if fill else -1) + tmp[self.num_speech_codebooks - 1] = e + if self.add_special_tokens_to_only_first_codebook: + # Fill zeros in all other codebooks (to avoid out of range when getting embeddings) + tmp[1:] = 0 + else: + tmp = e + ret.append(tmp) + ret = torch.cat(ret, dim=1) + ln = torch.tensor(ret.size()[1]).long() + return ret, ln + + def _get_text_tokens(self, text): + input_ids = self.tokenizer.text_to_ids(text) + return input_ids + + def _get_phoneme_tokens(self, text, lang="en"): + if self.english_only_model: + input_ids = self.phoneme_tokenizer.encode(text) + input_ids_adjusted = [_id + self.lm_vocab_size for _id in input_ids] + return input_ids_adjusted + else: + text = any_locale_text_preprocessing(text) + input_ids = self.g2p[lang](text) + input_ids_adjusted = [] + for i in input_ids: + input_ids_adjusted.append(f"p{{{i}}}") + input_ids_adjusted = self.tokenizer.text_to_ids("".join(input_ids_adjusted)) + return input_ids_adjusted + + def _pad_wav_to_multiple(self, wav): + if self.pad_multiple > 1: + if wav.shape[0] % self.pad_multiple != 0: + wav = torch.cat( + [wav, torch.zeros(self.pad_multiple - wav.shape[0] % self.pad_multiple, dtype=torch.float)] + ) + return wav + + def _get_element_len(self, element): + length = 0 + if isinstance(element, list): + for e in element: + if isinstance(e, int): + length += 1 + else: + if e.dim() > 1: + length += e.size()[1] + else: + length += e.size()[0] + else: + if element.dim() > 1: + length += element.size()[1] + else: + length += element.size()[0] + return length + + def _get_len(self, context_tokens, question_tokens, virtual_tokens): + length = 0 + length += self._get_element_len(context_tokens) + length += self._get_element_len(question_tokens) + length += self._get_element_len(virtual_tokens) + return length + + def _load_audio(self, audio_filepath, dur=-1): + if self.segment_max_duration is not None and dur > 0 and dur > self.segment_max_duration: + # this case has been added for segmenting audio for speaker verification task of SSLDisentangler + n_segments = int(self.segment_max_duration * self.sample_rate) + features = AudioSegment.segment_from_file( + audio_filepath, target_sr=self.sample_rate, n_segments=n_segments, trim=self.trim + ) + + features = torch.tensor(features.samples) + if self.pad_multiple > 1: + features = self._pad_wav_to_multiple(features) + audio, audio_length = features, torch.tensor(features.shape[0]).long() + else: + features = self.featurizer.process( + audio_filepath, + trim=self.trim, + trim_ref=self.trim_ref, + trim_top_db=self.trim_top_db, + trim_frame_length=self.trim_frame_length, + trim_hop_length=self.trim_hop_length, + ) + + if self.pad_multiple > 1: + features = self._pad_wav_to_multiple(features) + + audio, audio_length = features, torch.tensor(features.shape[0]).long() + + return audio, audio_length + + def convert_audio(self, audio, sample_rate, target_sample_rate, target_channels): + if audio.dim() == 1: + audio = audio.unsqueeze(0).unsqueeze(0) + assert audio.shape[1] in [1, 2], "Audio must be mono or stereo." + # assert sample_rate == target_sample_rate, "sample rate of FastPitch and Encodec model has to be same" + if target_channels == 2: + *shape, _, length = audio.shape + audio = audio.expand(*shape, target_channels, length) + return audio + + def get_codec(self, audio): + wav1 = self.convert_audio(audio, self.sample_rate, self.encodec_model.sample_rate, self.encodec_model.channels) + encoded_frames = self.encodec_model.encode(wav1) + codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1) + return codes.squeeze(0) + + def get_quantizer_codebook(self, reference_codec, reference_codec_length): + out = torch.zeros((1, 128, reference_codec_length.item())) + for i in range(reference_codec.size()[0]): + out += self.encodec_model.quantizer.vq.layers[i].decode(reference_codec[i, :].unsqueeze(0)) + return out.squeeze(0) + + def _get_speech_tokens(self, audio_filepath, dur=-1): + # Let's keep audio name and all internal directories in rel_audio_path_as_text_id to avoid any collisions + rel_audio_path = Path(audio_filepath).relative_to(self.base_data_dir).with_suffix("") + rel_audio_path_as_text_id = str(rel_audio_path).replace("/", "_") + + # Load audio features + audio, audio_length = self._load_audio(audio_filepath, dur) + + # Convert to codes + codec_path = self.codec_folder / f"{rel_audio_path_as_text_id}.pt" + + if codec_path.exists(): + try: + codec_codes = torch.load(codec_path).long() + except Exception as e: + print(f"[ERROR IN LOADING {codec_path}] e") + codec_codes = self.get_codec(audio).long() + torch.save(codec_codes, codec_path) + else: + codec_codes = self.get_codec(audio).long() + torch.save(codec_codes, codec_path) + + # Convert codes to codes corresponding to megatron embedding layer + codec_codes[0] = (codec_codes[0] + self.speech_offset).long() + + return codec_codes + + def _get_tokens(self, doc, field, field_data): + if self.context_slice_method == "random": + # During training, we want a random slice of the context + rng = random.Random() # Custom random generator (since random uses fixed seeds) + elif self.context_slice_method == "fixed": + # During inference, we want a fixed slice of the context + rng = random + else: + raise ValueError(f"Invalid context_slice_method {self.context_slice_method}") + if f"{field}_type" not in doc.keys(): + field_tokens = self._get_text_tokens(field_data.strip(" ")) # list of ids + elif doc[f"{field}_type"] == 'TEXT': + _text = field_data.strip(" ") + if _text.startswith("Phoneme TTS"): + lang = doc.get("lang", "en") + instruction_tokens = self._get_text_tokens("Phoneme TTS") + field_tokens = self._get_phoneme_tokens(_text[len("Phoneme TTS") :].strip(), lang=lang) + field_tokens = instruction_tokens + field_tokens + elif _text.startswith("Edit Speech"): + # Always use phoneme tokenizer for edit speech + instruction_tokens = self._get_text_tokens("Edit Speech") + field_tokens = self._get_phoneme_tokens(_text[len("Edit Speech") :].strip()) + field_tokens = instruction_tokens + field_tokens + elif _text.startswith("TEXT CONTEXT:"): + # Speaker id conditioning + field_tokens = self._get_text_tokens(_text) + # pad field tokens to fixed length + # assert self.context_duration_min == self.context_duration_max, "TEXT CONTEXT only supports fixed context duration" + # To keep context length the same for audio or tex context + # _fixed_context_len = int(self.context_duration_min * self.codebook_fps) + field_tokens = field_tokens + [self.tokenizer.eos_id] + else: + # if starts with Text to speech this + field_tokens = self._get_text_tokens(field_data.strip(" ")) # list of ids + elif doc[f"{field}_type"] == 'SPEECH': + dur = -1 + if f"{field}_duration" in doc: + dur = doc[f"{field}_duration"] + field_tokens = self._get_speech_tokens(field_data, dur) # list of ids + if not isinstance(field_tokens, list): + field_tokens = [field_tokens] + elif doc[f"{field}_type"] == 'AUDIOCODEC': + reference_codec_paths = field_data.split(";") + reference_codec_path = rng.choice(reference_codec_paths) + if self.codec_folder is not None: + reference_codec_path = self.codec_folder / reference_codec_path + field_tokens = torch.load(reference_codec_path).long() + field_tokens[0] = (field_tokens[0] + self.speech_offset).long() + field_tokens = [field_tokens] + # print("AUDIOCODEC", field_tokens.shape) + elif doc[f"{field}_type"] == 'REFSPEAKERCODEC': + reference_codec_paths = field_data.split(";") + reference_codec_path = rng.choice(reference_codec_paths) + if self.codec_folder is not None: + reference_codec_path = self.codec_folder / reference_codec_path + field_tokens = torch.load(reference_codec_path).long() + field_tokens[0] = (field_tokens[0] + self.speech_offset).long() + _min_len = int(self.context_duration_min * self.codebook_fps) + _max_len = int(self.context_duration_max * self.codebook_fps) + reference_codec_len = rng.randint(_min_len, _max_len) + reference_codec_len = min(reference_codec_len, field_tokens.shape[1]) + si = rng.randint(0, field_tokens.shape[1] - reference_codec_len) + field_tokens = field_tokens[:, si : si + reference_codec_len] + if self.context_pattern == "delay_parallel": + field_tokens = torch.cat( + [ + torch.zeros(self.num_speech_codebooks, self.num_speech_codebooks).long(), + field_tokens, + torch.zeros(self.num_speech_codebooks, self.num_speech_codebooks).long(), + ], + dim=1, + ) + new_field_tokens = [] + for _c in range(self.num_speech_codebooks): + st = self.num_speech_codebooks - _c + et = field_tokens.shape[1] - _c + new_field_tokens.append(field_tokens[_c, st:et]) + field_tokens = torch.stack(new_field_tokens, dim=0) + field_tokens = [field_tokens] + elif doc[f"{field}_type"] == 'DUMMYCONTEXT': + field_tokens = torch.zeros(self.num_speech_codebooks, 1).long() + return [field_tokens] + elif doc[f"{field}_type"] == 'CONTEXTANSWER': + # Both Context and Answer are in the field + context_info, answer_codec_path = field_data.split(";") + if self.codec_folder is not None: + context_codec_path = self.codec_folder / context_info + answer_codec_path = self.codec_folder / answer_codec_path + if context_info.startswith("TEXT CONTEXT:"): + context_tokens = self._get_text_tokens(context_info.strip(" ")) + # pad field tokens to fixed length + assert ( + self.context_duration_min == self.context_duration_max + ), "TEXT CONTEXT only supports fixed context duration" + _fixed_context_len = int(self.context_duration_min * self.codebook_fps) + context_tokens = context_tokens + [self.tokenizer.pad_id] * (_fixed_context_len - len(context_tokens)) + + answer_tokens = torch.load(answer_codec_path).long() + answer_tokens[0] = (answer_tokens[0] + self.speech_offset).long() + field_tokens = context_tokens + [self.tokenizer.pad_id] + [answer_tokens] + else: + context_tokens = torch.load(context_codec_path).long() + context_tokens[0] = (context_tokens[0] + self.speech_offset).long() + assert ( + self.context_duration_min == self.context_duration_max + ), "CONTEXTANSWER only supports fixed context duration" + reference_codec_len = int(self.context_duration_min * self.codebook_fps) + if context_tokens.shape[1] < reference_codec_len: + # Repeat the context to match the reference_codec_len + context_tokens = torch.cat( + [context_tokens] * (reference_codec_len // context_tokens.shape[1] + 1), dim=1 + ) + assert ( + context_tokens.shape[1] >= reference_codec_len + ), "CONTEXTANSWER context duration is less than min duration {} {} {}".format( + context_tokens.shape[1], reference_codec_len, context_codec_path + ) + si = rng.randint(0, context_tokens.shape[1] - reference_codec_len) + context_tokens = context_tokens[:, si : si + reference_codec_len] + + answer_tokens = torch.load(answer_codec_path).long() + answer_tokens[0] = (answer_tokens[0] + self.speech_offset).long() + pad_tokens = torch.zeros(self.num_speech_codebooks, 1).long() + # padding between context and answer + field_tokens = torch.cat([context_tokens, pad_tokens, answer_tokens], dim=1) + field_tokens = [field_tokens] + elif doc[f"{field}_type"] == 'SEPARATIONCODECS': + mixed_codec_path, reference_codec_paths = field_data.split(",") + reference_codec_paths = reference_codec_paths.split(";") + reference_codec_path = rng.choice(reference_codec_paths) + mixed_codec = torch.load(mixed_codec_path).long() + reference_codec = torch.load(reference_codec_path).long() + reference_codec_len = rng.randint(240, 400) + reference_codec = reference_codec[:, :reference_codec_len] + # MIXED AUDIO AND REF AUDIO ARE SEPARATED BY 8 TIMESTEPS OF 1023 TOKENS IN ALL CODEBOOKS + mask_tokens = (torch.ones(self.num_speech_codebooks, self.num_speech_codebooks) * 1023).long() + field_tokens = torch.cat([mixed_codec, mask_tokens, reference_codec], dim=1) + field_tokens[0] = (field_tokens[0] + self.speech_offset).long() + field_tokens = [field_tokens] + elif doc[f"{field}_type"] == 'EDITINGCODECS': + reference_audio_path = field_data + reference_codec = torch.load(reference_audio_path).long() + assert reference_codec.shape[1] > 80 # ensure reference audio is atleast 1 second + mask_len = rng.randint(40, 320) # ~0.5 second to 4 seconds + mask_len = min(mask_len, reference_codec.shape[1] - 80) + mask_start = rng.randint(0, reference_codec.shape[1] - mask_len) + mask_end = mask_start + mask_len + mask_tokens = (torch.ones(self.num_speech_codebooks, self.num_speech_codebooks) * 1023).long() + seg1 = reference_codec[:, :mask_start] + seg2 = reference_codec[:, mask_end:] + field_tokens = torch.cat([seg1, mask_tokens, seg2], dim=1) + # MISSING AUDIO IS REPLACED WITH 8 TIMESTEPS OF 1023 TOKENS IN ALL CODEBOOKS + field_tokens[0] = (field_tokens[0] + self.speech_offset).long() + field_tokens = [field_tokens] + else: + raise Exception(f"{field}_type not recognized") + return field_tokens + + def _insert_data_in_template(self, prompt_template_fields, doc, answer_field): + """Format the input example according to the template""" + out_dict = {} + for field in prompt_template_fields: + # discard the last one, {label} / {answer} + # Or if some fields from the template aren't present, e.g. {answer} during inference + # just remove that field from the template, leaving the space blank + if field == answer_field or field not in doc.keys(): + continue + # out_dict[field] = "" + + elif field in doc.keys(): + field_data = doc[field] + if f"{field}_type" not in doc.keys(): + doc[f"{field}_type"] = "TEXT" + raise Exception(f"{field}_type does not exist in doc") + else: + out_dict[field] = self._get_tokens(doc, field, field_data) + return out_dict + + def get_position_ids(self, virtual_token, context_and_qquestion): + enc_input = [] + enc_input.append(virtual_token) + if context_and_qquestion.dim() > 2: + enc_input.append(context_and_qquestion[:, 0, :]) + else: + enc_input.append(context_and_qquestion) + + enc_input = torch.cat(enc_input, dim=1) + + enc_input_p = enc_input[:, 0, :] if enc_input.dim() == 3 else enc_input + return build_position_ids(enc_input_p).contiguous() + + def collate_fn(self, batch): + """Prepares enc_input, dec_input, labels, loss_mask, enc_mask, dec_mask, position_ids, taskname_ids for global batch""" + + data_dict = self.pad_batch_and_build_loss_mask(batch) + + if self.encoder_type == "multi_transformer": + position_ids = [ + self.get_position_ids(data_dict['virtual_tokens'], data_dict['context_and_question_tokens'][0]), + self.get_position_ids(data_dict['virtual_tokens'], data_dict['context_and_question_tokens'][1]), + ] + else: + position_ids = self.get_position_ids(data_dict['virtual_tokens'], data_dict['context_and_question_tokens']) + + return ( + data_dict['virtual_tokens'], + data_dict['context_and_question_tokens'], + data_dict['enc_mask'], + data_dict['dec_input'], + data_dict['dec_input_mask'], + data_dict['dec_labels'], + data_dict['dec_labels_mask'], + position_ids, + data_dict['taskname_id'], + data_dict['speech_mask'], + data_dict['context_and_question_tokens_lens'], + data_dict['cross_attention_prior'], + data_dict['text_limits'], + data_dict['lang'], + data_dict['question_texts'], + ) + + def pad_batch_and_build_loss_mask(self, batch): + """Pad enc_input, dec_input, labels in batch to max batch length while building loss_mask, enc_mask, and dec_mask""" + ( + taskname_ids, + _, + virtual_tokens_len, + _, + _, + context_and_question_tokens_len, + _, + dec_input_len, + _, + dec_labels_len, + _, + _, + _, + question_texts, + ) = zip(*batch) + + taskname_ids = self.pad_taskname_ids(taskname_ids) + + max_virtual_tokens_len = max(virtual_tokens_len).item() if virtual_tokens_len is not None else 0 + if isinstance(virtual_tokens_len, tuple): + virtual_tokens_len = torch.stack(virtual_tokens_len) + virtual_mask = get_mask_from_lengths(virtual_tokens_len) + + if self.encoder_type == "multi_transformer": + max_context_len = ( + max(_c[0] for _c in context_and_question_tokens_len) + if context_and_question_tokens_len is not None + else 0 + ) + max_question_len = ( + max(_c[1] for _c in context_and_question_tokens_len) + if context_and_question_tokens_len is not None + else 0 + ) + max_context_and_question_tokens_len = [max_context_len, max_question_len] + context_len = torch.stack([_c[0] for _c in context_and_question_tokens_len]) + question_len = torch.stack([_c[1] for _c in context_and_question_tokens_len]) + context_mask = get_mask_from_lengths(context_len) + question_mask = get_mask_from_lengths(question_len) + context_and_question_tokens_len = [context_len, question_len] + context_and_question_mask = [context_mask, question_mask] + enc_mask = [ + torch.cat([virtual_mask, context_and_question_mask[0]], dim=1), + torch.cat([virtual_mask, context_and_question_mask[1]], dim=1), + ] + # import ipdb; ipdb.set_trace() + else: + max_context_and_question_tokens_len = ( + max(context_and_question_tokens_len).item() if context_and_question_tokens_len is not None else 0 + ) + if isinstance(context_and_question_tokens_len, tuple): + context_and_question_tokens_len = torch.stack(context_and_question_tokens_len) + context_and_question_mask = get_mask_from_lengths(context_and_question_tokens_len) + enc_mask = torch.cat([virtual_mask, context_and_question_mask], dim=1) + + max_dec_input_len = max(dec_input_len).item() if dec_input_len is not None else 0 + max_dec_labels_len = max(dec_labels_len).item() if dec_labels_len is not None else 0 + + ( + virtual_tokens_list, + context_question_tokens_list, + dec_input_list, + dec_input_mask_list, + dec_labels_list, + dec_labels_mask_list, + speech_mask_list, + cross_attention_prior_list, + text_limits, + lang_list, + ) = ( + [], + [], + [], + [], + [], + [], + [], + [], + [], + [], + ) + + for i, sample_tuple in enumerate(batch): + ( + _, + virtual_token, + virtual_token_len, + context_token_len, + context_and_question_token, + context_and_question_token_len, + dec_input, + dec_input_len, + dec_label, + dec_label_len, + is_speech, + cross_attention_prior, + lang, + _, + ) = sample_tuple + + virtual_tokens_list.append( + general_padding( + virtual_token, virtual_token_len.item(), max_virtual_tokens_len, pad_value=self.tokenizer.pad_id + ) + ) + + if self.encoder_type == "multi_transformer": + context_tokens_padded = general_padding( + context_and_question_token[0], + context_and_question_token_len[0].item(), + max_context_and_question_tokens_len[0], + pad_value=self.tokenizer.pad_id, + ) + if len(context_tokens_padded.shape) < 2: + context_tokens_padded = pad_text_to_speech_dims( + context_tokens_padded, self.tokenizer.pad_id, self.num_speech_codebooks - 1 + ) + question_tokens_padded = general_padding( + context_and_question_token[1], + context_and_question_token_len[1].item(), + max_context_and_question_tokens_len[1], + pad_value=self.tokenizer.pad_id, + ) + if len(question_tokens_padded.shape) < 2: + question_tokens_padded = pad_text_to_speech_dims( + question_tokens_padded, self.tokenizer.pad_id, self.num_speech_codebooks - 1 + ) + context_question_tokens_list.append([context_tokens_padded, question_tokens_padded]) + else: + # This means context and questions are concatenated together + context_tokens_padded = general_padding( + context_and_question_token, + context_and_question_token_len.item(), + max_context_and_question_tokens_len, + pad_value=self.tokenizer.pad_id, + ) + if len(context_tokens_padded.shape) < 2: + context_tokens_padded = pad_text_to_speech_dims( + context_tokens_padded, self.tokenizer.pad_id, self.num_speech_codebooks - 1 + ) + context_question_tokens_list.append(context_tokens_padded) + + if max_dec_input_len > 0: + dec_input_padded = general_padding( + dec_input, dec_input_len.item(), max_dec_input_len, pad_value=self.tokenizer.pad_id + ) + if len(dec_input_padded.shape) < 2: + dec_input_padded = pad_text_to_speech_dims( + dec_input_padded, self.tokenizer.pad_id, self.num_speech_codebooks - 1 + ) + dec_input_list.append(dec_input_padded) + dec_mask = ( + torch.as_tensor(([1] * dec_input_len) + ([0] * (max_dec_input_len - dec_input_len))) + .long() + .contiguous() + ) + dec_input_mask_list.append(dec_mask) + speech_mask = dec_mask if is_speech else torch.zeros(dec_mask.shape) + speech_mask_list.append(speech_mask) + + if max_dec_labels_len > 0: + loss_mask = ( + torch.as_tensor(([1] * dec_label_len) + ([0] * (max_dec_labels_len - dec_label_len))) + .long() + .contiguous() + ) + dec_label_padded = general_padding( + dec_label, dec_label_len.item(), max_dec_labels_len, pad_value=self.tokenizer.pad_id + ) + if len(dec_label_padded.shape) < 2: + dec_label_padded = pad_text_to_speech_dims( + dec_label_padded, self.tokenizer.pad_id, self.num_speech_codebooks - 1 + ) + dec_labels_list.append(dec_label_padded) + dec_labels_mask_list.append(loss_mask) + + _p0 = max_dec_labels_len - dec_label_len + if self.encoder_type == "multi_transformer": + _p1 = ( + max_virtual_tokens_len + + max_context_and_question_tokens_len[1] + - context_and_question_token_len[1] + - virtual_token_len + ) + else: + _p1 = ( + max_virtual_tokens_len + + max_context_and_question_tokens_len + - context_and_question_token_len + - virtual_token_len + ) + + cross_attention_prior_padded = torch.nn.functional.pad( + cross_attention_prior, + pad=(0, _p1, 0, _p0), + mode="constant", + value=1, + ) + cross_attention_prior_list.append(cross_attention_prior_padded) + + if self.encoder_type == "multi_transformer": + _start_of_text_id = virtual_token_len + 4 + _end_of_text_id = _start_of_text_id + ( + context_and_question_token_len[1] - 2 - 4 + ) # -2 for some end tokens + else: + _start_of_text_id = virtual_token_len + context_token_len + 4 + _end_of_text_id = _start_of_text_id + ( + context_and_question_token_len - context_token_len - 2 - 4 + ) # -2 for some end tokens + text_limits.append(torch.tensor([_start_of_text_id.item(), _end_of_text_id.item()])) + lang_list.append(torch.tensor(lang)) + + dec_labels_mask = torch.stack(dec_labels_mask_list) if len(dec_labels_mask_list) > 0 else None + if dec_labels_mask is not None and self.context_conditioning == 'decoder': + # Mask out context tokens from loss computation. +1 for bos/pad in the beginning + dec_labels_mask[:, : self.decoder_context_len + 1] = 0 + + if self.encoder_type == "multi_transformer": + context_batch = torch.stack([c[0] for c in context_question_tokens_list]) + question_batch = torch.stack([c[1] for c in context_question_tokens_list]) + context_and_question_tokens = [context_batch, question_batch] + else: + context_and_question_tokens = torch.stack(context_question_tokens_list) + + data_dict = { + "taskname_id": taskname_ids, + "virtual_tokens": torch.stack(virtual_tokens_list), + "context_and_question_tokens": context_and_question_tokens, + "enc_mask": enc_mask, + "dec_input": torch.stack(dec_input_list) if len(dec_input_list) > 0 else None, + "dec_input_mask": torch.stack(dec_input_mask_list) if len(dec_input_mask_list) > 0 else None, + "dec_labels": torch.stack(dec_labels_list) if len(dec_labels_list) > 0 else None, + "dec_labels_mask": dec_labels_mask, + "speech_mask": torch.stack(speech_mask_list) if len(speech_mask_list) > 0 else None, + "context_and_question_tokens_lens": context_and_question_tokens_len, + "cross_attention_prior": ( + torch.stack(cross_attention_prior_list) if len(cross_attention_prior_list) > 0 else None + ), + "text_limits": ( + torch.stack(text_limits) if len(text_limits) > 0 else None + ), # tensor, valid range of answer transcripts without virtual/instruction/end tokens. + "lang": torch.stack(lang_list), + "question_texts": question_texts, + } + + return data_dict diff --git a/nemo/collections/tts/data/speechllm/t5_speechllm_tarred_dataset.py b/nemo/collections/tts/data/speechllm/t5_speechllm_tarred_dataset.py new file mode 100644 index 000000000000..9b0a4f8d06c2 --- /dev/null +++ b/nemo/collections/tts/data/speechllm/t5_speechllm_tarred_dataset.py @@ -0,0 +1,986 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import io +import os +import random +from dataclasses import dataclass +from typing import List, Optional, Union + +import numpy as np +import torch +import webdataset as wd +from omegaconf import OmegaConf + +from nemo.collections.asr.data.audio_to_text import ( + _speech_collate_fn, + cache_datastore_manifests, + expand_sharded_filepaths, + shard_manifests_if_needed, +) +from nemo.collections.common.parts.preprocessing import collections +from nemo.collections.nlp.models.language_modeling.megatron_t5_model import T5Sentinel +from nemo.collections.nlp.modules.common import VirtualPromptSource +from nemo.collections.nlp.modules.common.megatron.utils import build_position_ids +from nemo.collections.tts.parts.utils.helpers import get_mask_from_lengths +from nemo.collections.tts.parts.utils.tts_dataset_utils import beta_binomial_prior_distribution, general_padding +from nemo.core.classes import IterableDataset +from nemo.utils import logging + +__all__ = ['T5SpeechLMTarredDataset'] + + +@dataclass +class G2PConfig: + _target_: str = "nemo.collections.tts.g2p.models.en_us_arpabet.EnglishG2p" + phoneme_dict: str = "scripts/tts_dataset_files/cmudict-0.7b_nv22.10" + heteronyms: str = "scripts/tts_dataset_files/heteronyms-052722" + phoneme_probability: float = 0.5 + + +@dataclass +class TextTokenizer: + _target_: str = "nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers.EnglishPhonemesTokenizer" + punct: bool = True + stresses: bool = True + chars: bool = True + apostrophe: bool = True + pad_with_space: bool = True + add_blank_at: bool = True + g2p: G2PConfig = G2PConfig() + + +@dataclass +class TextTokenizerConfig: + text_tokenizer: TextTokenizer = TextTokenizer() + + +def _get_default_text_tokenizer_conf(): + text_tokenizer: TextTokenizerConfig = TextTokenizerConfig() + return OmegaConf.create(OmegaConf.to_yaml(text_tokenizer)) + + +def pad_text_to_speech_dims(text_tensor, pad_id): + token_len = text_tensor.shape[0] + empty_padding = torch.ones((7, token_len), dtype=text_tensor.dtype, device=text_tensor.device) * pad_id + return torch.cat((text_tensor.unsqueeze(0), empty_padding), dim=0) + + +class InstructionTuningManifestProcessor: + """ + Class that processes a manifest json file containing paths to audio files, transcripts, and durations (in seconds). + Each new line is a different sample. Example below: + {"audio_filepath": "/path/to/audio.wav", "text_filepath": "/path/to/audio.txt", "duration": 23.147} + ... + {"audio_filepath": "/path/to/audio.wav", "text": "the transcription", "offset": 301.75, "duration": 0.82, "utt": + "utterance_id", "ctm_utt": "en_4156", "side": "A"} + Args: + manifest_filepath: Path to manifest json as described above. Can be comma-separated paths. + parser: Str for a language specific preprocessor or a callable. + max_duration: If audio exceeds this length, do not include in dataset. + min_duration: If audio is less than this length, do not include in dataset. + max_utts: Limit number of utterances. + bos_id: Id of beginning of sequence symbol to append if not None. + eos_id: Id of end of sequence symbol to append if not None. + pad_id: Id of pad symbol. Defaults to 0. + """ + + def __init__( + self, + manifest_filepath: str, + max_duration: Optional[float] = None, + min_duration: Optional[float] = None, + max_seq_length: Optional[float] = None, + max_utts: int = 0, + index_by_file_id: bool = False, + decoder_only_model: bool = False, + use_phoneme_tokenizer: bool = False, + ): + + # ASRAudioText( + self.collection = collections.InstructionTuningAudioText( + manifests_files=manifest_filepath, + min_duration=min_duration, + max_duration=max_duration, + max_seq_length=max_seq_length, + max_number=max_utts, + index_by_file_id=index_by_file_id, + decoder_only_model=decoder_only_model, + use_phoneme_tokenizer=use_phoneme_tokenizer, + ) + + +class _TarredInstructionTuningDataset(IterableDataset): + """ + A similar Dataset to the AudioToCharDataset/AudioToBPEDataset, but which loads tarred audio files. + """ + + def __init__( + self, + audio_tar_filepaths: Union[str, List[str]], + manifest_filepath: str, + sample_rate: int, + shuffle_n: int = 0, + min_duration: Optional[float] = None, + max_duration: Optional[float] = None, + max_seq_length: Optional[float] = None, + shard_strategy: str = "scatter", + shard_manifests: bool = False, + global_rank: int = 0, + world_size: int = 0, + return_sample_id: bool = False, + decoder_only_model: bool = False, + use_phoneme_tokenizer: bool = False, + ): + self.shard_manifests = shard_manifests + + # Shard manifests if necessary and possible and then expand the paths + manifest_filepath = shard_manifests_if_needed( + shard_manifests=shard_manifests, + shard_strategy=shard_strategy, + manifest_filepaths=manifest_filepath, + world_size=world_size, + global_rank=global_rank, + ) + + # If necessary, cache manifests from object store + cache_datastore_manifests(manifest_filepaths=manifest_filepath) + + self.manifest_processor = InstructionTuningManifestProcessor( + manifest_filepath=manifest_filepath, + max_duration=max_duration, + min_duration=min_duration, + max_seq_length=max_seq_length, + max_utts=0, + index_by_file_id=True, # Must set this so the manifest lines can be indexed by file ID + decoder_only_model=decoder_only_model, + use_phoneme_tokenizer=use_phoneme_tokenizer, + ) + + self.len = self._compute_len() + self.return_sample_id = return_sample_id + + audio_tar_filepaths = expand_sharded_filepaths( + sharded_filepaths=audio_tar_filepaths, + shard_strategy=shard_strategy, + world_size=world_size, + global_rank=global_rank, + ) + + if shuffle_n > 0: + # Only shuffle training data tar files + logging.info("Shuffling Tar files") + custom_rng = random.Random() + custom_rng.shuffle(audio_tar_filepaths) + logging.info("Done shuffling Tar files") + logging.info(audio_tar_filepaths[:10]) + + self.sample_rate = sample_rate + + # Put together WebDataset + self._dataset = wd.WebDataset(urls=audio_tar_filepaths, nodesplitter=None) + + if shuffle_n > 0: + self._dataset = self._dataset.shuffle(shuffle_n) + else: + logging.info("WebDataset will not shuffle files within the tar files.") + + self._dataset = ( + self._dataset.rename(key='__key__', answer='pt', context='context.pt') + .to_tuple('key', 'answer', 'context') + .pipe(self._filter) + .pipe(self._loop_offsets) + .map(f=self._build_sample) + ) + + def _filter(self, iterator): + """This function is used to remove samples that have been filtered out by ASRAudioText already. + Otherwise, we would get a KeyError as _build_sample attempts to find the manifest entry for a sample + that was filtered out (e.g. for duration). + Note that if using multi-GPU training, filtering may lead to an imbalance in samples in each shard, + which may make your code hang as one process will finish before the other. + """ + + class TarredAudioFilter: + def __init__(self, collection): + self.iterator = iterator + self.collection = collection + + def __iter__(self): + return self + + def __next__(self): + while True: + audio_filename, answer_bytes, context_bytes = next(self.iterator) + file_id, _ = os.path.splitext(os.path.basename(audio_filename)) + if file_id in self.collection.mapping: + return audio_filename, answer_bytes, context_bytes + + return TarredAudioFilter(self.manifest_processor.collection) + + def _loop_offsets(self, iterator): + """This function is used to iterate through utterances with different offsets for each file.""" + + class TarredAudioLoopOffsets: + def __init__(self, collection): + self.iterator = iterator + self.collection = collection + self.current_fn = None + self.current_bytes = None + self.current_context_bytes = None + self.offset_id = 0 + + def __iter__(self): + return self + + def __next__(self): + if self.current_fn is None: + self.current_fn, self.current_bytes, self.current_context_bytes = next(self.iterator) + self.offset_id = 0 + else: + offset_list = self.collection.mapping[self.current_fn] + if len(offset_list) == self.offset_id + 1: + self.current_fn, self.current_bytes, self.current_context_bytes = next(self.iterator) + self.offset_id = 0 + else: + self.offset_id += 1 + + return self.current_fn, self.current_bytes, self.current_context_bytes, self.offset_id + + return TarredAudioLoopOffsets(self.manifest_processor.collection) + + def _collate_fn(self, batch): + return _speech_collate_fn(batch) + + def _build_sample(self, tup): + """Builds the training sample by combining the data from the WebDataset with the manifest info.""" + audio_filename, encodec, ref_encodec, offset_id = tup + return audio_filename, encodec, ref_encodec, offset_id + + def get_manifest_sample(self, sample_id): + return self.manifest_processor.collection[sample_id] + + def __iter__(self): + return self._dataset.__iter__() + + def _compute_len(self): + if self.shard_manifests and torch.distributed.is_available() and torch.distributed.is_initialized(): + my_len = torch.tensor(len(self.manifest_processor.collection), dtype=torch.int32).cuda() + torch.distributed.all_reduce(my_len) + my_len = my_len.int() + logging.info(f'Sharded manifests: Total length: {my_len}') + else: + my_len = len(self.manifest_processor.collection) + + return my_len + + def __len__(self): + return self.len + + +class T5SpeechLMTarredDataset(_TarredInstructionTuningDataset): + """ + The dataset class for prompt-tuning or p-tuning pretrained T5 SpeechLM models. + """ + + def __init__( + self, + audio_tar_filepaths: Union[str, List[str]], + manifest_filepath: str, + tokenizer, + virtual_prompt_source: VirtualPromptSource, + task_templates: dict, + pseudo_tokens, + pad_token_id: str, + max_seq_length: int, + sample_rate: int, + shuffle_n: int = 0, + min_seq_length: int = 1, + add_bos: bool = False, + add_eos: bool = True, + for_train: bool = True, + decoder_starts_with_pad: bool = False, + add_eos_to_decoder_output: bool = True, + add_sentinel_to_input: bool = True, + ul2_prompt_token: str = None, + segment_max_duration: Optional[int] = None, + trim: bool = False, + trim_ref: Optional[float] = None, + trim_top_db: Optional[int] = None, + trim_frame_length: Optional[int] = None, + trim_hop_length: Optional[int] = None, + pad_multiple: int = 1, + pitch_augment: bool = False, + speech_offset: Optional[int] = None, + train_task: Optional[str] = None, + seq_pattern: Optional[str] = "parallel", + shard_strategy: str = "scatter", + shard_manifests: bool = False, + global_rank: int = 0, + world_size: int = 0, + return_sample_id: bool = False, + decoder_only_model: bool = False, + use_phoneme_tokenizer: Optional[bool] = False, + lm_vocab_size: Optional[int] = None, + use_attention_prior: Optional[bool] = False, + attention_prior_scaling_factor: Optional[float] = 1.0, + cross_attention_epsilon: Optional[float] = 0.0, + num_speech_codebooks: Optional[int] = 8, + **kwargs, + ): + """ + Only speech parameters are explained here. + segment_max_duration: Optional[int] = None, - Speech max segment duration + trim: bool = False, - speech parameter + trim_ref: Optional[float] = None, - speech parameter + trim_top_db: Optional[int] = None, - speech parameter + trim_frame_length: Optional[int] = None, - speech parameter + trim_hop_length: Optional[int] = None, - speech parameter + pad_multiple: int = 1, - speech parameter + pitch_augment: bool = False, - speech parameter + speech_offset: Optional[int] = None, - if speech tokens then add this offset to the token indices to distinguish between text and speech tokens. + **kwargs, + """ + # These two variables need to be set before calling super().__init__() because the parent class calls `load_data()` which requires these attributes. + self.decoder_starts_with_pad = decoder_starts_with_pad + self.add_eos_to_decoder_output = add_eos_to_decoder_output + self.add_sentinel_to_input = add_sentinel_to_input + self.ul2_prompt_token = ul2_prompt_token + # Speech related variables + # self.encodec_model = EncodecModel.encodec_model_24khz() + # self.encodec_model.set_target_bandwidth(6.0) + self.base_data_dir = None + self.segment_max_duration = segment_max_duration + self.sample_rate = sample_rate + # self.featurizer = WaveformFeaturizer(sample_rate=self.sample_rate) + self.pad_multiple = pad_multiple + self.pitch_augment = pitch_augment + self.trim = trim + self.trim_ref = trim_ref if trim_ref is not None else np.max + self.trim_top_db = trim_top_db if trim_top_db is not None else 60 + self.trim_frame_length = trim_frame_length if trim_frame_length is not None else 2048 + self.trim_hop_length = trim_hop_length if trim_hop_length is not None else 512 + self.speech_offset = speech_offset if speech_offset is not None else 3 + self.seq_pattern = seq_pattern + self.min_duration = kwargs.get('min_duration', 0.1) + self.max_duration = kwargs.get('max_duration', 20) + self.use_attention_prior = use_attention_prior + self.attention_prior_scaling_factor = attention_prior_scaling_factor + self.cross_attention_epsilon = cross_attention_epsilon # value of prior for context tokens (b/w 0 and 1) + assert self.cross_attention_epsilon >= 0.0 and self.cross_attention_epsilon <= 1.0 + + self.train_task = train_task + + # Initialized super part + self.tokenizer = tokenizer + self.virtual_prompt_source = virtual_prompt_source + self.task_templates = task_templates + self.pseudo_tokens = pseudo_tokens + self.pseudo_token_ids = set(self.tokenizer.tokens_to_ids(self.pseudo_tokens)) + self.pad_token_id = pad_token_id + self.max_seq_length = max_seq_length + self.min_seq_length = min_seq_length + self.add_bos = add_bos + self.add_eos = add_eos + self.for_train = for_train + self.use_phoneme_tokenizer = use_phoneme_tokenizer + self.examples = [] + self.lm_vocab_size = tokenizer.vocab_size if lm_vocab_size is None else lm_vocab_size + self.num_speech_codebooks = num_speech_codebooks + + assert self.min_seq_length <= max_seq_length, "Min sequence length should be less than or equal to max" + assert self.max_seq_length > 0, "Max sequence length should be greater than 0" + + self.context_length = kwargs.pop('context_length', None) # only used in gpt dataset atm + + logging.info("Loading and tokenizing dataset ... ") + + super().__init__( + audio_tar_filepaths=audio_tar_filepaths, + manifest_filepath=manifest_filepath, + sample_rate=sample_rate, + shuffle_n=shuffle_n, + min_duration=self.min_duration, + max_duration=self.max_duration, + max_seq_length=max_seq_length, + shard_strategy=shard_strategy, + shard_manifests=shard_manifests, + global_rank=global_rank, + world_size=world_size, + return_sample_id=return_sample_id, + decoder_only_model=decoder_only_model, + use_phoneme_tokenizer=use_phoneme_tokenizer, + ) + + self.encodec, self.ref_encodec = None, None + + def _insert_virtual_token_placeholders(self, input_example, virtual_token_splits): + """Insert the correct number of pseudo tokens at the <|VIRTUAL_PROMPT_n|> markers""" + total_inserted_tokens = 0 + + for idx in range(len(virtual_token_splits)): + split_start = total_inserted_tokens + split_end = total_inserted_tokens + virtual_token_splits[idx] + pseudo_tokens_for_split = "".join(self.pseudo_tokens[split_start:split_end]) + input_example = input_example.replace(f'<|VIRTUAL_PROMPT_{idx}|>', pseudo_tokens_for_split) + total_inserted_tokens = split_end + + return input_example + + def pad_taskname_ids(self, taskname_ids): + # Pad taskname_ids to be the same length for the prompt encoder + if self.virtual_prompt_source == VirtualPromptSource.PROMPT_ENCODER: + max_taskname_length = max(len(ids) for ids in taskname_ids) + taskname_ids = [ids + [self.pad_token_id] * (max_taskname_length - len(ids)) for ids in taskname_ids] + taskname_ids = torch.tensor(taskname_ids) + + # Task ids are just used for a look up embeddings for prompt-table + elif self.virtual_prompt_source == VirtualPromptSource.NO_PROMPT: + taskname_ids = torch.tensor(taskname_ids) + + return taskname_ids + + def _build_sample(self, tup): + audio_filename, self.encodec, self.ref_encodec, offset_id = tup + + file_id, _ = os.path.splitext(os.path.basename(audio_filename)) + manifest_idx = self.manifest_processor.collection.mapping[file_id][offset_id] + manifest_entry = self.manifest_processor.collection[manifest_idx] + doc = {} + doc['context'] = manifest_entry.context + doc['context_type'] = manifest_entry.context_type + doc['context_duration'] = manifest_entry.context_duration + doc['answer'] = manifest_entry.answer + doc['answer_type'] = manifest_entry.answer_type + doc['answer_duration'] = manifest_entry.answer_duration + doc['question'] = manifest_entry.question + doc['question_type'] = manifest_entry.question_type + + taskname = "squad" + prompt_template = self.task_templates[taskname]["prompt_template"] + prompt_template_fields = self.task_templates[taskname]["prompt_template_fields"] + virtual_token_splits = self.task_templates[taskname]["virtual_token_splits"] + answer_field = self.task_templates[taskname]["answer_field"] + + input_example = prompt_template + + question_in_manifest = manifest_entry.question + + # Format the input example according to the template + # Get context, question and answer codes in a dict. + input_dict = self._insert_data_in_template(input_example, prompt_template_fields, doc, answer_field) + context_tokens = input_dict['context'] + question_tokens = input_dict['question'] + + # Logic to prune context + # In case of TTS task, the entire reference speech is not required, so we randomly select a portion + # of the reference audio. + # In case of Next token prediction, We want context[:T] to go in the encoder and context[T+1:] to be + # predicted by the decoder. + start_token_index = 0 + end_token_index = -1 + if "Text to speech this" in question_in_manifest: + total_context_len = context_tokens[0].size()[1] + reduced_len = min( + 400, + ( + int(total_context_len * 0.2) + if total_context_len > 600 + else int(total_context_len * random.uniform(0.2, 0.5)) + ), + ) + start_token_index = random.randint( + 0, total_context_len - reduced_len + ) # start index can be greater than 440 + context_tokens[0] = context_tokens[0][ + :, start_token_index : min(start_token_index + 440, start_token_index + reduced_len) + ] + elif "Next token prediction" in question_in_manifest: + total_context_len = context_tokens[0].size()[1] + end_token_index = int(total_context_len * random.uniform(0.01, 0.2)) + context_tokens[0] = context_tokens[0][:, :end_token_index] + + # Get virtual tokens + virtual_tokens = self._insert_virtual_token_placeholders(input_example.split(' ')[0], virtual_token_splits) + + # a trick to align with the data format in t5 pretraining + # new + virtual_tokens = self.tokenizer.text_to_ids(virtual_tokens) + if self.add_sentinel_to_input: + question_tokens = question_tokens + self.tokenizer.text_to_ids(T5Sentinel.FIRST.value) + + # Add BOS/EOS to the input of encoder if desired, adds EOS by default + if self.ul2_prompt_token is not None: + ul2_prompt_token_id = self.tokenizer.text_to_ids(self.ul2_prompt_token) + assert len(ul2_prompt_token_id) == 1 + context_tokens = ul2_prompt_token_id + context_tokens + if self.add_bos: + context_tokens = [self.tokenizer.bos_id] + context_tokens + if self.add_eos: + question_tokens = question_tokens + [self.tokenizer.eos_id] + + # Try to truncate input text to fit into the max sequence length + if self._get_len(context_tokens, question_tokens, virtual_tokens) > self.max_seq_length: + context_tokens, question_tokens, virtual_tokens = self._truncate_input_speech( + context_tokens, question_tokens, virtual_tokens + ) + + virtual_tokens, virtual_tokens_len = self.list_to_tensor(virtual_tokens) + context_tokens, context_tokens_len = self.list_to_tensor(context_tokens) + question_tokens, question_tokens_len = self.list_to_tensor(question_tokens) + + if doc["question_type"] != "SPEECH" and doc["context_type"] == "SPEECH": + question_tokens = pad_text_to_speech_dims(question_tokens, self.tokenizer.pad_id) + if doc["context_type"] != "SPEECH" and doc["question_type"] == "SPEECH": + context_tokens = pad_text_to_speech_dims(context_tokens, self.tokenizer.pad_id) + context_tokens = context_tokens.to(question_tokens.device) + context_and_question_tokens = torch.cat([context_tokens, question_tokens], dim=1) + + # get answer ids + if answer_field in doc.keys(): # training and validation + answer_ids = self._get_tokens(doc, answer_field, doc[answer_field]) + if end_token_index > -1: + answer_ids[0] = answer_ids[0][:, end_token_index:] + + if self.decoder_starts_with_pad: + answer_text_ids = [self.tokenizer.pad_id] + else: + answer_text_ids = [self.tokenizer.bos_id] + + answer_text_ids += answer_ids + + if self.add_eos_to_decoder_output: + answer_text_ids += [self.tokenizer.eos_id] + else: + answer_text_ids += self.tokenizer.text_to_ids(T5Sentinel.END.value) + + # Skip example if the final length doesn't fit length requirements even after truncation + if ( + self.min_seq_length + <= self._get_element_len(context_and_question_tokens) + self._get_element_len(virtual_tokens) + <= self.max_seq_length + and self.min_seq_length <= self._get_element_len(answer_text_ids) <= self.max_seq_length + ): + if self.virtual_prompt_source == VirtualPromptSource.PROMPT_ENCODER: + taskname_id = self.tokenizer.text_to_ids(taskname) + elif ( + self.virtual_prompt_source == VirtualPromptSource.NO_PROMPT + ): # TODO (@adithyare) this class and GPTPromptLearningDataset should be merged. + taskname_id = -1 + else: + raise ValueError("Invalid virtual prompt source specified") + + dec_input = None + dec_labels = None + + if answer_field in doc.keys(): # training and validation + dec_input = answer_text_ids[:-1] + dec_labels = answer_text_ids[1:] + + dec_input, dec_input_len = self.list_to_tensor(dec_input, True) + dec_labels, dec_labels_len = self.list_to_tensor(dec_labels, True) + is_speech = True if doc["answer_type"] == "SPEECH" else False + if is_speech: + assert dec_input.dim() == 2 and dec_labels.dim() == 2 + if self.seq_pattern == "delay_parallel": + num_codebooks = dec_input.shape[0] + dec_input_padded = torch.cat( + [ + torch.zeros_like(dec_input[:, 0:num_codebooks]), + dec_input, + torch.zeros_like(dec_input[:, 0:num_codebooks]), + ], + dim=1, + ) + dec_labels_padded = torch.cat( + [ + torch.zeros_like(dec_labels[:, 0:num_codebooks]), + dec_labels, + torch.zeros_like(dec_labels[:, 0:num_codebooks]), + ], + dim=1, + ) + dec_input_new = [] + dec_labels_new = [] + for _c in range(self.num_speech_codebooks): + st = num_codebooks - _c + et_decoder_input = dec_input_padded.shape[1] - _c + et_decoder_labels = dec_labels_padded.shape[1] - _c + dec_input_new.append(dec_input_padded[_c, st:et_decoder_input]) + dec_labels_new.append(dec_labels_padded[_c, st:et_decoder_labels]) + dec_input = torch.stack(dec_input_new, dim=0) + dec_labels = torch.stack(dec_labels_new, dim=0) + dec_input_len = torch.tensor(dec_input.shape[1]).long() + dec_labels_len = torch.tensor(dec_labels.shape[1]).long() + + enc_len = context_tokens_len + question_tokens_len + virtual_tokens_len + # TODO: Remove hardcoding + num_question_offset = 4 # For "Text to Speech this" + + cross_attention_prior = torch.zeros(dec_labels_len, enc_len) + self.cross_attention_epsilon + if self.use_attention_prior: + cross_attention_question_prior = torch.from_numpy( + beta_binomial_prior_distribution( + question_tokens_len.item() - num_question_offset, + dec_labels_len.item(), + scaling_factor=self.attention_prior_scaling_factor, + ) + ) + cross_attention_prior[:, virtual_tokens_len + context_tokens_len + num_question_offset :] = ( + cross_attention_question_prior + ) + + return ( + taskname_id, + virtual_tokens, + virtual_tokens_len, + context_and_question_tokens, + context_tokens_len + question_tokens_len, + dec_input, + dec_input_len, + dec_labels, + dec_labels_len, + is_speech, + cross_attention_prior, + ) + else: + return None + + def _truncate_input_speech(self, context_tokens, question_tokens, virtual_tokens): + total_len = self._get_len(context_tokens, question_tokens, virtual_tokens) + context_len = self._get_element_len(context_tokens) + truncation_length = total_len - self.max_seq_length + 1 + context_tokens[0] = context_tokens[0][:, min(truncation_length, context_len) :] + return context_tokens, question_tokens, virtual_tokens + + def list_to_tensor(self, element, fill=False): + """ + Convert list to tensor. The list might contain integers, 2D-tensors (speech tokens) and combination of two. + If all of them are ints, simply convert to tensor + If combination of 2D-tensor and ints. Convert int to the dimension of the tensor. + example: [2, 4, 5] -> torch.tensor([2, 4, 5]) + example: [2, torch.tensor([[4, 5, 6], [6, 7, 8]])] -> torch.tensor( [[-1, 4, 5, 6], [2, 6, 7, 8]] ) + """ + ret, ln = None, None + if element is None: + return ret, ln + + max_len = max([1 if isinstance(item, int) else len(item) for item in element]) + if max_len == 1: + ret = torch.as_tensor(element).long() + ln = torch.tensor(ret.size()[0]).long() + else: + ret = [] + for e in element: + if isinstance(e, int): + tmp = torch.full((8, 1), e if fill else -1) + tmp[7] = e + else: + tmp = e + ret.append(tmp) + ret = torch.cat(ret, dim=1) + ln = torch.tensor(ret.size()[1]).long() + return ret, ln + + def _get_text_tokens(self, text): + input_ids = self.tokenizer.text_to_ids(text) + return input_ids + + def _get_phoneme_tokens(self, text): + input_ids = phoneme_tokenizer.encode(text) + input_ids_adjusted = [_id + self.lm_vocab_size for _id in input_ids] + return input_ids_adjusted + + def _pad_wav_to_multiple(self, wav): + if self.pad_multiple > 1: + if wav.shape[0] % self.pad_multiple != 0: + wav = torch.cat( + [wav, torch.zeros(self.pad_multiple - wav.shape[0] % self.pad_multiple, dtype=torch.float)] + ) + return wav + + def _get_element_len(self, element): + length = 0 + if isinstance(element, list): + for e in element: + if isinstance(e, int): + length += 1 + else: + if e.dim() > 1: + length += e.size()[1] + else: + length += e.size()[0] + else: + if element.dim() > 1: + length += element.size()[1] + else: + length += element.size()[0] + return length + + def _get_len(self, context_tokens, question_tokens, virtual_tokens): + length = 0 + length += self._get_element_len(context_tokens) + length += self._get_element_len(question_tokens) + length += self._get_element_len(virtual_tokens) + return length + + def _get_speech_tokens(self, field): + + # Convert to codes + codec_codes, codec_codes_length = None, None # Codes + + if self.train_task == 'tts': + if field == 'context': + self.ref_encodec = torch.load(io.BytesIO(self.ref_encodec), map_location="cpu").long() + codec_codes = self.ref_encodec + elif field == 'answer': + self.encodec = torch.load(io.BytesIO(self.encodec), map_location="cpu").long() + codec_codes = self.encodec + elif self.train_task == 'asr': + if field == 'context': + self.ref_encodec = torch.load(io.BytesIO(self.ref_encodec), map_location="cpu").long() + codec_codes = self.ref_encodec + + # codec_codes_length = torch.tensor(codec_codes.shape[1]).long() + + # Convert codes to codes corresponding to megatron embedding layer + codec_codes[0] = (codec_codes[0] + self.speech_offset).long() + + return codec_codes + + def _get_tokens(self, doc, field, field_data): + if f"{field}_type" not in doc.keys(): + field_tokens = self._get_text_tokens(field_data.strip(" ")) # list of ids + elif doc[f"{field}_type"] == 'TEXT': + _text = field_data.strip(" ") + if self.use_phoneme_tokenizer: + instruction_tokens = self._get_text_tokens("Phoneme TTS") + field_tokens = self._get_phoneme_tokens(_text.replace("Text to speech this ", "")) + field_tokens = instruction_tokens + field_tokens + else: + field_tokens = self._get_text_tokens(_text) # list of ids + elif doc[f"{field}_type"] == 'SPEECH': + field_tokens = self._get_speech_tokens(field) # list of ids + if not isinstance(field_tokens, list): + field_tokens = [field_tokens] + elif doc[f"{field}_type"] == 'TOKENS': + # Do nothing; already tokenized + field_tokens = field_data + else: + raise Exception(f"{field}_type not recognized") + return field_tokens + + def _insert_data_in_template(self, input_example, prompt_template_fields, doc, answer_field): + """Format the input example according to the template""" + out_dict = {} + for field in prompt_template_fields: + # discard the last one, {label} / {answer} + # Or if some fields from the template aren't present, e.g. {answer} during inference + # just remove that field from the template, leaving the space blank + if field == answer_field or field not in doc.keys(): + continue + # out_dict[field] = "" + + elif field in doc.keys(): + field_data = doc[field] + if f"{field}_type" not in doc.keys(): + doc[f"{field}_type"] = "TEXT" + raise Exception(f"{field}_type does not exist in doc") + else: + out_dict[field] = self._get_tokens(doc, field, field_data) + return out_dict + + def get_position_ids(self, virtual_token, context_and_qquestion): + enc_input = [] + enc_input.append(virtual_token) + if context_and_qquestion.dim() > 2: + enc_input.append(context_and_qquestion[:, 0, :]) + else: + enc_input.append(context_and_qquestion) + + enc_input = torch.cat(enc_input, dim=1) + + enc_input_p = enc_input[:, 0, :] if enc_input.dim() == 3 else enc_input + return build_position_ids(enc_input_p).contiguous() + + def collate_fn(self, batch): + """Prepares enc_input, dec_input, labels, loss_mask, enc_mask, dec_mask, position_ids, taskname_ids for global batch""" + + data_dict = self.pad_batch_and_build_loss_mask(batch) + + position_ids = self.get_position_ids(data_dict['virtual_tokens'], data_dict['context_and_question_tokens']) + + return ( + data_dict['virtual_tokens'], + data_dict['context_and_question_tokens'], + data_dict['enc_mask'], + data_dict['dec_input'], + data_dict['dec_input_mask'], + data_dict['dec_labels'], + data_dict['dec_labels_mask'], + position_ids, + data_dict['taskname_id'], + data_dict['speech_mask'], + data_dict['context_and_question_tokens_lens'], + data_dict['cross_attention_prior'], + ) + + def pad_batch_and_build_loss_mask(self, batch): + """Pad enc_input, dec_input, labels in batch to max batch length while building loss_mask, enc_mask, and dec_mask""" + ( + taskname_ids, + _, + virtual_tokens_len, + _, + context_and_question_tokens_len, + _, + dec_input_len, + _, + dec_labels_len, + _, + _, + ) = zip(*batch) + + taskname_ids = self.pad_taskname_ids(taskname_ids) + + max_virtual_tokens_len = max(virtual_tokens_len).item() if virtual_tokens_len is not None else 0 + if isinstance(virtual_tokens_len, tuple): + virtual_tokens_len = torch.stack(virtual_tokens_len) + virtual_mask = get_mask_from_lengths(virtual_tokens_len) + + max_context_and_question_tokens_len = ( + max(context_and_question_tokens_len).item() if context_and_question_tokens_len is not None else 0 + ) + if isinstance(context_and_question_tokens_len, tuple): + context_and_question_tokens_len = torch.stack(context_and_question_tokens_len) + context_and_question_mask = get_mask_from_lengths(context_and_question_tokens_len) + + max_dec_input_len = max(dec_input_len).item() if dec_input_len is not None else 0 + max_dec_labels_len = max(dec_labels_len).item() if dec_labels_len is not None else 0 + enc_mask = torch.cat([virtual_mask, context_and_question_mask], dim=1) + + ( + virtual_tokens_list, + context_question_tokens_list, + dec_input_list, + dec_input_mask_list, + dec_labels_list, + dec_labels_mask_list, + speech_mask_list, + cross_attention_prior_list, + ) = ( + [], + [], + [], + [], + [], + [], + [], + [], + ) + + for i, sample_tuple in enumerate(batch): + ( + _, + virtual_token, + virtual_token_len, + context_and_question_token, + context_and_question_token_len, + dec_input, + dec_input_len, + dec_label, + dec_label_len, + is_speech, + cross_attention_prior, + ) = sample_tuple + + virtual_tokens_list.append( + general_padding( + virtual_token, virtual_token_len.item(), max_virtual_tokens_len, pad_value=self.tokenizer.pad_id + ) + ) + + context_tokens_padded = general_padding( + context_and_question_token, + context_and_question_token_len.item(), + max_context_and_question_tokens_len, + pad_value=self.tokenizer.pad_id, + ) + if len(context_tokens_padded.shape) < 2: + context_tokens_padded = pad_text_to_speech_dims(context_tokens_padded, self.tokenizer.pad_id) + context_question_tokens_list.append(context_tokens_padded) + + if max_dec_input_len > 0: + dec_input_padded = general_padding( + dec_input, dec_input_len.item(), max_dec_input_len, pad_value=self.tokenizer.pad_id + ) + if len(dec_input_padded.shape) < 2: + dec_input_padded = pad_text_to_speech_dims(dec_input_padded, self.tokenizer.pad_id) + dec_input_list.append(dec_input_padded) + dec_mask = ( + torch.as_tensor(([1] * dec_input_len) + ([0] * (max_dec_input_len - dec_input_len))) + .long() + .contiguous() + ) + dec_input_mask_list.append(dec_mask) + speech_mask = dec_mask if is_speech else torch.zeros(dec_mask.shape) + speech_mask_list.append(speech_mask) + + if max_dec_labels_len > 0: + loss_mask = ( + torch.as_tensor(([1] * dec_label_len) + ([0] * (max_dec_labels_len - dec_label_len))) + .long() + .contiguous() + ) + dec_label_padded = general_padding( + dec_label, dec_label_len.item(), max_dec_labels_len, pad_value=self.tokenizer.pad_id + ) + if len(dec_label_padded.shape) < 2: + dec_label_padded = pad_text_to_speech_dims(dec_label_padded, self.tokenizer.pad_id) + dec_labels_list.append(dec_label_padded) + dec_labels_mask_list.append(loss_mask) + + _p0 = max_dec_labels_len - dec_label_len + _p1 = ( + max_virtual_tokens_len + + max_context_and_question_tokens_len + - context_and_question_token_len + - virtual_token_len + ) + + cross_attention_prior_padded = torch.nn.functional.pad( + cross_attention_prior, + pad=(0, _p1, 0, _p0), + mode="constant", + value=1, + ) + cross_attention_prior_list.append(cross_attention_prior_padded) + + data_dict = { + "taskname_id": taskname_ids, + "virtual_tokens": torch.stack(virtual_tokens_list), + "context_and_question_tokens": torch.stack(context_question_tokens_list), + "enc_mask": enc_mask, + "dec_input": torch.stack(dec_input_list) if len(dec_input_list) > 0 else None, + "dec_input_mask": torch.stack(dec_input_mask_list) if len(dec_input_mask_list) > 0 else None, + "dec_labels": torch.stack(dec_labels_list) if len(dec_labels_list) > 0 else None, + "dec_labels_mask": torch.stack(dec_labels_mask_list) if len(dec_labels_mask_list) > 0 else None, + "speech_mask": torch.stack(speech_mask_list) if len(speech_mask_list) > 0 else None, + "context_and_question_tokens_lens": context_and_question_tokens_len, + "cross_attention_prior": ( + torch.stack(cross_attention_prior_list) if len(cross_attention_prior_list) > 0 else None + ), + } + + return data_dict diff --git a/nemo/collections/tts/g2p/models/ctc.py b/nemo/collections/tts/g2p/models/ctc.py index 2e180e766211..1859b09594ff 100644 --- a/nemo/collections/tts/g2p/models/ctc.py +++ b/nemo/collections/tts/g2p/models/ctc.py @@ -19,8 +19,8 @@ import torch from hydra.utils import instantiate +from lightning.pytorch import Trainer from omegaconf import DictConfig, ListConfig, OmegaConf, open_dict -from pytorch_lightning import Trainer from transformers import AutoConfig, AutoModel, AutoTokenizer from nemo.collections.tts.g2p.data.ctc import CTCG2PBPEDataset @@ -101,11 +101,21 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.decoding = CTCBPEDecoding(self.cfg.decoding, tokenizer=self.tokenizer) - self.wer = WER(decoding=self.decoding, use_cer=False, log_prediction=False, dist_sync_on_step=True,) - self.per = WER(decoding=self.decoding, use_cer=True, log_prediction=False, dist_sync_on_step=True,) + self.wer = WER( + decoding=self.decoding, + use_cer=False, + log_prediction=False, + dist_sync_on_step=True, + ) + self.per = WER( + decoding=self.decoding, + use_cer=True, + log_prediction=False, + dist_sync_on_step=True, + ) def setup_grapheme_tokenizer(self, cfg): - """ Initialized grapheme tokenizer """ + """Initialized grapheme tokenizer""" if self.mode == "byt5": # Load appropriate tokenizer from HuggingFace @@ -315,7 +325,10 @@ def _setup_infer_dataloader(self, cfg: DictConfig) -> 'torch.utils.data.DataLoad ) @torch.no_grad() - def _infer(self, config: DictConfig,) -> List[int]: + def _infer( + self, + config: DictConfig, + ) -> List[int]: """ Runs model inference. diff --git a/nemo/collections/tts/g2p/models/heteronym_classification.py b/nemo/collections/tts/g2p/models/heteronym_classification.py index 54b9a8b07413..47d08eb16e17 100644 --- a/nemo/collections/tts/g2p/models/heteronym_classification.py +++ b/nemo/collections/tts/g2p/models/heteronym_classification.py @@ -19,8 +19,8 @@ import torch from hydra.utils import instantiate +from lightning.pytorch import Trainer from omegaconf import DictConfig -from pytorch_lightning import Trainer from nemo.collections.common.losses import CrossEntropyLoss from nemo.collections.nlp.metrics.classification_report import ClassificationReport @@ -113,9 +113,9 @@ def make_step(self, batch): def training_step(self, batch, batch_idx): """ - Lightning calls this inside the training loop with the data from the training dataloader - passed in as `batch`. - """ + Lightning calls this inside the training loop with the data from the training dataloader + passed in as `batch`. + """ loss, logits = self.make_step(batch) self.log('train_loss', loss) @@ -267,7 +267,11 @@ def disambiguate( item = {"text_graphemes": cur_sentence, "start_end": cur_start_ends, "heteronym_span": cur_heteronyms} f.write(json.dumps(item, ensure_ascii=False) + '\n') - all_preds = self._disambiguate(manifest=tmp_manifest, batch_size=batch_size, num_workers=num_workers,) + all_preds = self._disambiguate( + manifest=tmp_manifest, + batch_size=batch_size, + num_workers=num_workers, + ) if wordid_to_phonemes_file is not None: self.set_wordid_to_phonemes(wordid_to_phonemes_file) diff --git a/nemo/collections/tts/g2p/models/t5.py b/nemo/collections/tts/g2p/models/t5.py index 19f976081687..4c673b18dc4a 100644 --- a/nemo/collections/tts/g2p/models/t5.py +++ b/nemo/collections/tts/g2p/models/t5.py @@ -17,8 +17,8 @@ import torch from hydra.utils import instantiate +from lightning.pytorch import Trainer from omegaconf import DictConfig, OmegaConf -from pytorch_lightning import Trainer from transformers import AutoTokenizer, T5ForConditionalGeneration from nemo.collections.asr.metrics.wer import word_error_rate diff --git a/nemo/collections/tts/g2p/models/zh_cn_pinyin.py b/nemo/collections/tts/g2p/models/zh_cn_pinyin.py index 985897d8df3f..2fe0ac3f6077 100644 --- a/nemo/collections/tts/g2p/models/zh_cn_pinyin.py +++ b/nemo/collections/tts/g2p/models/zh_cn_pinyin.py @@ -93,7 +93,7 @@ def __init__( self.ascii_letter_dict = { x: ascii_letter_prefix + x for x in get_grapheme_character_set(locale="en-US", case=ascii_letter_case) } - self.ascii_letter_list = sorted(self.ascii_letter_dict) + self.ascii_letter_list = sorted(self.ascii_letter_dict.values()) self.ascii_letter_case = ascii_letter_case if apply_to_oov_word is None: @@ -181,6 +181,7 @@ def __call__(self, text: str) -> List[str]: `['wo3', 'jin1', 'tian1', 'qu4', 'le5', 'A', 'p', 'p', 'l', 'e', ' ', 'S', 't', 'o', 'r', 'e', ',', ' ', 'mai3', 'le5', 'yi2', 'ge4', 'i', 'P', 'h', 'o', 'n', 'e', '。']` """ + err = False text = set_grapheme_case(text, case=self.ascii_letter_case) pinyin_seq = [] @@ -201,7 +202,15 @@ def __call__(self, text: str) -> List[str]: tone_hyp = pinyin[-1] if tone_hyp in self.tone_dict: syllable = pinyin[:-1] - assert syllable in self.phoneme_dict, f"Syllable <{syllable}> does not exist in the dictionary." + # TODO: skipping the syllable that does not exist in the dictionary will lead to deletion errors in the + # synthesized speech. Even though this case is uncommon, it should be fixed in future. + if syllable not in self.phoneme_dict: + err = True + logging.error( + f"Syllable <{syllable}> does not exist in the dictionary. You should expect symbol " + f"deletion risks!!" + ) + continue phoneme_seq += self.phoneme_dict[syllable] phoneme_seq.append(self.tone_dict[tone_hyp]) # All pinyin would end up with a number in 1-5, which represents tones of the pinyin. @@ -211,4 +220,6 @@ def __call__(self, text: str) -> List[str]: phoneme_seq.append(self.ascii_letter_dict[tone_hyp]) else: phoneme_seq.append(pinyin) + if err: + logging.error(f"|{text}| contained unknown syllables") return phoneme_seq diff --git a/nemo/collections/tts/models/aligner.py b/nemo/collections/tts/models/aligner.py index d8e65d6e6821..5fea8615f7f2 100644 --- a/nemo/collections/tts/models/aligner.py +++ b/nemo/collections/tts/models/aligner.py @@ -18,9 +18,9 @@ import omegaconf import torch from hydra.utils import instantiate +from lightning.pytorch import Trainer +from lightning.pytorch.loggers import WandbLogger from omegaconf import DictConfig -from pytorch_lightning import Trainer -from pytorch_lightning.loggers import WandbLogger from torch import nn from nemo.collections.tts.losses.aligner_loss import BinLoss, ForwardSumLoss diff --git a/nemo/collections/tts/models/audio_codec.py b/nemo/collections/tts/models/audio_codec.py index 0c5e41157613..230a24e36cb0 100644 --- a/nemo/collections/tts/models/audio_codec.py +++ b/nemo/collections/tts/models/audio_codec.py @@ -21,8 +21,8 @@ import torch.nn.functional as F from einops import rearrange from hydra.utils import instantiate +from lightning.pytorch import Trainer from omegaconf import DictConfig, OmegaConf, open_dict -from pytorch_lightning import Trainer from nemo.collections.tts.losses.audio_codec_loss import ( FeatureMatchingLoss, diff --git a/nemo/collections/tts/models/fastpitch.py b/nemo/collections/tts/models/fastpitch.py index b1e702c89124..34213303abf4 100644 --- a/nemo/collections/tts/models/fastpitch.py +++ b/nemo/collections/tts/models/fastpitch.py @@ -18,9 +18,9 @@ import torch from hydra.utils import instantiate +from lightning.pytorch import Trainer +from lightning.pytorch.loggers import TensorBoardLogger from omegaconf import DictConfig, OmegaConf, open_dict -from pytorch_lightning import Trainer -from pytorch_lightning.loggers import TensorBoardLogger from nemo.collections.common.parts.preprocessing import parsers from nemo.collections.tts.losses.aligner_loss import BinLoss, ForwardSumLoss diff --git a/nemo/collections/tts/models/fastpitch_ssl.py b/nemo/collections/tts/models/fastpitch_ssl.py index fe743edf8783..f2384c41c5b5 100644 --- a/nemo/collections/tts/models/fastpitch_ssl.py +++ b/nemo/collections/tts/models/fastpitch_ssl.py @@ -16,9 +16,9 @@ import torch from hydra.utils import instantiate +from lightning.pytorch import Trainer +from lightning.pytorch.loggers import TensorBoardLogger from omegaconf import DictConfig -from pytorch_lightning import Trainer -from pytorch_lightning.loggers import TensorBoardLogger from nemo.collections.tts.losses.fastpitchloss import DurationLoss, MelLoss, PitchLoss from nemo.collections.tts.modules.fastpitch import FastPitchSSLModule, average_features @@ -34,7 +34,7 @@ class FastPitchModel_SSL(ModelPT): """ FastPitch based model that can synthesize mel spectrograms from content and speaker embeddings - obtained from SSLDisentangler. This model can be used for voice conversion by swapping the speaker embedding + obtained from SSLDisentangler. This model can be used for voice conversion by swapping the speaker embedding of a given source utterance, with the speaker embedding of a target speaker. """ @@ -133,9 +133,21 @@ def tb_logger(self): return self._tb_logger def forward( - self, *, enc_out=None, enc_mask=None, durs=None, pitch=None, pace=1.0, + self, + *, + enc_out=None, + enc_mask=None, + durs=None, + pitch=None, + pace=1.0, ): - return self.fastpitch(enc_out=enc_out, enc_mask=enc_mask, durs=durs, pitch=pitch, pace=pace,) + return self.fastpitch( + enc_out=enc_out, + enc_mask=enc_mask, + durs=durs, + pitch=pitch, + pace=pace, + ) def compute_encoding(self, content_embedding, speaker_embedding, dataset_id=None): # content embedding is (B, C, T) @@ -177,7 +189,11 @@ def training_step(self, batch, batch_idx): enc_mask = enc_mask[:, :, None] mels_pred, _, _, log_durs_pred, pitch_pred, pitch = self( - enc_out=enc_out, enc_mask=enc_mask, durs=durs, pitch=pitch, pace=1.0, + enc_out=enc_out, + enc_mask=enc_mask, + durs=durs, + pitch=pitch, + pace=1.0, ) loss = 0 @@ -208,7 +224,10 @@ def training_step(self, batch, batch_idx): ) spec_predict = mels_pred[0].data.cpu().float().numpy() self.tb_logger.add_image( - "train_mel_predicted", plot_spectrogram_to_numpy(spec_predict), self.global_step, dataformats="HWC", + "train_mel_predicted", + plot_spectrogram_to_numpy(spec_predict), + self.global_step, + dataformats="HWC", ) return loss @@ -286,7 +305,10 @@ def on_validation_epoch_end(self, outputs): ) spec_predict = spec_predict[_rand_idx].data.cpu().float().numpy() self.tb_logger.add_image( - "val_mel_predicted", plot_spectrogram_to_numpy(spec_predict), self.global_step, dataformats="HWC", + "val_mel_predicted", + plot_spectrogram_to_numpy(spec_predict), + self.global_step, + dataformats="HWC", ) if self.pitch_conditioning: @@ -321,10 +343,10 @@ def generate_wav( ): """ Args: - content_embedding : Content embedding from SSL backbone (B, C, T) + content_embedding : Content embedding from SSL backbone (B, C, T) speaker_embedding : Speaker embedding from SSL backbone (B, C) pitch_contour : Normalized Pitch contour derived from the mel spectrogram - encoded_len: Length of each content embedding, optional if batch size is 1. + encoded_len: Length of each content embedding, optional if batch size is 1. compute_pitch: if true, predict pitch contour from content and speaker embedding. compute_duration: if true, predict duration from content and speaker embedding. durs_gt: Ground truth duration of each content embedding, ignored if compute_duration is True. diff --git a/nemo/collections/tts/models/hifigan.py b/nemo/collections/tts/models/hifigan.py index 7a9a6d30671f..1a5462349c4d 100644 --- a/nemo/collections/tts/models/hifigan.py +++ b/nemo/collections/tts/models/hifigan.py @@ -18,8 +18,8 @@ import torch import torch.nn.functional as F from hydra.utils import instantiate +from lightning.pytorch.loggers.wandb import WandbLogger from omegaconf import DictConfig, OmegaConf, open_dict -from pytorch_lightning.loggers.wandb import WandbLogger from nemo.collections.tts.losses.hifigan_losses import DiscriminatorLoss, FeatureMatchingLoss, GeneratorLoss from nemo.collections.tts.models.base import Vocoder @@ -313,7 +313,7 @@ def stft(x): comp = torch.stft(x.squeeze(1), n_fft=1024, hop_length=256, win_length=1024, return_complex=True) comp = torch.view_as_real(comp) real, imag = comp[..., 0], comp[..., 1] - mags = torch.sqrt(real ** 2 + imag ** 2) + mags = torch.sqrt(real**2 + imag**2) phase = torch.atan2(imag, real) return mags, phase diff --git a/nemo/collections/tts/models/mixer_tts.py b/nemo/collections/tts/models/mixer_tts.py index c260df22e3c0..58b7f6f9706b 100644 --- a/nemo/collections/tts/models/mixer_tts.py +++ b/nemo/collections/tts/models/mixer_tts.py @@ -20,9 +20,9 @@ import transformers import wandb from hydra.utils import instantiate +from lightning.pytorch import Trainer +from lightning.pytorch.loggers import WandbLogger from omegaconf import DictConfig -from pytorch_lightning import Trainer -from pytorch_lightning.loggers import WandbLogger from torch import nn from torch.nn import functional as F from transformers import AlbertTokenizer diff --git a/nemo/collections/tts/models/radtts.py b/nemo/collections/tts/models/radtts.py index 82f85d1ed6a2..3f04f2ca3908 100644 --- a/nemo/collections/tts/models/radtts.py +++ b/nemo/collections/tts/models/radtts.py @@ -15,9 +15,9 @@ import torch from hydra.utils import instantiate +from lightning.pytorch import Trainer +from lightning.pytorch.loggers import TensorBoardLogger from omegaconf import DictConfig, OmegaConf -from pytorch_lightning import Trainer -from pytorch_lightning.loggers import TensorBoardLogger from nemo.collections.common.tokenizers.text_to_speech.tts_tokenizers import BaseTokenizer from nemo.collections.tts.losses.radttsloss import AttentionBinarizationLoss, RADTTSLoss diff --git a/nemo/collections/tts/models/spectrogram_enhancer.py b/nemo/collections/tts/models/spectrogram_enhancer.py index 65934d9a10ce..3644a77eb6fe 100644 --- a/nemo/collections/tts/models/spectrogram_enhancer.py +++ b/nemo/collections/tts/models/spectrogram_enhancer.py @@ -43,9 +43,9 @@ import torch.nn.functional as F from einops import rearrange from hydra.utils import instantiate +from lightning.pytorch import Trainer +from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger from omegaconf import DictConfig -from pytorch_lightning import Trainer -from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger from torch.utils.tensorboard.writer import SummaryWriter from nemo.collections.common.parts.utils import mask_sequence_tensor diff --git a/nemo/collections/tts/models/speechllm/__init__.py b/nemo/collections/tts/models/speechllm/__init__.py new file mode 100644 index 000000000000..9df65818d226 --- /dev/null +++ b/nemo/collections/tts/models/speechllm/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/tts/models/speechllm/megatron_base_speechllm_prompt_model.py b/nemo/collections/tts/models/speechllm/megatron_base_speechllm_prompt_model.py new file mode 100644 index 000000000000..658ace21726f --- /dev/null +++ b/nemo/collections/tts/models/speechllm/megatron_base_speechllm_prompt_model.py @@ -0,0 +1,444 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re + +import torch +from lightning.pytorch.trainer.trainer import Trainer +from omegaconf.dictconfig import DictConfig +from torch import Tensor + +from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer +from nemo.collections.nlp.metrics.prompt_learning_metrics import AccuracyScore, BLEUScore, ROUGEScores +from nemo.collections.nlp.models.language_modeling.megatron_base_model import MegatronBaseModel +from nemo.collections.nlp.modules.common import ( + PromptEncoder, + PromptEncoderType, + VirtualPromptPlaceholderToken, + VirtualPromptSource, + VirtualPromptStyle, +) +from nemo.collections.nlp.modules.common.transformer.text_generation import TextGeneration +from nemo.collections.nlp.parts import utils_funcs +from nemo.utils import AppState + +try: + from apex.transformer.pipeline_parallel.utils import _reconfigure_microbatch_calculator + + HAVE_APEX = True + +except (ImportError, ModuleNotFoundError): + HAVE_APEX = False + +try: + from megatron.core import parallel_state + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + + HAVE_MEGATRON_CORE = False + + +__all__ = ['MegatronBaseSpeechLM'] + + +class MegatronBaseSpeechLM(MegatronBaseModel, TextGeneration): + """ + Model class for prompt-tuning or p-tuning a pretrained Megatron model. + + Prompt Tuning initalizes virtual prompt embeddings directly from a copy of + certain token embeddings from the the pretrained model's vocabulary + and directly tunes these embedding weights. The token embeddings used in + initalization are specified by the user in the config file. The model can + be prompt-tuned for multiple tasks at once. virtual prompts are stored in a + prompt table and can be added or deleted without disrupting virtual prompts + for other tasks. + + P-tuning initializes an LSTM encoder model that generates virtual prompt + embeddings for every task. Each task shares the same encoder. After ptuning + is compelete, the learned virtual prompts can be saved to the prompt table + using add_ptuned_prompts_to_prompt_table(). Thus, if a user wants to add a + new virtual prompt via p-tuning, they do not need to retrain on all previous + tasks. This gives p-tuning the same task flexiblity as prompt-tuning. + """ + + def __init__(self, cfg: DictConfig, trainer: Trainer): + super().__init__(cfg, trainer) + self.init_model(cfg, trainer) + self.config = self.model_parallel_config + + def init_model(self, cfg: DictConfig, trainer: Trainer): + self.cfg = cfg + + self.load_frozen_model(cfg, trainer) + self.prompt_encoder = None + self.tokenizer = self.frozen_model.tokenizer + + if hasattr(self.frozen_model.cfg, "encoder") and hasattr(self.frozen_model.cfg, "decoder"): + self.hidden_size = ( + self.frozen_model.cfg.encoder.hidden_size + ) # Encoder and decoder need to have the same hidden size and we check for this in the frozen enc-dec model. + else: + self.hidden_size = self.frozen_model.cfg.hidden_size + + self.existing_tasks = list(self.cfg.get('existing_tasks', [])) + self.new_tasks = list(self.cfg.get('new_tasks', [])) + self.virtual_prompt_style = VirtualPromptStyle(cfg.virtual_prompt_style) + + # Load templates for assigning virtual prompt token positions + self.load_task_templates(self.cfg.task_templates) + + if self.first_stage_of_pipeline() and self.virtual_prompt_style in [ + VirtualPromptStyle.P_TUNING, + ]: + # TODO: Handle this when moving GPT prompt learning to the base class. + self.word_embeddings = self.frozen_model.enc_dec_model.encoder_embedding.word_embeddings + + # P-Tuning uses an LSTM Encoder to produce virtual token embeddings + if self.virtual_prompt_style == VirtualPromptStyle.P_TUNING: + self.virtual_prompt_source = VirtualPromptSource.PROMPT_ENCODER + elif self.virtual_prompt_style == VirtualPromptStyle.NO_PROMPT: + self.virtual_prompt_source = VirtualPromptSource.NO_PROMPT + else: + raise ValueError(f"\nvirtual prompt style '{cfg.virtual_prompt_style}'") + + self._reduced_loss_buffer = [] + self._inference_config = None + + # Prepare pseudo token ids for virtual/virtual prompt tokens + self.pseudo_tokens = get_pseudo_tokens(self.max_virtual_tokens) + if isinstance(self.tokenizer, SentencePieceTokenizer): + self.tokenizer.add_special_tokens(self.pseudo_tokens) + else: + self.tokenizer.add_special_tokens({'additional_special_tokens': self.pseudo_tokens}) + self.pseudo_token_ids = self.tokenizer.tokens_to_ids(self.pseudo_tokens) + self.pseudo_token_ids_start = self.pseudo_token_ids[0] if self.pseudo_token_ids else None + self.pad_token_id = self.tokenizer.pad_id if self.tokenizer.pad_id is not None else self.tokenizer.unk_id + self.decoder_seq_length = cfg.get('decoder_seq_length', 40) + + self.autocast_dtype = utils_funcs.torch_dtype_from_precision(self.cfg.precision) # Mixed precision datatype + # make sure the default pytorch lightning gradient clipping in the basemodel + self.grad_clip_pl_default = True + self.lowest_val_loss = None + self.prompt_encoder = None + + self.enable_autocast = not self.megatron_amp_O2 and self.autocast_dtype in [torch.float16, torch.bfloat16] + + # define validation metric + if self.cfg.get('report_validation_metric', False): + validation_metric = self.cfg.get('validation_metric', 'accuracy') + if validation_metric == 'accuracy': + self.validation_metric = AccuracyScore() + elif validation_metric == 'bleu': + self.validation_metric = BLEUScore() + elif validation_metric == 'rouge': + self.validation_metric = ROUGEScores() + + def load_task_templates(self, task_templates): + """ + Takes in the task template portion of the config and turns + it into a table where each task's prompt template and + the number of virtual tokens to insert in a given part of + the prompt template are specified. + """ + self.task_templates = {} + self.task_id_num_to_name = {} + self.max_virtual_tokens = 0 + + task_id_num = 0 + for task in task_templates: + self.task_templates[task.taskname] = { + "prompt_template": task.prompt_template, + "prompt_template_fields": re.findall("\{(.*?)\}", task.prompt_template), + "answer_only_loss": task.get("answer_only_loss", False), + "answer_field": task.get("answer_field", None), + "truncate_field": task.truncate_field, + "total_virtual_tokens": task.total_virtual_tokens, + "virtual_token_splits": task.virtual_token_splits, + "task_id_num": task_id_num, + } + + self.max_virtual_tokens = max(self.max_virtual_tokens, task.total_virtual_tokens) + self.task_id_num_to_name[task_id_num] = task.taskname + task_id_num += 1 + + # Check that all new tasks have the same total num virtual tokens + # Num virtual tokens for new tasks don't need to match num used for previously tuned tasks + if self.new_tasks: + new_task_name = self.new_tasks[0] + self.total_new_task_virtual_tokens = self.task_templates[new_task_name]["total_virtual_tokens"] + + assert all( + self.task_templates[taskname]["total_virtual_tokens"] == self.total_new_task_virtual_tokens + for taskname in self.new_tasks + ), "Total virtual tokens for each task tuned simultaneously must match. If you want to use a different number of virtual tokens for different tasks, tune them separately." + + def init_prompt_encoder(self): + """ + Init the prompt encoder needed for p-tuning on a new task + """ + # Total virtual tokens should be the same across all new tasks, so just need one + new_task = self.new_tasks[0] + total_virtual_tokens = self.task_templates[new_task]["total_virtual_tokens"] + + encoder_type = PromptEncoderType(self.cfg.p_tuning.get("encoder_type", "tpmlp").lower()) + self.prompt_encoder = PromptEncoder( + config=self.model_parallel_config, + encoder_type=encoder_type, + total_virtual_tokens=total_virtual_tokens, + token_dim=self.hidden_size, + hidden_size=self.cfg.p_tuning.get("encoder_hidden", self.hidden_size // 2), + lstm_dropout=self.cfg.p_tuning.get("dropout", 0.0), + num_layers=self.cfg.p_tuning.get("num_layers", 2), + init_std=self.cfg.p_tuning.get("init_std", 0.023), + taskname=new_task, + ) + + def freeze_existing_word_embeddings(self): + """Freeze params of existing virtual prompts that should not be tuned further""" + # Make sure word embeddings are frozen + for params in self.word_embeddings.parameters(): + params.requires_grad = False + + def state_dict(self): + """ + Custom state dict that only contains prompt table and prompt encoder parameters. + No frozen model parameters are stored in the state dict. Prompt encoder parameters + are only in state dict for intermediate checkpoints saved during training. Final + nemo checkpoints at the end of training will contain prompt table parameters only. + """ + state_dict_ = {} + state_dict_["frozen_model_enc_dec_model"] = self.frozen_model.enc_dec_model.state_dict() + state_dict_["word_embeddings"] = self.word_embeddings.state_dict() + if self.prompt_encoder is not None: + state_dict_["prompt_encoder"] = self.prompt_encoder.state_dict() + + return state_dict_ + + def load_state_dict(self, state_dict, strict: bool = True): + """ + Custom load state dict method that only loads prompt table and prompt encoder + parameters. Matching load method for this class' custom state dict method. + """ + self.init_prompt_encoder() + self.frozen_model.enc_dec_model.load_state_dict(state_dict["frozen_model_enc_dec_model"], strict) + self.word_embeddings.load_state_dict(state_dict["word_embeddings"], strict) + if 'prompt_encoder' in state_dict: + self.prompt_encoder.load_state_dict(state_dict["prompt_encoder"], strict) + + # Not sure why when we resume training the prompt encoder is on cpu + # Because it's not created on init - Should really be moved to init + self.prompt_encoder.to("cuda") + + def embed_input(self, input_ids: Tensor, taskname_ids: Tensor, use_cached_reps: bool): + """ + Replaces the virtual tokens in the input_ids with embeddings + calculated from either the 'prompt_table' or 'prompt_encoder'. + The virtual token placeholders have token_ids listed in + `self.pseudo_token_ids`. + + params: + input_ids: the input token ids + taskname_ids: the NLP task tag token ids + returns: + the token embedding for the LM model. + """ + # Replace virtual token ids with padding for forward pass through vocab embeddings + discrete_token_ids = input_ids.clone() + discrete_token_ids[(input_ids >= self.pseudo_token_ids_start)] = self.pad_token_id + discrete_token_embeds = self.word_embeddings(discrete_token_ids).clone() + + # Find the indicies where virtual tokens should be inserted + virtual_token_locations = input_ids >= self.pseudo_token_ids_start + + # If there are no virtual tokens, just return discrete token embeds + if not virtual_token_locations.any(): + return discrete_token_embeds + + if self.virtual_prompt_source == VirtualPromptSource.PROMPT_ENCODER: + # taskname_embeddings = self.word_embeddings(taskname_ids) + batch_size, _ = taskname_ids.size() + virtual_token_embeds = self.prompt_encoder(batch_size=batch_size, use_cached_reps=use_cached_reps) + else: + raise ValueError("invalid VirtualPromptSource.") + + # Create index template specifying where virtual token embeddings should be placed + batch_size, _, embedding_size = discrete_token_embeds.shape + virtual_token_index = virtual_token_locations.nonzero().reshape((batch_size, -1, 2))[:, :, 1][:, :, None] + virtual_token_index = virtual_token_index.expand( + batch_size, self.total_new_task_virtual_tokens, embedding_size + ) + + # Make sure discrete_token_embeds and virtual_token_embeds share the same dtype + discrete_token_embeds = discrete_token_embeds.type(virtual_token_embeds.dtype) + + # Insert virtual token embeddings where they belong amoung the discrete token embeddings + discrete_token_embeds.scatter_(1, virtual_token_index, virtual_token_embeds) + input_embeds = discrete_token_embeds + + return input_embeds + + def on_train_end(self): + # Save p-tuned prompts to prompt table for inference or future task training + self.save_to(save_path=self.cfg.nemo_path) + + def setup(self, stage=None): + if stage == 'predict' and self.first_stage_of_pipeline(): + return + + self.setup_test_data() + if stage == 'test': + return + + if self.first_stage_of_pipeline(): + if self.virtual_prompt_style == VirtualPromptStyle.P_TUNING: + if self.prompt_encoder is None: + self.init_prompt_encoder() + + self.setup_training_data() + self.setup_validation_data() + + def setup_training_data(self, training_data_config=None): + if self.cfg.data.get('train_ds', None): + self._train_ds, self._train_dl = self.build_virtual_prompt_dataset( + dataset_paths=self.cfg.data.train_ds, + batch_size=self.cfg.global_batch_size, + for_train=True, + drop_last=True, + shuffle=True, + num_workers=self.cfg.data.num_workers, + pin_memory=True, + ) + elif self.cfg.data.get('train_manifest', None): + self._train_ds, self._train_dl = self.build_virtual_prompt_tarred_dataset( + dataset_paths=self.cfg.data.train_manifest, + audio_path=self.cfg.data.train_audio_path, + batch_size=self.cfg.global_batch_size, + for_train=True, + drop_last=True, + shuffle=self.cfg.data.shuffle, + num_workers=self.cfg.data.num_workers, + pin_memory=True, + ) + + def setup_validation_data(self, validation_data_config=None): + if self.cfg.data.get('validation_ds', None): + self._validation_ds, self._validation_dl = self.build_virtual_prompt_dataset( + dataset_paths=self.cfg.data.validation_ds, + batch_size=self.cfg.get("validation_global_batch_size", self.cfg.global_batch_size), + for_train=True, + drop_last=self.cfg.get("validation_drop_last", True), + shuffle=False, + num_workers=self.cfg.data.num_workers, + pin_memory=True, + ) + elif self.cfg.data.get('validation_manifest', None): + self._validation_ds, self._validation_dl = self.build_virtual_prompt_tarred_dataset( + dataset_paths=self.cfg.data.validation_manifest, + audio_path=self.cfg.data.validation_audio_path, + batch_size=self.cfg.get("validation_global_batch_size", self.cfg.global_batch_size), + for_train=True, + drop_last=self.cfg.get("validation_drop_last", True), + shuffle=0, + num_workers=self.cfg.data.num_workers, + pin_memory=True, + ) + + def setup_test_data(self, test_data_config=None): + if self.cfg.data.get('test_ds', None): + self._test_ds, self._test_dl = self.build_virtual_prompt_dataset( + dataset_paths=self.cfg.data.test_ds, + batch_size=self.cfg.get("validation_global_batch_size", self.cfg.global_batch_size), + for_train=False, + drop_last=False, + shuffle=False, + num_workers=self.cfg.data.num_workers, + pin_memory=True, + ) + elif self.cfg.data.get('test_manifest', None): + self._test_ds, self._test_dl = self.build_virtual_prompt_tarred_dataset( + dataset_paths=self.cfg.data.test_manifest, + audio_path=self.cfg.data.test_audio_path, + batch_size=self.cfg.global_batch_size, + for_train=False, + drop_last=False, + shuffle=0, + num_workers=self.cfg.data.num_workers, + pin_memory=True, + ) + + def _reconfigure_and_process_inference_batch(self, global_batch_size_per_gpu, gbs): + # This should happen only on the last batch of the dataset. + if global_batch_size_per_gpu != gbs // parallel_state.get_data_parallel_world_size(): + # NOTE: This is reconfiguring to make sure there is no grad-acc for validation batches. + app_state = AppState() + _reconfigure_microbatch_calculator( + rank=app_state.global_rank, + rampup_batch_size=None, + global_batch_size=global_batch_size_per_gpu * parallel_state.get_data_parallel_world_size(), + micro_batch_size=global_batch_size_per_gpu, + data_parallel_size=parallel_state.get_data_parallel_world_size(), + ) + + def _reconfigure_batch_sizes(self, gbs: int, mbs: int): + app_state = AppState() + _reconfigure_microbatch_calculator( + rank=app_state.global_rank, + rampup_batch_size=None, + global_batch_size=gbs, + micro_batch_size=mbs, + data_parallel_size=parallel_state.get_data_parallel_world_size(), + ) + + def set_inference_config(self, inference_config): + self._inference_config = inference_config + + def get_inference_config(self): + return self._inference_config + + def set_input_tensor(self, input_tensor): + pass + + def first_stage_of_pipeline(self): + pass + + @classmethod + def list_available_models(cls): + pass + + def load_frozen_model(self, cfg, trainer): + pass + + +def get_pseudo_tokens(num_virtual_tokens): + """ + Takes in an integer and returns a list of strings where each string + is a numbered virtual token placeholder. If + num_virtual_tokens = 3, then this function returns: + + ["", "", ""] + + Args: + num_virtual_tokens: (int) Number of virtual token strings you want to make + + returns a list of string. + + """ + pseudo_tokens = [ + VirtualPromptPlaceholderToken.BASE.value + str(i) + VirtualPromptPlaceholderToken.END.value + for i in range(num_virtual_tokens) + ] + + return pseudo_tokens diff --git a/nemo/collections/tts/models/speechllm/megatron_t5_speechllm_model.py b/nemo/collections/tts/models/speechllm/megatron_t5_speechllm_model.py new file mode 100644 index 000000000000..d35d53b3cac7 --- /dev/null +++ b/nemo/collections/tts/models/speechllm/megatron_t5_speechllm_model.py @@ -0,0 +1,2672 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import json +import os +import random +import string +from functools import partial +from typing import Any, List + +import editdistance +import imageio +import numpy as np +import soundfile as sf +import torch +from lightning.pytorch.trainer.trainer import Trainer +from omegaconf import OmegaConf +from omegaconf.dictconfig import DictConfig +from omegaconf.omegaconf import open_dict + +import nemo.collections.asr as nemo_asr +from nemo.collections.asr.metrics.wer import word_error_rate +from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceSpeechLLMTTSTokenizer +from nemo.collections.nlp.models.language_modeling.megatron_t5_model import MegatronT5Model +from nemo.collections.nlp.models.language_modeling.megatron_t5_sft_model import MegatronT5SFTModel +from nemo.collections.nlp.modules.common.megatron.token_level_encoder_decoder import ( + MegatronTokenLevelEncoderDecoderSpeechLLMModule, +) +from nemo.collections.nlp.modules.common.megatron.utils import ( + average_losses_across_data_parallel_group, + get_iterator_k_split, + init_method_normal, +) +from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector +from nemo.collections.nlp.parts.utils_funcs import get_last_rank +from nemo.collections.tts.data.speechllm.t5_speechllm_dataset import Lang, T5SpeechLMDataset +from nemo.collections.tts.data.speechllm.t5_speechllm_tarred_dataset import T5SpeechLMTarredDataset +from nemo.collections.tts.losses.aligner_loss import ForwardSumLoss +from nemo.collections.tts.models import AudioCodecModel +from nemo.collections.tts.models.speechllm.megatron_base_speechllm_prompt_model import MegatronBaseSpeechLM +from nemo.collections.tts.parts.utils.helpers import plot_alignment_to_numpy_for_speechllm, plot_codec_to_numpy +from nemo.utils import AppState, logging + +try: + from apex.transformer.pipeline_parallel.utils import get_micro_batch_size, get_num_microbatches + + HAVE_APEX = True + +except (ImportError, ModuleNotFoundError): + + HAVE_APEX = False + +try: + from megatron.core import parallel_state, tensor_parallel + from megatron.core.enums import ModelType + from megatron.core.pipeline_parallel.schedules import get_forward_backward_func + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + + HAVE_MEGATRON_CORE = False + + +import time + +import librosa +from torchaudio.pipelines import SQUIM_SUBJECTIVE +from transformers import Wav2Vec2FeatureExtractor, WavLMForXVector + +__all__ = ['MegatronT5SpeechLMModel'] + + +class MegatronT5OverrideModel(MegatronT5Model): + def _build_tokenizer(self): + if self._cfg.tokenizer.library == "sentencepiece": + if hasattr(self._cfg.tokenizer, "sentencepiece_legacy"): + legacy = self._cfg.tokenizer.sentencepiece_legacy + else: + legacy = True if self._cfg.tokenizer.library == 'sentencepiece' else False + self.tokenizer = SentencePieceSpeechLLMTTSTokenizer( + model_path=self.register_artifact("tokenizer.model", self._cfg.tokenizer.get('model', None)), + legacy=legacy, + ) + + if self._cfg.tokenizer.get('additional_special_tokens', None) is not None: + tokens_list = OmegaConf.to_object(self._cfg.tokenizer.additional_special_tokens) + self.tokenizer.add_special_tokens(tokens_list) + else: + super()._build_tokenizer() + + def model_provider_func(self, pre_process, post_process, add_encoder, add_decoder): + if not hasattr(self.cfg, 'encoder') or not hasattr(self.cfg, 'decoder'): + logging.warning( + 'Could not find encoder or decoder in config. This is probably because of restoring an old checkpoint. Copying shared model configs to encoder and decoder configs.' + ) + # After the call below, self.cfg.encoder and self.cfg.decoder will be populated with the cfg.model configs from old checkpoints. + self._populate_encoder_decoder_configs_for_backward_compatibility(self.cfg) + + if parallel_state.get_pipeline_model_parallel_world_size() > 1 and self.cfg.encoder.arch == 'perceiver': + raise ValueError(f"Perceivers with pipeline parallel > 1 is not supported yet.") + + if not hasattr(self.cfg, 'embedding_init_method_std'): + embedding_init_method_std = self.cfg.encoder.init_method_std + else: + embedding_init_method_std = self.cfg.embedding_init_method_std + + if not hasattr(self.cfg, 'embedding_dropout'): + embedding_dropout = self.cfg.encoder.hidden_dropout + else: + embedding_dropout = self.cfg.embedding_dropout + + model = MegatronTokenLevelEncoderDecoderSpeechLLMModule( + config=self.model_parallel_config, + encoder_cfg=self.cfg.encoder, + decoder_cfg=self.cfg.decoder, + vocab_size=self.padded_vocab_size, + max_position_embeddings=self.cfg.max_position_embeddings, + num_tokentypes=0, + parallel_output=True, + pre_process=pre_process, + post_process=post_process, + fp16_cross_entropy=self.cfg.get('fp16_lm_cross_entropy', False), + precision=self.cfg.get('precision', 16), + embedding_init_method_std=embedding_init_method_std, + embedding_dropout=embedding_dropout, + label_smoothing=self.cfg.get('label_smoothing', 0.0), + add_encoder=add_encoder, + add_decoder=add_decoder, + share_token_embeddings=self.cfg.get('share_token_embeddings', True), + share_decoder_tokens_head_embeddings=self.cfg.get('share_decoder_tokens_head_embeddings', True), + tokens_head_bias=self.cfg.get('tokens_head_bias', True), + hiddens_cfg=self.cfg.get('hiddens', None), + ) + return model + + +class MegatronT5SpeechLMModel(MegatronBaseSpeechLM): + """ + Model class for prompt-tuning or p-tuning a pretrained Megatron T5 model. + + Prompt Tuning initializes virtual prompt embeddings directly from a copy of + certain token embeddings from the pretrained T5 model's vocabulary + and directly tunes these embedding weights. The token embeddings used in + initialization are specified by the user in the config file. The model can + be prompt-tuned for multiple tasks at once. Virtual prompts are stored in a + prompt table and can be added or deleted without disrupting virtual prompts + for other tasks. + + P-tuning initializes an LSTM encoder model that generates virtual prompt + embeddings for every task. Each task shares the same encoder. After p-tuning + is complete, the learned virtual prompts can be saved to the prompt table + using add_ptuned_prompts_to_prompt_table(). Thus, if a user wants to add a + new virtual prompt via p-tuning, they do not need to retrain on all previous + tasks. This gives p-tuning the same task flexibility as prompt-tuning. + """ + + def __init__(self, cfg: DictConfig, trainer: Trainer): + super().__init__(cfg, trainer) + self.model_type = ModelType.encoder_and_decoder + speech_codebook_size = cfg.data.get('speech_codebook_size', 1024) + num_speech_codebooks = cfg.data.get('num_speech_codebooks', 8) + speech_offset = cfg.data.get('speech_offset', 30000) + codecmodel_type = cfg.get('codecmodel_type', 'nemo_codec') + attn_prior_scaledown_start_step = cfg.get('attn_prior_scaledown_start_step', 10000) + attn_prior_end_step = cfg.get('attn_prior_end_step', 11000) + num_cross_attention_heads = cfg.get('num_cross_attention_heads', 12) + self.lm_vocab_size = cfg.get('lm_vocab_size', 30000) + self.context_pattern = cfg.data.get('context_pattern', 'parallel') + self.context_conditioning = cfg.get('context_conditioning', "decoder") + self.context_duration_min = cfg.data.get('context_duration_min', 2.9) + self.context_duration_max = cfg.data.get('context_duration_max', 2.9) + self.codebook_fps = cfg.data.get('codebook_fps', 86) + self.decoder_context_len = 0 + if self.context_conditioning == "decoder": + assert self.context_duration_min == self.context_duration_max, "Decoder context duration must be fixed" + self.decoder_context_len = int(self.codebook_fps * self.context_duration_min) + + self.speech_offset = speech_offset + self.speech_codebook_size = speech_codebook_size + self.num_speech_codebooks = num_speech_codebooks + self.codecmodel_type = codecmodel_type + self.enc_output_to_layers = cfg.get('enc_output_to_layers', None) + if self.enc_output_to_layers is not None: + # Convert from listconfig to list + self.enc_output_to_layers = [[l for l in encoder_layer] for encoder_layer in self.enc_output_to_layers] + + self.frozen_model.enc_dec_model.speech_offset = speech_offset + self.frozen_model.enc_dec_model.speech_codebook_size = speech_codebook_size + self.frozen_model.enc_dec_model.num_speech_codebooks = num_speech_codebooks + self.frozen_model.enc_dec_model.seq_pattern = cfg.get('seq_pattern', 'parallel') + self.frozen_model.enc_dec_model.attn_prior_scaledown_start_step = attn_prior_scaledown_start_step + self.frozen_model.enc_dec_model.attn_prior_end_step = attn_prior_end_step + self.frozen_model.enc_dec_model.alignment_decoder_layerids = cfg.get( + 'alignment_decoder_layerids', list(range(0, 12)) + ) + self.frozen_model.enc_dec_model.return_all_crossattention_probs = cfg.get( + 'return_all_crossattention_probs', False + ) + self.frozen_model.enc_dec_model.num_cross_attention_heads = num_cross_attention_heads + self.frozen_model.enc_dec_model.context_conditioning = self.context_conditioning + self.frozen_model.enc_dec_model.decoder_context_len = self.decoder_context_len + self.frozen_model.enc_dec_model.enc_output_to_layers = self.enc_output_to_layers + + self.alignment_loss_start_step = 0 + self.alignment_loss_end_step = float('inf') + self.use_alignment_loss = cfg.get('use_alignment_loss', False) + if self.use_alignment_loss: + alignment_loss_scale = cfg.get('alignment_loss_scale', 1.0) + self.frozen_model.enc_dec_model.use_alignment_loss = True + self.frozen_model.enc_dec_model.forward_sum_loss = ForwardSumLoss(loss_scale=alignment_loss_scale) + self.frozen_model.enc_dec_model.alignment_text_end_offset = cfg.get('alignment_text_end_offset', 0) + self.frozen_model.enc_dec_model.align_every_n_head = cfg.get('align_every_n_head', 1) + self.alignment_loss_start_step = cfg.get('alignment_loss_start_step', 0) + self.alignment_loss_end_step = cfg.get('alignment_loss_end_step', float('inf')) + + # Need to explicitly set this since it is already initialized + self.frozen_model.enc_dec_model.tokens_head.parallel_output = self.frozen_model.enc_dec_model.parallel_output + + list_of_speech_heads = [] + list_of_speech_tokens_embeddings = [] + for _ in range(self.num_speech_codebooks - 1): + # init is NOT used since we overwrite the weight below anyways + _speech_head_embedding = tensor_parallel.VocabParallelEmbedding( + speech_codebook_size, + embedding_dim=self.word_embeddings.embedding_dim, + init_method=lambda x: x.data.fill_(0), + config=self.model_parallel_config, + ) + _speech_head_embedding.weight.data.fill_(0) + _speech_head_embedding.shared = True + list_of_speech_tokens_embeddings.append(_speech_head_embedding) + # Linear layer that maps from hidden size to speech codebook size + hidden_size = self.frozen_model.enc_dec_model.decoder_cfg.hidden_size + init_method_std = self.frozen_model.enc_dec_model.decoder_cfg.init_method_std + # Changing to ColumnParallelLinear instead of Linear to support 3b Tensor Parallelism + _speech_head = tensor_parallel.ColumnParallelLinear( + input_size=hidden_size, + output_size=speech_codebook_size, + bias=True, + gather_output=not self.frozen_model.enc_dec_model.parallel_output, + init_method=init_method_normal(init_method_std), + config=self.model_parallel_config, + ) + list_of_speech_heads.append(_speech_head) + + self.frozen_model.enc_dec_model.speech_tokens_heads = torch.nn.ModuleList(list_of_speech_heads) + self.frozen_model.enc_dec_model.speech_tokens_embeddings = torch.nn.ModuleList( + list_of_speech_tokens_embeddings + ) + + self.sample_rate = 24000 + if codecmodel_type == 'nemo_codec': + codec_model = AudioCodecModel.restore_from(cfg.get('codecmodel_path')) + codec_model.to('cuda') + codec_model.eval() + self.sample_rate = 22050 + else: + raise NotImplementedError() + + self.additional_models = {'codec': codec_model} + self.train_check_interval = self.cfg.get('train_check_interval', 500) + self.plot_alignments_sliced = self.cfg.get('plot_alignments_sliced', True) + app_state = AppState() + self.is_rank_zero = app_state.global_rank == 0 + self.predict_step_outputs = [] + self.phoneme_tokenizer = None + + # classifier-free guidance (CFG) option during training. The probability (0.0 <= ε <= 1.0) is used to trigger the action that the + # text or audio tokens in a batch are replaced by [UNK], such that mimicking the text- or audio-free scenario. + # If a random number is greater than ε, then keep text or audio tokens as-is, otherwise, the text or audio tokens are + # replaced by [UNK]. Default to 0.0, meaning CFG is disabled. + self.train_text_cfg_prob = cfg.get('train_text_cfg_prob', 0.0) + self.train_audio_cfg_prob = cfg.get('train_audio_cfg_prob', 0.0) + self._rng = random.Random() + + # control the strength of the classifier guidance during inference, Logits_cfg = w*Logits_cond + (1-w)*Logits_uncond, + # equivalent to Logits_cfg = Logits_cond + alpha*(Logits_cond - Logits_uncond) where alpha=w-1. + # Default w to 1.O, indicating no interpolation is applied. + self.inference_cfg_interpolation_scale = cfg.get('inference_cfg_interpolation_scale', 1.0) + self.inference_apply_text_cfg = cfg.get('inference_apply_text_cfg', False) + self.inference_apply_audio_cfg = cfg.get('inference_apply_audio_cfg', False) + if self.inference_cfg_interpolation_scale == 1.0: + self.inference_apply_text_cfg = False + self.inference_apply_audio_cfg = False + + # whether to apply cfg filter to address faster speech rate. + self.inference_apply_cfg_filter = cfg.get("inference_apply_cfg_filter", False) + + # this scale is suggested to be smaller than `self.question_guidance_scale` and it is used to balance the weights + # between the conditioned logits after applying cfg filter and the original unconditioned logits. Default to 1.0, + # indicating only conditioned logits are used. + if not self.inference_apply_cfg_filter: + self.inference_cfg_filter_interpolation_scale = None + else: + self.inference_cfg_filter_interpolation_scale = cfg.get('inference_cfg_filter_interpolation_scale', 1.0) + + # whether to estimate MOS in predict_step. + self.estimate_mos = cfg.get('estimate_mos', True) + if self.estimate_mos: + # requires to specify a non-matching high-quality and clean reference audio file. It is used to estimate MOS. + self.non_matching_ref_audio_filepath = cfg.get('non_matching_ref_audio_filepath', None) + if self.non_matching_ref_audio_filepath is None: + raise ValueError( + f"Please provide a high-quality reference audio to estimate the MOS. Alternatively, " + f"set `model.estimate_mos=False` to disable MOS estimation." + ) + if not os.path.exists(self.non_matching_ref_audio_filepath): + raise FileNotFoundError( + f"Please provide a valid file path for a high-quality reference audio to estimate" + f" the MOS. Alternatively, set `model.estimate_mos=False` to disable MOS estimation." + ) + + def decode_wav_from_codec_model(self, codes): + codec_model = self.additional_models['codec'] + if self.codecmodel_type == 'nemo_codec': + codec_len = torch.Tensor([codes.shape[1]]).long().cuda() + if codec_len < 10: + # return a one-second silence + return torch.zeros(24000).cuda() + wav, _ = codec_model.decode(tokens=codes.unsqueeze(0), tokens_len=codec_len) + wav = wav[0] + else: + raise NotImplementedError() + return wav + + def first_stage_of_pipeline(self): + if self.frozen_model.enc_dec_model.pre_process and parallel_state.get_pipeline_model_parallel_rank() == 0: + return True + return False + + def forward( + self, + virtual_tokens, + context_and_question_tokens, + enc_mask, + dec_input, + dec_mask, + position_ids, + taskname_ids, + labels=None, + speech_mask=None, + inference=False, + inference_step=0, + cross_attention_prior=None, + text_limits=None, + decoder_max_sequence_len=None, + encoder_max_sequence_len=None, + ): + """ + Special forward method for p-tuning/prompt-tuning pretrained + T5 style models. + """ + if isinstance(context_and_question_tokens, list): + multi_encoder = True + assert isinstance(enc_mask, list) + assert isinstance(position_ids, list) + if cross_attention_prior is None: + cross_attention_prior = [None for _ in range(len(context_and_question_tokens))] + assert isinstance(cross_attention_prior, list) + assert len(context_and_question_tokens) == len(enc_mask) == len(position_ids) == len(cross_attention_prior) + else: + multi_encoder = False + context_and_question_tokens = [context_and_question_tokens] + enc_mask = [enc_mask] + position_ids = [position_ids] + cross_attention_prior = [cross_attention_prior] + + enc_output = None + logging.debug( + f"self.first_stage_of_pipeline()={self.first_stage_of_pipeline()}\tinference_step={inference_step}" + ) + if self.first_stage_of_pipeline() and inference_step == 0: + # Get embeddings for text tokens and insert virtual token embeddings + encoder_input_list = [] + for ei in range(len(context_and_question_tokens)): + input_embeds = self.get_embeddings_and_combine( + [virtual_tokens, context_and_question_tokens[ei]], taskname_ids, inference + ) + # TODO: This check needs to be revisited with PP support. + if hasattr(self.frozen_model.enc_dec_model.encoder_embedding, 'position_embeddings'): + position_embeddings = self.frozen_model.enc_dec_model.encoder_embedding.position_embeddings( + position_ids[ei] + ) + encoder_input = input_embeds + position_embeddings + else: + encoder_input = input_embeds + encoder_input_list.append(encoder_input) + else: + encoder_input_list = None + encoder_input = None + if inference_step != 0: + enc_output = context_and_question_tokens if multi_encoder else context_and_question_tokens[0] + + # If the decoder input starts with instead of , which is the case for huggingface T5 models, we don't want to mask the first token. + # For NeMo-Megatron, the sequence starts with , which is never masked so we can always set index 0 to be unmasked. + dec_mask[:, 0] = 1 + + if not self.cfg.data.get('use_attention_prior', False): + cross_attention_prior = [None for _ in range(len(cross_attention_prior))] + + _encoder_input = encoder_input_list + if not multi_encoder: + enc_mask = enc_mask[0] + cross_attention_prior = cross_attention_prior[0] + _encoder_input = encoder_input_list[0] if encoder_input_list is not None else None + + # Call forward on T5 model with preprocessed embeddings + if inference and inference_step == 0: + set_inference_key_value_memory = True + else: + set_inference_key_value_memory = False + + if self.autocast_dtype == torch.float32: + output, out_logits = self.frozen_model.enc_dec_model( + enc_input_ids=None, + enc_attn_mask=enc_mask, + dec_input_ids=dec_input, + dec_attn_mask=dec_mask, + token_type_ids=None, + labels=labels, + output_enc_hidden_only=False, + enc_input=_encoder_input, + enc_output=enc_output, + speech_mask=speech_mask, + cross_attention_prior=cross_attention_prior, + text_limits=text_limits, + global_step=self.global_step, + set_inference_key_value_memory=set_inference_key_value_memory, + decoder_max_sequence_len=decoder_max_sequence_len, + encoder_max_sequence_len=encoder_max_sequence_len, + ) + else: + with torch.autocast(device_type="cuda", dtype=self.autocast_dtype): + output, out_logits = self.frozen_model.enc_dec_model( + enc_input_ids=None, + enc_attn_mask=enc_mask, + dec_input_ids=dec_input, + dec_attn_mask=dec_mask, + token_type_ids=None, + labels=labels, + output_enc_hidden_only=False, + enc_input=_encoder_input, + enc_output=enc_output, + speech_mask=speech_mask, + cross_attention_prior=cross_attention_prior, + text_limits=text_limits, + global_step=self.global_step, + set_inference_key_value_memory=set_inference_key_value_memory, + decoder_max_sequence_len=decoder_max_sequence_len, + encoder_max_sequence_len=encoder_max_sequence_len, + ) + + return output, encoder_input, out_logits + + def load_frozen_model(self, cfg, trainer): + self.megatron_amp_O2 = cfg.get('megatron_amp_o2', False) + + # TODO: Fix this once apex patches FusedScaledMaskedSoftmax. + # This is a workaround for the fact that `masked_softmax_fusion` has issues with certain input sizes that may be present while finetuning. + cfg_language_model_path = cfg.get('language_model_path', None) + cfg_frozen_model = cfg.get('frozen_model', None) + if not (bool(cfg_language_model_path) ^ bool(cfg_frozen_model)): + raise ValueError( + "T5-TTS requires either 'language_model_path' or 'frozen_model' in its config, but not both." + ) + + if cfg_language_model_path: + t5_cfg = MegatronT5Model.restore_from(cfg_language_model_path, trainer=trainer, return_config=True) + else: + t5_cfg = cfg_frozen_model + + OmegaConf.set_struct(t5_cfg, True) + with open_dict(t5_cfg): + if hasattr(t5_cfg, 'encoder') and hasattr(t5_cfg, 'decoder'): + t5_cfg.encoder.masked_softmax_fusion = False + t5_cfg.decoder.masked_softmax_fusion = False + else: + t5_cfg.masked_softmax_fusion = False + t5_cfg.megatron_amp_O2 = self.megatron_amp_O2 + # hack to make the _GLOBAL_NUM_MICROBATCHES_CALCULATOR initialize + t5_cfg.micro_batch_size = cfg.get('micro_batch_size', 4) + t5_cfg.global_batch_size = cfg.get('global_batch_size', 4) + t5_cfg.precision = trainer.precision + t5_cfg.tokenizer.num_sentinel_tokens = cfg.get('num_sentinel_tokens', 39184 - 29056) + t5_cfg.seq_length = cfg.data.max_seq_length + if cfg.get('max_position_embeddings', None) is None: + t5_cfg.max_position_embeddings = cfg.data.max_seq_length + else: + t5_cfg.max_position_embeddings = cfg.get('max_position_embeddings') + t5_cfg.use_flash_attention = cfg.get('use_flash_attention', False) + if cfg.get('override_token_model', None): + t5_cfg.tokenizer.model = cfg['override_token_model'] + if cfg.get('override_tokenizer_vocab_file', None): + t5_cfg.tokenizer.vocab_file = cfg['override_tokenizer_vocab_file'] + + if cfg.get('train_from_scratch', False): + print("Training from scratch!") + # Defaults for 220m model + # To override any of these, add +model.override_= to the config file. + # Eg. +model.override_hidden_size=1024 + overide_keys = [ + 'hidden_size', # 768 + 'num_layers', # 12 + 'num_attention_heads', # 12 + 'hidden_dropout', # 0.1 + 'attention_dropout', # 0.1 + 'kv_channels', # 64 + 'ffn_hidden_size', # 2048 + ] + # Defaults for 220m model + for k in overide_keys: + if cfg.get(f'override_{k}') is not None: + t5_cfg[k] = cfg.get(f'override_{k}') + + self.frozen_model = MegatronT5OverrideModel(t5_cfg, trainer=trainer) + num_params = sum(p.numel() for p in self.frozen_model.parameters() if p.requires_grad) + print(f"Number of parameters: {num_params}") + else: + print(f"Loading from pretrained checkpoint: {cfg_language_model_path}") + if cfg_language_model_path is None: + raise ValueError( + "T5-TTS SFT on pretrained model checkpoint requires `langauge_model_path` in its config." + ) + + self.frozen_model = MegatronT5OverrideModel.restore_from( + cfg_language_model_path, + trainer=trainer, + override_config_path=t5_cfg, + save_restore_connector=NLPSaveRestoreConnector(), + ) + + if not cfg.get('english_only_model', False): + self.frozen_model.tokenizer.add_phone_tokens_to_special_tokens() + + logging.info(f"self.frozen_model {self.frozen_model}") + + def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only): + """ + Dataloader produces a global batch which is turned into a list of microbatches. + The list of microbatches is then piped through the pipeline using megatron-core fwd/bwd functions. + """ + # Get seq length of batch + batch = next(dataloader_iter) + _, seq_length = batch[0].shape + if batch[4].dim() > 2: + _, _, dec_seq_length = batch[4].shape + else: + _, dec_seq_length = batch[4].shape + data_iter = get_iterator_k_split(batch, get_num_microbatches()) + + fwd_bwd_function = get_forward_backward_func() + + losses_reduced_per_micro_batch = fwd_bwd_function( + forward_step_func=self.get_forward_output_and_loss_func(forward_only), + data_iterator=data_iter, + model=[self], + num_microbatches=get_num_microbatches(), + forward_only=forward_only, + seq_length=seq_length, + micro_batch_size=get_micro_batch_size(), + decoder_seq_length=dec_seq_length, + ) + + # only the last stages of the pipeline return losses + if losses_reduced_per_micro_batch: + # average loss across micro batches + loss_tensors_list = [loss_reduced['avg'] for loss_reduced in losses_reduced_per_micro_batch] + loss_tensor = torch.concat(loss_tensors_list) + loss_mean = loss_tensor.mean() + else: + # we're not on the last pipeline stage so no losses + loss_mean = torch.tensor(0.0).cuda() + + return loss_mean + + def convert_tokens_to_range(self, tokens, apply_offset_correction=True, pattern=None): + # convert tokens to range [0, 1024] + output_tokens = tokens.clone() + if apply_offset_correction: + output_tokens[0] = output_tokens[0] - self.speech_offset + output_tokens = torch.clamp(output_tokens, min=0, max=self.speech_codebook_size - 1) + if pattern is None: + pattern = self.cfg.get('seq_pattern', 'delay_parallel') + if pattern == "delay_parallel": + output_tokens_new = [] + for _c in range(output_tokens.shape[0]): + si = _c + ei = _c + output_tokens.shape[1] - self.num_speech_codebooks + output_tokens_new.append(output_tokens[_c, si:ei]) + output_tokens_new = torch.stack(output_tokens_new) + output_tokens = output_tokens_new + + return output_tokens + + def get_forward_output_and_loss_func(self, validation_step=False): + def fwd_output_and_loss_func(dataloader_iter, model): + batch = next(dataloader_iter) + _batch = [] + for x in batch: + if isinstance(x, torch.Tensor): + x = x.cuda(non_blocking=True) + elif isinstance(x, list): + if isinstance(x[0], torch.Tensor): + x = [y.cuda(non_blocking=True) for y in x] + _batch.append(x) + batch = _batch + # batch = [x.cuda(non_blocking=True) if isinstance(x, torch.Tensor) else x for x in batch] + ( + virtual_tokens, + context_and_question_tokens, + enc_mask, + dec_input, + dec_input_mask, + labels, + loss_mask, + position_ids, + taskname_ids, + speech_mask, + context_and_question_tokens_lens, + cross_attention_prior, + text_limits, + _, # TODO: text limit and lang not in tarred dataset + _, + ) = batch + + if self.trainer.global_step % self.train_check_interval == 0 and not validation_step and self.is_rank_zero: + self.frozen_model.enc_dec_model.logging_step = True + + _cross_attention_prior = cross_attention_prior + if isinstance(context_and_question_tokens, list): + # None for context and prior for question + _cross_attention_prior = [None, cross_attention_prior] + + output_tensor, encoder_input, out_logits = model( + virtual_tokens, + context_and_question_tokens, + enc_mask, + dec_input, + dec_input_mask, + position_ids, + taskname_ids, + labels=labels, + speech_mask=speech_mask, + cross_attention_prior=_cross_attention_prior, + text_limits=text_limits, + inference=False, + ) + output_tensor = output_tensor.contiguous() + + alignment_loss = out_logits[3] + if alignment_loss is not None: + self.logger.experiment.add_scalar('train_alignment_loss', alignment_loss, self.global_step) + + if self.trainer.global_step % self.train_check_interval == 0 and not validation_step and self.is_rank_zero: + self.frozen_model.enc_dec_model.logging_step = False + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=False): + if torch.count_nonzero(speech_mask) == 0: + text_labels = labels[:, 0, :] # [B, 8, T] -> [B, T] + token_logits = out_logits[0] * 1 # [T, B, V] + if self.frozen_model.enc_dec_model.parallel_output: + # Gather from tensor parallel region + token_logits = tensor_parallel.gather_from_tensor_model_parallel_region(token_logits) + token_logits = token_logits.argmax(dim=2) # [T, B] + token_logits = token_logits.t() # [B, T] + score = 0 + for i in range(text_labels.size()[0]): + r = text_labels[i].long() + nzm = r != 0 + r = r.tolist() + h = token_logits[i].long() * nzm + h = h.tolist() + score += editdistance.eval(r, h) + score /= text_labels.size()[0] + logging.info(f"wer score : {score}") + self.logger.experiment.add_scalar('WER', score, self.global_step) + else: + audio_len = ( + self.decoder_context_len + (labels[0][0][self.decoder_context_len :] != 0).sum().item() + ) + labels_to_1024 = self.convert_tokens_to_range(labels[0, :, 0:audio_len]) + label_wav = self.decode_wav_from_codec_model(labels_to_1024) + dec_input_to_1024 = self.convert_tokens_to_range(dec_input[0, :, 0:audio_len]) + dec_input_wav = self.decode_wav_from_codec_model(dec_input_to_1024) + self.logger.experiment.add_audio( + "train_label_wav", label_wav, self.global_step, self.sample_rate + ) + self.logger.experiment.add_audio( + "train_dec_input_wav", dec_input_wav, self.global_step, self.sample_rate + ) + if isinstance(context_and_question_tokens, list): + context_tokens = context_and_question_tokens[0] + question_tokens = context_and_question_tokens[1] + input_token_list_all = [ + question_tokens[0, 0, i].item() for i in range(question_tokens.shape[2]) + ] + input_token_list = [ + (ti, t) + for ti, t in enumerate(input_token_list_all) + if t != 0 and t < self.speech_offset + ] + context_end_step = context_and_question_tokens_lens[0][0].item() + _context_tokens = context_tokens[0, :, :context_end_step] + else: + input_token_list_all = [ + context_and_question_tokens[0, 0, i].item() + for i in range(context_and_question_tokens.shape[2]) + ] + input_token_list = [ + (ti, t) + for ti, t in enumerate(input_token_list_all) + if t != 0 and t < self.speech_offset + ] + context_end_step = input_token_list[0][0] + _context_tokens = context_and_question_tokens[0, :, :context_end_step] + + if context_end_step > 1: + is_speech_context = _context_tokens[1, :].sum().item() > 0 + if is_speech_context: + _context_tokens = self.convert_tokens_to_range( + _context_tokens, pattern=self.context_pattern + ) + _context_wav = self.decode_wav_from_codec_model(_context_tokens) + self.logger.experiment.add_audio( + "train_context_wav", _context_wav, self.global_step, self.sample_rate + ) + else: + _context_token_list = [v.item() for v in _context_tokens[0, :]] + _context_text = self.frozen_model.tokenizer.ids_to_text( + [v for v in _context_token_list if v < self.lm_vocab_size] + ) + self.logger.experiment.add_text( + "train_context_text", _context_text, self.global_step + ) + + question_si = text_limits[0, 0].item() - virtual_tokens.shape[1] + question_ei = text_limits[0, 1].item() - virtual_tokens.shape[1] + text_si = text_limits[0, 0].item() + text_ei = text_limits[0, 1].item() + input_text = self.frozen_model.tokenizer.ids_to_text( + [v for v in input_token_list_all[question_si:question_ei] if v < self.lm_vocab_size] + ) + self.logger.experiment.add_text("Train Input Text", input_text, self.global_step) + + input_phoneme_tokens = [ + v - self.lm_vocab_size + for v in input_token_list_all[question_si:question_ei] + if v >= self.lm_vocab_size + ] + + if len(input_phoneme_tokens) > 0: + phoneme_text = self.phoneme_tokenizer.decode(input_phoneme_tokens) + self.logger.experiment.add_text( + "Train Input Phoneme Text", phoneme_text, self.global_step + ) + + token_logits = out_logits[0] + speech_logits_list = out_logits[1] + + attention_probs_list = out_logits[2] # list of (BS, 12, out_length, in_length) + if attention_probs_list is not None: + attention_sliced_list = [] + for lidx in range(len(attention_probs_list)): + attention_probs = attention_probs_list[lidx] + for _i in range(attention_probs.shape[1]): + name = f"Attention Probs Layer {lidx} Head {_i}" + attention_to_plot = attention_probs[0, _i, :audio_len, :text_ei] + if self.plot_alignments_sliced: + attention_to_plot = attention_probs[0, _i, 0:audio_len, text_si:text_ei] + # 4 to offset "Text to Speech this" + name += " Sliced" + alignment_image = plot_alignment_to_numpy_for_speechllm( + attention_to_plot.cpu().float().numpy().T, + phoneme_ver=0 if self.plot_alignments_sliced else 1, + phoneme_seq=None if self.plot_alignments_sliced else [text_si], + ) + self.logger.experiment.add_image( + name, + alignment_image, + self.global_step, + dataformats="HWC", + ) + attention_sliced_list.append( + attention_probs[ + 0, _i, self.decoder_context_len : audio_len, text_si:text_ei + ] + ) + attention_sliced = torch.stack(attention_sliced_list) + attention_sliced = torch.mean(attention_sliced, 0) + text = None + if len(input_text) > 0: + text = self.frozen_model.tokenizer.ids_to_tokens( + [ + v + for v in input_token_list_all[question_si:question_ei] + if v < self.lm_vocab_size + ] + ) + if len(input_phoneme_tokens) > 0: + text = phoneme_text.split("|") + alignment_image_sliced = plot_alignment_to_numpy_for_speechllm( + attention_sliced.cpu().float().numpy().T, + phoneme_seq=text, + phoneme_ver=2, + vmin=0.0, + phone_offset=0, + h_offset=False, + ) + self.logger.experiment.add_image( + f"Attention Probs Average Sliced", + alignment_image_sliced, + self.global_step, + dataformats="HWC", + ) + if self.frozen_model.enc_dec_model.parallel_output: + # Gather from tensor parallel region + token_logits = tensor_parallel.gather_from_tensor_model_parallel_region(token_logits) + for _i in range(len(speech_logits_list)): + speech_logits_list[_i] = tensor_parallel.gather_from_tensor_model_parallel_region( + speech_logits_list[_i] + ) + speech_logits = torch.stack(speech_logits_list, dim=-1) # (t, b, 1024, 7) + token_logits_example = token_logits[:, 0, :] * 1 + speech_logits_example = speech_logits[:, 0, :, :] * 1 + first_layer_tokens = token_logits_example.argmax(dim=1) - self.speech_offset + other_layer_tokens = [] + for _i in range(speech_logits_example.shape[2]): + other_layer_tokens.append(speech_logits_example[:, :, _i].argmax(dim=1)) + + all_layer_tokens = torch.stack([first_layer_tokens] + other_layer_tokens) # (8, t) + all_layer_tokens = self.convert_tokens_to_range( + all_layer_tokens, apply_offset_correction=False + ) + # all_layer_tokens = torch.clip(all_layer_tokens, 0, 1023) + predicted_wav = self.decode_wav_from_codec_model(all_layer_tokens) + self.logger.experiment.add_audio( + "train_tf_pred_wav", predicted_wav, self.global_step, self.sample_rate + ) + + def loss_func(loss_args): + output_tensor, out_logits, curr_step = loss_args + alignment_loss = out_logits[3] + loss = self.frozen_model.loss_func(loss_mask, output_tensor) + if ( + (alignment_loss is not None) + and (curr_step > self.alignment_loss_start_step) + and (curr_step < self.alignment_loss_end_step) + ): + logging.debug(f"Adding alignment loss. cur:{curr_step} start:{self.alignment_loss_start_step}") + loss = loss + alignment_loss + reduced_loss = average_losses_across_data_parallel_group([loss]) + return loss, {'avg': reduced_loss} + + return [output_tensor, out_logits, self.global_step], loss_func + + return fwd_output_and_loss_func + + def get_forward_output_only_func(self): + """Used in inference / predict""" + + def fwd_output_only_func(dataloader_iter, model): + batch = next(dataloader_iter) + _batch = [] + for x in batch: + if isinstance(x, torch.Tensor): + x = x.cuda(non_blocking=True) + elif isinstance(x, list): + if isinstance(x[0], torch.Tensor): + x = [y.cuda(non_blocking=True) for y in x] + _batch.append(x) + batch = _batch + # batch = [x.cuda(non_blocking=True) if isinstance(x, torch.Tensor) else x for x in batch] + ( + decoder_max_sequence_len, + encoder_max_sequence_len, + context_and_question_tokens, + enc_mask, + dec_input, + dec_input_mask, + position_ids, + taskname_ids, + speech_mask, + ) = batch + + output_logits, _, token_and_speech_logits = model( + context_and_question_tokens, + context_and_question_tokens, + enc_mask, + dec_input, + dec_input_mask, + position_ids, + taskname_ids, + labels=None, + speech_mask=speech_mask, + inference=True, + inference_step=1, + decoder_max_sequence_len=decoder_max_sequence_len, + encoder_max_sequence_len=encoder_max_sequence_len, + ) + output_tensor = [output_logits, token_and_speech_logits] + + def id_func(output_tensor): + return 0, {'output_logits': output_tensor[0], 'token_and_speech_logits': output_tensor[1]} + + return output_tensor, id_func + + return fwd_output_only_func + + def backward(self, *args, **kwargs): + """LightningModule hook to do backward. + We want this to do nothing since we run backward in the fwd/bwd functions from megatron-core. + No need to call it here. + """ + return + + def optimizer_zero_grad(self, *args, **kwargs): + """LightningModule hook to zero grad. + We want this to do nothing as we are zeroing grads during the training_step. + """ + return + + def set_input_tensor(self, input_tensor): + """Set input tensor to be used instead of forward()'s input. + When using pipeline parallelism the input from the previous + stage comes from communication, not from the input, so the + model's forward_step_func won't have it. This function is thus + used by internal code to bypass the input provided by the + forward_step_func""" + self.frozen_model.enc_dec_model.set_input_tensor(input_tensor) + + def on_train_epoch_start(self) -> None: + gbs = self.cfg.global_batch_size + mbs = self.cfg.micro_batch_size + self._reconfigure_batch_sizes(gbs, mbs) + return super().on_train_epoch_start() + + def on_validation_epoch_start(self) -> None: + gbs = self.cfg.get('validation_global_batch_size', self.cfg.global_batch_size) + mbs = self.cfg.get('validation_micro_batch_size', self.cfg.micro_batch_size) + self._reconfigure_batch_sizes(gbs, mbs) + return super().on_validation_epoch_start() + + def training_step(self, dataloader_iter, batch_idx): + self._optimizer.zero_grad() + batch = next(dataloader_iter) + + # apply text classifier-free guidance by replacing input question tokens with [UNK]. + if self.train_text_cfg_prob > 0.0: + if self._rng.random() < self.train_text_cfg_prob: + logging.info(f"Text Classifier-Free Guidance is triggered for the {batch_idx}-th batch.") + + # temporally disable computing CTC alignment loss. + if self.use_alignment_loss: + self.frozen_model.enc_dec_model.use_alignment_loss = False + + # make cross-attention prior to None to remove the prior. + batch[11] = None + + # replace question token IDs with [UNK]'s id. No speech offset for Phoneme's [UNK]. Same op as train. + # instruction token IDs are bpe token IDs directly obtained from self.tokenizer without any offset. + # question token IDs are phoneme and grapheme token IDs and are offset by self.lm_vocab_size + # if under "Phoneme TTS" instruction, so existing no overlaps between instruction and question token IDs. + # question token IDs are bpe token IDs without any offset + # if under "Text to speech this" instruction, so existing overlaps between instruction and question token IDs. + context_and_question_tokens = batch[ + 1 + ] # (batch_size, self.num_speech_codebooks, max_context_question_tokens_len) + text_limits = batch[12] + virtual_tokens = batch[0] + question_limits = text_limits - virtual_tokens.size( + 1 + ) # (b, 2), reset question range to start from [pad] context, same start position as context_and_question_tokens. + question_start = question_limits[:, 0].unsqueeze(1) # (b, 1) + question_end = question_limits[:, 1].unsqueeze(1) # (b, 1) + + if isinstance(context_and_question_tokens, list): # indicate self.encoder_type=multi_transformers. + context_tokens, question_tokens = context_and_question_tokens + question_tokens_unconditioned = question_tokens.clone() + time_range = torch.arange( + question_tokens_unconditioned.size(2), device=question_tokens_unconditioned.device + ).unsqueeze(0) + question_mask = (time_range >= question_start) & ( + time_range < question_end + ) # create a mask for question only tokens. + question_tokens_unconditioned[:, 0][ + question_mask + ] = self.tokenizer.unk_id # only the first layer has non-zero IDs. + batch[1] = [context_tokens, question_tokens_unconditioned] + else: + context_and_question_tokens_unconditioned = ( + context_and_question_tokens.clone() + ) # (batch_size, self.num_speech_codebooks, max_context_question_tokens_len) + time_range = torch.arange( + context_and_question_tokens_unconditioned.size(2), + device=context_and_question_tokens_unconditioned.device, + ).unsqueeze( + 0 + ) # (1, max_context_question_tokens_len) + question_mask = (time_range >= question_start) & ( + time_range < question_end + ) # create a mask for question only tokens. + context_and_question_tokens_unconditioned[:, 0][ + question_mask + ] = self.tokenizer.unk_id # only the first layer has non-zero IDs. + batch[1] = context_and_question_tokens_unconditioned + + del question_limits, question_start, question_end, time_range, question_mask + else: + # recover to original alignment loss config. + self.frozen_model.enc_dec_model.use_alignment_loss = self.use_alignment_loss + + # apply audio context classifier-free guidance by replacing audio codec with [UNK] + if self.train_audio_cfg_prob > 0.0: + if self._rng.random() < self.train_audio_cfg_prob: + logging.info(f"Audio Classifier-Free Guidance is triggered for the {batch_idx}-th batch.") + + context_and_question_tokens = batch[ + 1 + ] # (batch_size, self.num_speech_codebooks, max_context_question_tokens_len) + + if isinstance(context_and_question_tokens, list): # indicate self.encoder_type=multi_transformers. + context_tokens, question_tokens = context_and_question_tokens + context_tokens_unconditioned = context_tokens.clone() + context_tokens_unconditioned[:, :, :] = ( + self.tokenizer.unk_id + ) # TODO @xueyang: verify if extra tokens other than audio codec tokens are appended. + batch[1] = [context_tokens_unconditioned, question_tokens] + else: + # dec_input + dec_input = batch[3] + dec_input_unconditioned = dec_input.clone() + dec_input_unconditioned[:, :, 1 : self.decoder_context_len + 1] = ( + self.tokenizer.unk_id + ) # TODO @xueyang: switch to other token id if this one is conflict with text unk. + batch[3] = dec_input_unconditioned + + loss_mean = self.fwd_bwd_step(itertools.chain([batch]), batch_idx, forward_only=False) + self.allreduce_gradients() + + ## logging + # we can only log on one rank if it is rank zero so we broadcast from last rank + # we can avoid this broadcast by updating the PTL log function to accept specific ranks + torch.distributed.broadcast(loss_mean, get_last_rank()) + + if self.cfg.precision == 16 and hasattr(self.trainer.precision_plugin.scaler, "_scale"): + loss_scale = self.trainer.precision_plugin.scaler._scale + if loss_scale is not None: + self.log('loss_scale', loss_scale, batch_size=1) + + self.log('reduced_train_loss', loss_mean, prog_bar=True, rank_zero_only=True, batch_size=1) + lr = self._optimizer.param_groups[0]['lr'] + self.log('lr', lr, rank_zero_only=True, batch_size=1) + self.log('global_step', self.trainer.global_step, prog_bar=True, rank_zero_only=True, batch_size=1) + return loss_mean + + def get_predictions(self, input_ids, enc_mask, encoder_input, labels): + predicted_token_ids, log_probs = self.frozen_model.decode( + tokens_enc=input_ids, + enc_mask=enc_mask, + num_tokens_to_generate=self.decoder_seq_length, + encoder_input=encoder_input, + bos_id=( + self.tokenizer.pad_id if self.cfg.data.get('decoder_starts_with_pad', False) else self.tokenizer.bos_id + ), + ) + # Special ids to text function to handle stripping and special tokens with sentencepiece tokenizers. + preds_text = MegatronT5SFTModel.ids_to_text(predicted_token_ids, self.tokenizer) + labels_text = MegatronT5SFTModel.ids_to_text(labels, self.tokenizer) + input_text = MegatronT5SFTModel.ids_to_text(input_ids, self.tokenizer) + return { + 'predicted_token_ids': preds_text, + 'labels': labels_text, + 'enc_inputs': input_text, + } + + def get_embeddings(self, tokens, taskname_ids, inference=False): + out = None + if tokens.dim() > 2: + for i in range(tokens.size()[1]): # for 8 channels + if i == 0: + # Embed first layer using word embeddings + out = self.embed_input(tokens[:, i, :], taskname_ids, inference) # (B, T, D) + else: + # Embed other layers using speech embeddings + cur = self.frozen_model.enc_dec_model.speech_tokens_embeddings[i - 1](tokens[:, i, :]) + # do not add embeddings of zero tokens of other channels (except the first channel) + non_zero_flag = tokens[:, i, :] != 0 # (B, T) + cur = cur * non_zero_flag.unsqueeze(2) + out = out + cur + else: + out = self.embed_input(tokens, taskname_ids, inference) + return out + + def get_embeddings_and_combine(self, token_list, taskname_ids, inference): + embedding_list = [] + for tokens in token_list: + embedding_list.append(self.get_embeddings(tokens, taskname_ids, inference)) + return torch.cat(embedding_list, dim=1) + + def validation_step(self, batch, batch_idx, dataloader_idx=0): + ( + virtual_tokens, + context_and_question_tokens, + enc_mask, + dec_input, + dec_input_mask, + labels, + loss_mask, + position_ids, + taskname_ids, + speech_mask, + context_and_question_tokens_lens, + cross_attention_prior, + text_limits, + _, + _, + ) = batch + # loss_mask (b, t) + # does not use dataloader_iter due to device placement issues arising from PTL + + mode = self.training + self.eval() + gbs = self.cfg.get('validation_global_batch_size', self.cfg.global_batch_size) + self._reconfigure_and_process_inference_batch(virtual_tokens.size(0), gbs) + + loss_mean = self.fwd_bwd_step( + itertools.chain([batch]), batch_idx, forward_only=True + ) # comment this out and add custom forward function to calculate WER + # # logging.info (f'loss_mean {loss_mean}') + + if batch_idx == 0 and self.is_rank_zero: + self.frozen_model.enc_dec_model.logging_step = True + self.predict_step_outputs = [] + # log_scalars=False avoids logging scalar TTS metrics in the predict_step + # Images, audio and texts will still be logged + self.predict_step(batch=batch, batch_idx=batch_idx, log_scalars=False, global_step=self.global_step) + for inf_key in self.predict_step_outputs[0]: + if self.predict_step_outputs[0][inf_key] is not None: + self.logger.experiment.add_scalar( + f'Val_{inf_key}', self.predict_step_outputs[0][inf_key], self.global_step + ) + + labels_original = labels.clone() # (b, 8, t) + + _cross_attention_prior = cross_attention_prior + if isinstance(context_and_question_tokens, list): + _cross_attention_prior = [None, cross_attention_prior] + + output_loss, _, output_logits = self.forward( + virtual_tokens, + context_and_question_tokens, + enc_mask, + dec_input, + dec_input_mask, + position_ids, + taskname_ids, + labels=labels, + speech_mask=speech_mask, + cross_attention_prior=_cross_attention_prior, + text_limits=text_limits, + inference=False, + ) + + if batch_idx == 0 and self.is_rank_zero: + self.frozen_model.enc_dec_model.logging_step = False + with torch.cuda.amp.autocast(enabled=False): + if torch.count_nonzero(speech_mask) == 0: + text_labels = labels[:, 0, :] # [B, 8, T] -> [B, T] + token_logits = output_logits[0] * 1 # [T, B, V] + if self.frozen_model.enc_dec_model.parallel_output: + # Gather from tensor parallel region + token_logits = tensor_parallel.gather_from_tensor_model_parallel_region(token_logits) + token_logits = token_logits.argmax(dim=2) # [T, B] + token_logits = token_logits.t() # [B, T] + score = 0 + for i in range(text_labels.size()[0]): + r = text_labels[i].long() + nzm = r != 0 + r = r.tolist() + h = token_logits[i].long() * nzm + h = h.tolist() + score += editdistance.eval(r, h) + score /= text_labels.size()[0] + logging.info(f"wer score : {score}") + self.logger.experiment.add_scalar('WER', score, self.global_step) + else: + audio_len = self.decoder_context_len + (labels[0][0][self.decoder_context_len :] != 0).sum().item() + labels_to_1024 = self.convert_tokens_to_range(labels[0, :, 0:audio_len]) + label_wav = self.decode_wav_from_codec_model(labels_to_1024) + dec_input_to_1024 = self.convert_tokens_to_range(dec_input[0, :, 0:audio_len]) + dec_input_wav = self.decode_wav_from_codec_model(dec_input_to_1024) + self.logger.experiment.add_audio("val_label_wav", label_wav, self.global_step, self.sample_rate) + self.logger.experiment.add_audio( + "val_dec_input_wav", dec_input_wav, self.global_step, self.sample_rate + ) + + if isinstance(context_and_question_tokens, list): + context_tokens = context_and_question_tokens[0] + question_tokens = context_and_question_tokens[1] + input_token_list_all = [ + question_tokens[0, 0, i].item() for i in range(question_tokens.shape[2]) + ] + input_token_list = [ + (ti, t) for ti, t in enumerate(input_token_list_all) if t != 0 and t < self.speech_offset + ] + context_end_step = context_and_question_tokens_lens[0][0].item() + _context_tokens = context_tokens[0, :, :context_end_step] + + else: + input_token_list_all = [ + context_and_question_tokens[0, 0, i].item() + for i in range(context_and_question_tokens.shape[2]) + ] + input_token_list = [ + (ti, t) for ti, t in enumerate(input_token_list_all) if t != 0 and t < self.speech_offset + ] + context_end_step = input_token_list[0][0] + _context_tokens = context_and_question_tokens[0, :, :context_end_step] + if context_end_step > 1: + is_speech_context = _context_tokens[1, :].sum().item() > 0 + if is_speech_context: + _context_tokens = self.convert_tokens_to_range( + _context_tokens, pattern=self.context_pattern + ) + _context_wav = self.decode_wav_from_codec_model(_context_tokens) + self.logger.experiment.add_audio( + "val_context_wav", _context_wav, self.global_step, self.sample_rate + ) + else: + _context_token_list = [v.item() for v in _context_tokens[0, :]] + _context_text = self.frozen_model.tokenizer.ids_to_text( + [v for v in _context_token_list if v < self.lm_vocab_size] + ) + self.logger.experiment.add_text("val_context_text", _context_text, self.global_step) + + question_si = text_limits[0, 0].item() - virtual_tokens.shape[1] + question_ei = text_limits[0, 1].item() - virtual_tokens.shape[1] + + text_si = text_limits[0, 0].item() + text_ei = text_limits[0, 1].item() + + input_text = self.frozen_model.tokenizer.ids_to_text( + [v for v in input_token_list_all[question_si:question_ei] if v < self.lm_vocab_size] + ) + self.logger.experiment.add_text("Val Input Text", input_text, self.global_step) + + input_phoneme_tokens = [ + v - self.lm_vocab_size + for v in input_token_list_all[question_si:question_ei] + if v >= self.lm_vocab_size + ] + if len(input_phoneme_tokens) > 0: + phoneme_text = self.phoneme_tokenizer.decode(input_phoneme_tokens) + self.logger.experiment.add_text("Val Input Phoneme Text", phoneme_text, self.global_step) + + token_logits = output_logits[0] + speech_logits_list = output_logits[1] + + # if self.trainer.global_step % 500 == 0: + attention_probs_list = output_logits[2] # list of (BS, 12, out_length, in_length) + if attention_probs_list is not None: + attention_sliced_list = [] + for lidx in range(len(attention_probs_list)): + attention_probs = attention_probs_list[lidx] + for _i in range(attention_probs.shape[1]): + attention_sliced_list.append( + attention_probs[0, _i, self.decoder_context_len : audio_len, text_si:text_ei] + ) + attention_sliced = torch.stack(attention_sliced_list) + attention_sliced = torch.mean(attention_sliced, 0) + text = None + if len(input_text) > 0: + text = self.frozen_model.tokenizer.ids_to_tokens( + [v for v in input_token_list_all[question_si:question_ei] if v < self.lm_vocab_size] + ) + if len(input_phoneme_tokens) > 0: + text = phoneme_text.split("|") + alignment_image_sliced = plot_alignment_to_numpy_for_speechllm( + attention_sliced.cpu().float().numpy().T, + phoneme_seq=text, + phoneme_ver=2, + vmin=0.0, + phone_offset=0, + h_offset=False, + ) + self.logger.experiment.add_image( + f"Val Attention Probs Average Sliced", + alignment_image_sliced, + self.global_step, + dataformats="HWC", + ) + if self.frozen_model.enc_dec_model.parallel_output: + # Gather from tensor parallel region + token_logits = tensor_parallel.gather_from_tensor_model_parallel_region(token_logits) + for _i in range(len(speech_logits_list)): + speech_logits_list[_i] = tensor_parallel.gather_from_tensor_model_parallel_region( + speech_logits_list[_i] + ) + speech_logits = torch.stack(speech_logits_list, dim=-1) # (t, b, 1024, 7) + token_logits_example = token_logits[:, 0, :] * 1 + speech_logits_example = speech_logits[:, 0, :, :] * 1 + first_layer_tokens = token_logits_example.argmax(dim=1) - self.speech_offset + other_layer_tokens = [] + for _i in range(speech_logits_example.shape[2]): + other_layer_tokens.append(speech_logits_example[:, :, _i].argmax(dim=1)) + + all_layer_tokens = torch.stack([first_layer_tokens] + other_layer_tokens) # (8, t) + all_layer_tokens = self.convert_tokens_to_range(all_layer_tokens, apply_offset_correction=False) + all_layer_tokens = torch.clip(all_layer_tokens, 0, self.speech_codebook_size - 1) + predicted_wav = self.decode_wav_from_codec_model(all_layer_tokens) + self.logger.experiment.add_audio( + "val_tf_pred_wav", predicted_wav, self.global_step, self.sample_rate + ) + + first_layer_logits = output_logits[0] + speech_logits_list = output_logits[1] + + if self.frozen_model.enc_dec_model.parallel_output: + # Gather from tensor parallel region + first_layer_logits = tensor_parallel.gather_from_tensor_model_parallel_region(first_layer_logits) + if torch.count_nonzero(speech_mask) > 0: + for _i in range(len(speech_logits_list)): + speech_logits_list[_i] = tensor_parallel.gather_from_tensor_model_parallel_region( + speech_logits_list[_i] + ) + speech_logits = torch.stack(speech_logits_list, dim=-1) # (t, b, 1024, 7) + first_layer_preds = first_layer_logits.argmax(dim=2) # (t,bs) + first_layer_preds = first_layer_preds.transpose(0, 1) # (bs,t) + labels_first_layer = labels_original[:, 0, :] # (bs,t) + correct_predictions = first_layer_preds == labels_first_layer # (bs,t) + correct_predictions = correct_predictions * loss_mask # (bs,t) + total_correct_predictions = torch.sum(correct_predictions) + total_predictions = torch.sum(loss_mask) + first_layer_accuracy = total_correct_predictions / total_predictions + first_layer_loss = torch.nn.functional.cross_entropy( + first_layer_logits.permute(1, 2, 0), labels_first_layer, reduction='none' + ) # (bs,t) + first_layer_loss = torch.sum(first_layer_loss * loss_mask) / total_predictions + + metrics = { + 'loss': loss_mean, + 'first_layer_accuracy': first_layer_accuracy, + 'first_layer_loss': first_layer_loss, + } + loss_total = first_layer_loss + for i in range(self.num_speech_codebooks - 1): + if torch.count_nonzero(speech_mask) > 0: + speech_logits_i = speech_logits[:, :, :, i] + speech_preds_i = speech_logits_i.argmax(dim=2) # (t,bs) + speech_preds_i = speech_preds_i.transpose(0, 1) # (bs,t) + labels_i = labels_original[:, i + 1, :] # (bs,t) + correct_predictions_i = speech_preds_i == labels_i # (bs,t) + correct_predictions_i = correct_predictions_i * loss_mask * speech_mask # (bs,t) + total_correct_predictions_i = torch.sum(correct_predictions_i) + total_predictions_i = torch.sum(loss_mask * speech_mask) + speech_accuracy_i = total_correct_predictions_i / total_predictions_i + loss_i = torch.nn.functional.cross_entropy( + speech_logits_i.permute(1, 2, 0), labels_i, reduction='none' + ) # (bs,t) + loss_i = torch.sum(loss_i * loss_mask * speech_mask) / total_predictions_i + else: + speech_accuracy_i = torch.tensor(0.0) + loss_i = torch.tensor(0.0) + metrics[f'speech_accuracy_{i+1}'] = speech_accuracy_i + metrics[f'speech_loss_{i+1}'] = loss_i + loss_total += loss_i + + metrics['loss_total_check'] = loss_total + self.validation_step_outputs.append(metrics) + self.train(mode=mode) + self.frozen_model.train() + return metrics['loss'] + + def on_validation_epoch_end(self): + outputs = self.validation_step_outputs + if self.cfg.get('pipeline_model_parallel_size', 1) > 1: + if parallel_state.is_pipeline_last_stage(): + # only the last pipeline parallel stages return loss + averaged_loss = torch.stack([item['loss'] for item in outputs]).mean() + averaged_loss_total_check = torch.stack([item['loss_total_check'] for item in outputs]).mean() + averaged_first_layer_accuracy = torch.stack([item['first_layer_accuracy'] for item in outputs]).mean() + + self.log( + 'val_loss_total_check', + averaged_loss_total_check, + prog_bar=False, + rank_zero_only=True, + batch_size=1, + ) + self.log( + 'val_first_layer_accuracy', + averaged_first_layer_accuracy, + prog_bar=True, + rank_zero_only=True, + batch_size=1, + ) + logging.info(f'Validation first_layer_accuracy: {averaged_first_layer_accuracy}') + logging.info(f'Validation loss_total_check: {averaged_loss_total_check}') + + for i in range(1, self.num_speech_codebooks): + averaged_speech_accuracy = torch.stack([item[f'speech_accuracy_{i}'] for item in outputs]).mean() + averaged_speech_loss = torch.stack([item[f'speech_loss_{i}'] for item in outputs]).mean() + self.log( + f'val_speech_accuracy_{i}', + averaged_speech_accuracy, + prog_bar=True, + rank_zero_only=True, + batch_size=1, + ) + self.log( + f'val_speech_loss_{i}', averaged_speech_loss, prog_bar=True, rank_zero_only=True, batch_size=1 + ) + logging.info(f'Validation speech_accuracy_{i}: {averaged_speech_accuracy}') + logging.info(f'Validation speech_loss_{i}: {averaged_speech_loss}') + else: + averaged_loss = torch.tensor(0.0).cuda() + + # we can only log on one rank if it is rank zero so we broadcast from last rank + torch.distributed.broadcast(averaged_loss, get_last_rank()) + + self.log('val_loss', averaged_loss, prog_bar=True, rank_zero_only=True, batch_size=1) + logging.info(f'Validation loss: {averaged_loss}') + + else: + if len(outputs) > 0: + averaged_loss = torch.stack([item['loss'] for item in outputs]).mean() + averaged_loss_total_check = torch.stack([item['loss_total_check'] for item in outputs]).mean() + logging.info(f'Validation loss: {averaged_loss}') + self.log('val_loss', averaged_loss, prog_bar=True, rank_zero_only=True, batch_size=1) + self.log( + 'val_loss_total_check', + averaged_loss_total_check, + prog_bar=False, + rank_zero_only=True, + batch_size=1, + ) + + averaged_first_layer_accuracy = torch.stack([item['first_layer_accuracy'] for item in outputs]).mean() + logging.info(f'Validation first_layer_accuracy: {averaged_first_layer_accuracy}') + self.log( + 'val_first_layer_accuracy', + averaged_first_layer_accuracy, + prog_bar=True, + rank_zero_only=True, + batch_size=1, + ) + + for i in range(1, self.num_speech_codebooks): + averaged_speech_accuracy = torch.stack([item[f'speech_accuracy_{i}'] for item in outputs]).mean() + averaged_speech_loss = torch.stack([item[f'speech_loss_{i}'] for item in outputs]).mean() + logging.info(f'Validation speech_accuracy_{i}: {averaged_speech_accuracy}') + logging.info(f'Validation speech_loss_{i}: {averaged_speech_loss}') + self.log( + f'val_speech_accuracy_{i}', + averaged_speech_accuracy, + prog_bar=True, + rank_zero_only=True, + batch_size=1, + ) + self.log( + f'val_speech_loss_{i}', averaged_speech_loss, prog_bar=True, rank_zero_only=True, batch_size=1 + ) + + if self.cfg.get("report_validation_metric", False): + gather_results = [None for _ in range(parallel_state.get_data_parallel_world_size())] + + all_preds = list(itertools.chain(*[item['predicted_token_ids'] for item in outputs])) + all_labels = list(itertools.chain(*[item['labels'] for item in outputs])) + all_inputs = list(itertools.chain(*[item['enc_inputs'] for item in outputs])) + + assert len(all_preds) == len(all_labels) + assert len(all_preds) == len(all_inputs) + + # Gather inputs, preds, labels from all workers + torch.distributed.all_gather_object( + gather_results, + [(input, pred, label) for (input, pred, label) in zip(all_inputs, all_preds, all_labels)], + group=parallel_state.get_data_parallel_group(), + ) + + # Deduplicate sentences that may have been distributed across multiple data parallel ranks. + if parallel_state.get_data_parallel_rank() == 0: + + gather_results_dedup = list(set(itertools.chain(*gather_results))) + + val_metric_dict = self.validation_metric.get_score( + [i[2] for i in gather_results_dedup], + [i[1] for i in gather_results_dedup], + ) + + for metric, val in val_metric_dict.items(): + logging.info(f'Validation {metric}: {val}') + val_metric = list(val_metric_dict.items())[0][1] + metric_name = list(val_metric_dict.items())[0][0] + else: + val_metric = torch.tensor(0.0).cuda() + metric_name = '' + + self.log(f'val_{metric_name}', val_metric, prog_bar=True, rank_zero_only=True, batch_size=1) + + gbs = self.cfg.global_batch_size + mbs = self.cfg.micro_batch_size + self._reconfigure_batch_sizes(gbs, mbs) + self.validation_step_outputs.clear() + + def test_step(self, batch, batch_idx): + result = self.predict_step(batch, batch_idx) + return result + + def on_test_epoch_end(self): + """ + This might still be broken for lightning 2.0. to fix: see + https://github.com/NVIDIA/NeMo/blob/9bdf4d12276ee8f95a340cf2f7f340e9b5b74a7e/docs/source/starthere/migration-guide.rst + """ + outputs = self.predict_step_outputs + average_metrics = {} + for output in outputs: + for key in output: + if key not in average_metrics: + average_metrics[key] = [] + if isinstance(output[key], torch.Tensor): + average_metrics[key].append(output[key].item()) + elif output[key] is None: + continue + else: + average_metrics[key].append(output[key]) + + for key in average_metrics: + average_metrics[key] = np.mean(average_metrics[key]).item() + logging.info(f'Test {key}: {average_metrics[key]}') + self.log(f'test_{key}', average_metrics[key], prog_bar=True, rank_zero_only=True, batch_size=1) + self.logger.experiment.add_scalar(f'Inf Cumulative {key}', average_metrics[key], 0) + + # save average metrics into json file + with open(os.path.join(self.logger.log_dir, 'output_metrics.json'), 'w') as f: + json.dump(average_metrics, f) + + def build_virtual_prompt_dataset( + self, dataset_paths, batch_size, for_train, drop_last, shuffle, num_workers, pin_memory + ): + dataset = T5SpeechLMDataset( + datasets=dataset_paths, + tokenizer=self.tokenizer, + sample_rate=self.cfg.data.get('sample_rate', 24000), + virtual_prompt_source=self.virtual_prompt_source, + task_templates=self.task_templates, + pseudo_tokens=self.pseudo_tokens, + pad_token_id=self.pad_token_id, + max_seq_length=self.cfg.data.get('max_seq_length', self.frozen_model.cfg.max_position_embeddings), + min_seq_length=self.cfg.data.get('min_seq_length', 1), + add_bos=self.cfg.data.get('add_bos', False), + add_eos=self.cfg.data.get('add_eos', True), + decoder_starts_with_pad=self.cfg.data.get('decoder_starts_with_pad', False), + add_eos_to_decoder_output=self.cfg.data.get('add_eos_to_decoder_output', True), + add_sentinel_to_input=self.cfg.data.get('add_sentinel_to_input', True), + ul2_prompt_token=self.cfg.data.get('ul2_prompt_token', None), + for_train=for_train, + segment_max_duration=self.cfg.data.get('segment_max_duration', None), + trim=self.cfg.data.get('trim', None), + trim_ref=self.cfg.data.get('trim_ref', None), + trim_top_db=self.cfg.data.get('trim_top_db', None), + trim_frame_length=self.cfg.data.get('trim_frame_length', None), + trim_hop_length=self.cfg.data.get('trim_hop_length', None), + pad_multiple=self.cfg.data.get('pad_multiple', 1), + pitch_augment=self.cfg.data.get('pitch_augment', None), + sup_data_path=self.cfg.data.get('sup_data_path', None), + codec_folder=self.cfg.data.get('codec_folder', None), + speech_offset=self.cfg.data.get('speech_offset', None), + train_task=self.cfg.data.get('train_task', "tts"), + seq_pattern=self.cfg.get('seq_pattern', 'delay_parallel'), + use_attention_prior=self.cfg.data.get('use_attention_prior', False), + attention_prior_scaling_factor=self.cfg.data.get('attention_prior_scaling_factor', 1.0), + cross_attention_epsilon=self.cfg.data.get('cross_attention_epsilon', 0.0), + lm_vocab_size=self.lm_vocab_size, + num_speech_codebooks=self.num_speech_codebooks, + codebook_fps=self.cfg.data.get('codebook_fps', 86), + add_special_tokens_to_only_first_codebook=self.cfg.data.get( + 'add_special_tokens_to_only_first_codebook', False + ), + context_pattern=self.cfg.data.get('context_pattern', 'parallel'), + context_duration_min=self.cfg.data.get('context_duration_min', 3.0), + context_duration_max=self.cfg.data.get('context_duration_max', 5.0), + g2p=self.cfg.data.get('g2p', None), + skip_datasets=self.cfg.data.get('skip_datasets', []), + english_only_model=self.cfg.get('english_only_model', False), + use_ipa=self.cfg.data.get('use_ipa', False), + context_conditioning=self.cfg.get('context_conditioning', "decoder"), + use_beta_binomial_interpolator=self.cfg.get('use_beta_binomial_interpolator', False), + context_slice_method=self.cfg.data.get('context_slice_method', 'random'), + phoneme_probability=self.cfg.data.get('phoneme_probability', 0.5), + encoder_type=self.cfg.data.get('encoder_type', 'single_transformer'), + ) + + rank = parallel_state.get_data_parallel_rank() + world_size = parallel_state.get_data_parallel_world_size() + sampler = torch.utils.data.distributed.DistributedSampler( + dataset, num_replicas=world_size, rank=rank, shuffle=shuffle, seed=self.cfg.seed + ) + + dataloader = torch.utils.data.DataLoader( + dataset, + collate_fn=dataset.collate_fn, + sampler=sampler, + batch_size=batch_size // world_size, + drop_last=drop_last, + num_workers=num_workers, + pin_memory=pin_memory, + persistent_workers=( + True if num_workers > 0 else False + ), # (@adithyare and @eharper) We need to set this to True to get around issues with spawn=True + ) + logging.info(f'build success: {len(dataloader)} {dataset_paths}') + if self.phoneme_tokenizer is None: + self.phoneme_tokenizer = dataset.phoneme_tokenizer + return dataset, dataloader + + def build_virtual_prompt_tarred_dataset( + self, dataset_paths, audio_path, batch_size, for_train, drop_last, shuffle, num_workers, pin_memory + ): + dataset = T5SpeechLMTarredDataset( + audio_tar_filepaths=audio_path, + manifest_filepath=dataset_paths, + tokenizer=self.tokenizer, + sample_rate=self.cfg.data.get('sample_rate', 24000), + virtual_prompt_source=self.virtual_prompt_source, + task_templates=self.task_templates, + pseudo_tokens=self.pseudo_tokens, + pad_token_id=self.pad_token_id, + max_seq_length=self.cfg.data.get('max_seq_length', self.frozen_model.cfg.max_position_embeddings), + min_seq_length=self.cfg.data.get('min_seq_length', 1), + shuffle_n=shuffle, + add_bos=self.cfg.data.get('add_bos', False), + add_eos=self.cfg.data.get('add_eos', True), + decoder_starts_with_pad=self.cfg.data.get('decoder_starts_with_pad', False), + add_eos_to_decoder_output=self.cfg.data.get('add_eos_to_decoder_output', True), + add_sentinel_to_input=self.cfg.data.get('add_sentinel_to_input', True), + ul2_prompt_token=self.cfg.data.get('ul2_prompt_token', None), + for_train=for_train, + segment_max_duration=self.cfg.data.get('segment_max_duration', None), + trim=self.cfg.data.get('trim', None), + trim_ref=self.cfg.data.get('trim_ref', None), + trim_top_db=self.cfg.data.get('trim_top_db', None), + trim_frame_length=self.cfg.data.get('trim_frame_length', None), + trim_hop_length=self.cfg.data.get('trim_hop_length', None), + pad_multiple=self.cfg.data.get('pad_multiple', 1), + pitch_augment=self.cfg.data.get('pitch_augment', None), + speech_offset=self.cfg.data.get('speech_offset', None), + train_task=self.cfg.data.get('train_task', "tts"), + seq_pattern=self.cfg.get('seq_pattern', 'delay_parallel'), + use_attention_prior=self.cfg.data.get('use_attention_prior', False), + attention_prior_scaling_factor=self.cfg.data.get('attention_prior_scaling_factor', 1.0), + cross_attention_epsilon=self.cfg.data.get('cross_attention_epsilon', 0.0), + lm_vocab_size=self.lm_vocab_size, + num_speech_codebooks=self.num_speech_codebooks, + ) + rank = parallel_state.get_data_parallel_rank() + world_size = parallel_state.get_data_parallel_world_size() + dataloader = torch.utils.data.DataLoader( + dataset, + collate_fn=dataset.collate_fn, + batch_size=batch_size // world_size, + drop_last=drop_last, + num_workers=num_workers, + pin_memory=pin_memory, + persistent_workers=( + True if num_workers > 0 else False + ), # (@adithyare and @eharper) We need to set this to True to get around issues with spawn=True + ) + logging.info(f'build success: {len(dataloader)} {dataset_paths}') + + return dataset, dataloader + + def process_text(self, input_text): + """ + Normalizes text for CER/WER calculation. + Taken from hallucination_eval.py + """ + # Convert text to lowercase + lower_case_text = input_text.lower() + + # Remove commas from text + no_comma_text = lower_case_text.replace(",", "") + + # Replace "-" with spaces + no_dash_text = no_comma_text.replace("-", " ") + + # Replace double spaces with single space + single_space_text = " ".join(no_dash_text.split()) + + single_space_text = single_space_text.translate(str.maketrans('', '', string.punctuation)) + + return single_space_text + + def predict_step( + self, batch: Any, batch_idx: int, dataloader_idx: int = 0, log_scalars=True, global_step=None + ) -> Any: + + with torch.no_grad(): + ( + virtual_tokens, + context_and_question_tokens, + enc_mask, + dec_input_raw, + dec_input_mask_raw, + labels, + loss_mask, + position_ids, + taskname_ids, + speech_mask, + context_and_question_tokens_lens, + cross_attention_prior, + text_limits, # [start of question token, question token len) in [0, enc_mask.size(1)) + lang, + question_texts, + ) = batch + + batch_size = virtual_tokens.size(0) + dec_input = ( + dec_input_raw * 1 + ) # (B, 8, T) # TODO @xueyang: apply clone() method bypasses this unnecessary computation. + dec_input_mask = dec_input_mask_raw * 1 # (B, T) + dec_input_mask[:, :] = 1 # Does not really matter + output_token_list = [] + + end_indices = {} + # pad dec_input (B, 8, T) to 1000 timesteps + max_inference_timesteps = self.cfg.get('max_inference_timesteps', 2000) + # TODO @xueyang: potential bug when max_inference_timesteps < dec_input.shape[2], then dec_input is clipped. + dec_input = torch.nn.functional.pad(dec_input, (0, max_inference_timesteps - dec_input.shape[2]), value=0) + dec_input[:, :, self.decoder_context_len + 1 :].zero_() + # TODO @xueyang: why not just declare torch.ones(dec_input_raw.size(0), max_inference_timesteps)? + dec_input_mask = torch.nn.functional.pad( + dec_input_mask, (0, max_inference_timesteps - dec_input_mask.shape[1]), value=1 + ) + + if self.inference_apply_text_cfg and self.inference_apply_audio_cfg: + question_limits = text_limits - virtual_tokens.size( + 1 + ) # (b, 2), reset question range to start from [pad] context, same start position as context_and_question_tokens. + question_start = question_limits[:, 0].unsqueeze(1) # (b, 1) + question_end = question_limits[:, 1].unsqueeze(1) # (b, 1) + + # duplicate and glue two batches into a single one. + virtual_tokens = torch.cat((virtual_tokens, virtual_tokens), dim=0) + taskname_ids = torch.cat((taskname_ids, taskname_ids), dim=0) + speech_mask = torch.cat((speech_mask, speech_mask), dim=0) + dec_input_mask = torch.cat((dec_input_mask, dec_input_mask), dim=0) + + if isinstance(context_and_question_tokens, list): # indicate self.encoder_type = "multi_transformers". + context_tokens, question_tokens = context_and_question_tokens + + # text + question_tokens_unconditioned = question_tokens.clone() + time_range = torch.arange( + question_tokens_unconditioned.size(2), device=question_tokens_unconditioned.device + ).unsqueeze(0) + question_mask = (time_range >= question_start) & ( + time_range < question_end + ) # create a mask for question only tokens. + question_tokens_unconditioned[:, 0][ + question_mask + ] = self.tokenizer.unk_id # only the first layer has non-zero IDs. + + # audio + context_tokens_unconditioned = context_tokens.clone() + context_tokens_unconditioned[:, :, :] = self.tokenizer.unk_id + + # concatenate both conditioned and unconditioned batches as a single one. + context_and_question_tokens = [ + torch.cat((context_tokens, context_tokens_unconditioned), dim=0), + torch.cat((question_tokens, question_tokens_unconditioned), dim=0), + ] + enc_mask = [torch.cat((mask, mask), dim=0) for mask in enc_mask] + dec_input = torch.cat((dec_input, dec_input), dim=0) + position_ids = [torch.cat((pos_ids, pos_ids), dim=0) for pos_ids in position_ids] + else: + assert ( + self.context_conditioning == "decoder" + ), f"The encoder_type is single_transformer. We expect context_condition is decoder: context_condition={self.context_conditioning}" + + # text + context_and_question_tokens_unconditioned = context_and_question_tokens.clone() + time_range = torch.arange( + context_and_question_tokens_unconditioned.size(2), + device=context_and_question_tokens_unconditioned.device, + ).unsqueeze( + 0 + ) # (1, max_context_question_tokens_len) + question_mask = (time_range >= question_start) & ( + time_range < question_end + ) # create a mask for question only tokens. + context_and_question_tokens_unconditioned[:, 0][ + question_mask + ] = self.tokenizer.unk_id # only the first layer has non-zero IDs. + + # audio + dec_input_unconditioned = dec_input.clone() + dec_input_unconditioned[:, :, 1 : self.decoder_context_len + 1] = ( + self.tokenizer.unk_id + ) # TODO @xueyang: switch to other token id if this one is conflict with text unk. + + # concatenate both conditioned and unconditioned batches as a single one. + context_and_question_tokens = torch.cat( + (context_and_question_tokens, context_and_question_tokens_unconditioned), dim=0 + ) + enc_mask = torch.cat((enc_mask, enc_mask), dim=0) + dec_input = torch.cat((dec_input, dec_input_unconditioned), dim=0) + position_ids = torch.cat((position_ids, position_ids), dim=0) + + # clean up useless variables. + del question_limits, question_start, question_end, time_range, question_mask + elif self.inference_apply_text_cfg: + # replace question token IDs with [UNK]'s id. No speech offset for Phoneme's [UNK]. Same op as train. + # instruction token IDs are bpe token IDs directly obtained from self.tokenizer without any offset. + # question token IDs are phoneme and grapheme token IDs and are offset by self.lm_vocab_size + # if under "Phoneme TTS" instruction, so exising no overlaps between instruction and question token IDs. + # question token IDs are bpe token IDs without any offset + # if under "Text to speech this" instruction, so existing overlaps between instruction and question token IDs. + question_limits = text_limits - virtual_tokens.size( + 1 + ) # (b, 2), reset question range to start from [pad] context, same start position as context_and_question_tokens. + question_start = question_limits[:, 0].unsqueeze(1) # (b, 1) + question_end = question_limits[:, 1].unsqueeze(1) # (b, 1) + + # duplicate and glue two batches into a single one. + virtual_tokens = torch.cat((virtual_tokens, virtual_tokens), dim=0) + taskname_ids = torch.cat((taskname_ids, taskname_ids), dim=0) + speech_mask = torch.cat((speech_mask, speech_mask), dim=0) + dec_input_mask = torch.cat((dec_input_mask, dec_input_mask), dim=0) + + if isinstance(context_and_question_tokens, list): # indicate self.encoder_type = "multi_transformers". + context_tokens, question_tokens = context_and_question_tokens + question_tokens_unconditioned = question_tokens.clone() + + time_range = torch.arange( + question_tokens_unconditioned.size(2), device=question_tokens_unconditioned.device + ).unsqueeze(0) + question_mask = (time_range >= question_start) & ( + time_range < question_end + ) # create a mask for question only tokens. + question_tokens_unconditioned[:, 0][ + question_mask + ] = self.tokenizer.unk_id # only the first layer has non-zero IDs. + + # concatenate both conditioned and unconditioned batches as a single one. + context_and_question_tokens = [ + torch.cat((context_tokens, context_tokens), dim=0), + torch.cat((question_tokens, question_tokens_unconditioned), dim=0), + ] + enc_mask = [torch.cat((mask, mask), dim=0) for mask in enc_mask] + dec_input = torch.cat((dec_input, dec_input), dim=0) + position_ids = [torch.cat((pos_ids, pos_ids), dim=0) for pos_ids in position_ids] + else: + assert ( + self.context_conditioning == "decoder" + ), f"The encoder_type is single_transformer. We expect context_condition is decoder: context_condition={self.context_conditioning}" + context_and_question_tokens_unconditioned = context_and_question_tokens.clone() + time_range = torch.arange( + context_and_question_tokens_unconditioned.size(2), + device=context_and_question_tokens_unconditioned.device, + ).unsqueeze( + 0 + ) # (1, max_context_question_tokens_len) + question_mask = (time_range >= question_start) & ( + time_range < question_end + ) # create a mask for question only tokens. + context_and_question_tokens_unconditioned[:, 0][ + question_mask + ] = self.tokenizer.unk_id # only the first layer has non-zero IDs. + + # concatenate both conditioned and unconditioned batches as a single one. + context_and_question_tokens = torch.cat( + (context_and_question_tokens, context_and_question_tokens_unconditioned), dim=0 + ) + enc_mask = torch.cat((enc_mask, enc_mask), dim=0) + dec_input = torch.cat((dec_input, dec_input), dim=0) + position_ids = torch.cat((position_ids, position_ids), dim=0) + + # clean up useless variables. + del question_limits, question_start, question_end, time_range, question_mask + elif self.inference_apply_audio_cfg: + # duplicate and glue two batches into a single one. + virtual_tokens = torch.cat((virtual_tokens, virtual_tokens), dim=0) + taskname_ids = torch.cat((taskname_ids, taskname_ids), dim=0) + speech_mask = torch.cat((speech_mask, speech_mask), dim=0) + dec_input_mask = torch.cat((dec_input_mask, dec_input_mask), dim=0) + + if isinstance( + context_and_question_tokens, list + ): # indicate that self.encoder_type = "multi_transformers" + context_tokens, question_tokens = context_and_question_tokens + context_tokens_unconditioned = context_tokens.clone() + context_tokens_unconditioned[:, :, :] = ( + self.tokenizer.unk_id + ) # TODO @xueyang: verify if extra tokens other than audio codec tokens are appended. + + # concatenate both conditioned and unconditioned batches as a single one. + context_and_question_tokens = [ + torch.cat((context_tokens, context_tokens_unconditioned), dim=0), + torch.cat((question_tokens, question_tokens), dim=0), + ] + enc_mask = [torch.cat((mask, mask), dim=0) for mask in enc_mask] + dec_input = torch.cat((dec_input, dec_input), dim=0) + position_ids = [torch.cat((pos_ids, pos_ids), dim=0) for pos_ids in position_ids] + else: + assert ( + self.context_conditioning == "decoder" + ), f"The encoder_type is single_transformer. We expect context_condition is decoder: context_condition={self.context_conditioning}" + dec_input_unconditioned = dec_input.clone() + dec_input_unconditioned[:, :, 1 : self.decoder_context_len + 1] = ( + self.tokenizer.unk_id + ) # TODO @xueyang: switch to other token id if this one is conflict with text unk. + + # concatenate both conditioned and unconditioned batches as a single one. + context_and_question_tokens = torch.cat( + (context_and_question_tokens, context_and_question_tokens), dim=0 + ) + enc_mask = torch.cat((enc_mask, enc_mask), dim=0) + dec_input = torch.cat((dec_input, dec_input_unconditioned), dim=0) + position_ids = torch.cat((position_ids, position_ids), dim=0) + else: + logging.debug( + f"Neither text or audio cfg logits are applied:" + f" self.inference_apply_text_cfg={self.inference_apply_text_cfg}," + f" self.inference_apply_audio_cfg={self.inference_apply_audio_cfg}" + ) + + end_inference_loop_at = None + fwd_bwd_function = get_forward_backward_func() + encoder_output = None + attention_probs_all = [] + start_time = time.time() + for t in range(self.decoder_context_len + 1, dec_input.shape[2] - 1): + # Start at 0 if encoder context, else context_len + if t % 100 == 0: + logging.info("Timestep {}".format(t)) + if t == end_inference_loop_at: + print("All ends detected") + break + + if isinstance(enc_mask, list): + encoder_max_sequence_len = [e.size(1) for e in enc_mask] + else: + encoder_max_sequence_len = enc_mask.size(1) + + # if context_condition is decoder, then t starts at [PAD] token represented as [0] * 8. + # if context_condition is encoder, then t starts at [CLS]. + if t == self.decoder_context_len + 1: + # Run first step manually + output_logits, _, token_and_speech_logits = self.forward( + virtual_tokens, + context_and_question_tokens, + enc_mask, + dec_input[ + :, :, : t + 1 + ], # tensors representing [CLS] + context audio tokens + [PAD] if context_condition is decoder, otherwise, tensors representing [CLS]. + dec_input_mask[:, : t + 1], # doesn't matter because of all ones. + position_ids, + taskname_ids, + labels=None, + speech_mask=speech_mask, + inference=True, + inference_step=0, + decoder_max_sequence_len=max_inference_timesteps, + encoder_max_sequence_len=encoder_max_sequence_len, + ) + encoder_output = token_and_speech_logits[-1] + + if isinstance(encoder_output, list): + encoder_output = [e.transpose(0, 1) for e in encoder_output] + else: + encoder_output = encoder_output.transpose(0, 1) + + else: + # Prepare batch + batch = [ + max_inference_timesteps, + encoder_max_sequence_len, + encoder_output, + enc_mask, + dec_input[:, :, : t + 1], + dec_input_mask[:, : t + 1], + position_ids, + taskname_ids, + speech_mask, + ] + + output_tensor = fwd_bwd_function( + forward_step_func=self.get_forward_output_only_func(), + data_iterator=iter( + [ + batch, + ] + ), + model=[self], + num_microbatches=get_num_microbatches(), + forward_only=True, + seq_length=t, + micro_batch_size=dec_input.shape[0], + ) + output_logits = output_tensor[0]['output_logits'] # (B, T, V, 8) or (2B, T, V, 8) + token_and_speech_logits = output_tensor[0]['token_and_speech_logits'] + + # when return_all_crossattention is False, attention_probs is None. + if self.frozen_model.enc_dec_model.return_all_crossattention_probs: + attention_probs = token_and_speech_logits[2] + attention_probs_mean = torch.stack(attention_probs).mean(dim=0) # B, 12, 1, enc_timesteps + attention_probs_all.append(attention_probs_mean) + + if self.inference_apply_text_cfg or self.inference_apply_audio_cfg: + # interpolate conditioned and unconditioned logits + token_logits = ( + self.inference_cfg_interpolation_scale * token_and_speech_logits[0][:batch_size] + + (1 - self.inference_cfg_interpolation_scale) * token_and_speech_logits[0][batch_size:] + ) + output_speech_logits = ( + self.inference_cfg_interpolation_scale * output_logits[:batch_size] + + (1 - self.inference_cfg_interpolation_scale) * output_logits[batch_size:] + ) + else: + token_logits = token_and_speech_logits[0] # (B, T, V) + output_speech_logits = output_logits + + token_logits_currtimestep = token_logits[:, -1, :] # (B, V) + token_preds = token_logits_currtimestep.argmax(dim=1) # (B,) + + if torch.count_nonzero(speech_mask) > 0: + output_logits_currtimestep = ( + output_speech_logits[:, -1, :, :] + .permute(0, 2, 1) + .contiguous() + .view(-1, self.speech_codebook_size) + ) # (B*8, V) + output_logits_currtimestep_conditioned = ( + output_logits[:batch_size][:, -1, :, :] + .permute(0, 2, 1) + .contiguous() + .view(-1, self.speech_codebook_size) + ) + output_logits_currtimestep_unconditioned = ( + output_logits[batch_size:][:, -1, :, :] + .permute(0, 2, 1) + .contiguous() + .view(-1, self.speech_codebook_size) + ) + else: + output_logits_currtimestep = token_logits_currtimestep # (B, V) + output_logits_currtimestep_conditioned = token_logits_currtimestep + output_logits_currtimestep_unconditioned = token_logits_currtimestep + + top_k = self.cfg.get('top_k', 80) + + # (B*8, 80) or (B, 80) + output_logits_currtimestep_topk = torch.topk(output_logits_currtimestep, top_k, dim=1)[0] + + # find indices which are not top k + indices_to_remove = output_logits_currtimestep < output_logits_currtimestep_topk[:, -1].unsqueeze(1) + # (B*8, 1024) or (B, 1024) + + if self.inference_apply_cfg_filter: + output_logits_currtimestep_rescored = output_logits_currtimestep_conditioned.clone() + else: + output_logits_currtimestep_rescored = output_logits_currtimestep.clone() + + output_logits_currtimestep_rescored[indices_to_remove] = -float('Inf') + + # logits interpolation between conditioned and unconditioned logits. + if ( + self.inference_apply_text_cfg or self.inference_apply_audio_cfg + ) and self.inference_apply_cfg_filter: + output_logits_currtimestep_rescored = ( + self.inference_cfg_filter_interpolation_scale * output_logits_currtimestep_rescored + + (1 - self.inference_cfg_filter_interpolation_scale) + * output_logits_currtimestep_unconditioned + ) + + temperature = self.cfg.get('temperature', 0.85) # Set temp 0.01 for greedy decoding + output_logits_currtimestep_rescored = output_logits_currtimestep_rescored / temperature + output_logits_currtimestep_rescored = torch.nn.functional.softmax( + output_logits_currtimestep_rescored, dim=1 + ) + + output_tokens_curr_timestep = torch.multinomial( + output_logits_currtimestep_rescored, num_samples=1 + ) # (B*8, 1) + + if torch.count_nonzero(speech_mask) > 0: + # Convert back to (B, 8) + output_tokens_curr_timestep = output_tokens_curr_timestep.view( + batch_size, self.num_speech_codebooks + ) + + for _b in range(token_preds.shape[0]): + if t > self.decoder_context_len + 10 and token_preds[_b] == self.tokenizer.eos_id: + if _b not in end_indices: + logging.info("End detected for item {}".format(_b) + " at timestep {}".format(t)) + end_indices[_b] = t + if len(end_indices) == token_preds.shape[0]: + end_inference_loop_at = t + self.num_speech_codebooks + + output_token_list.append(output_tokens_curr_timestep) + + # duplicate to 2b dim as input for the next iteration if enabling cfg. + if self.inference_apply_text_cfg or self.inference_apply_audio_cfg: + output_tokens_curr_timestep = torch.cat( + (output_tokens_curr_timestep, output_tokens_curr_timestep), dim=0 + ) + + if torch.count_nonzero(speech_mask) > 0: + dec_input_next_timestep = output_tokens_curr_timestep * 1 # (B,8) + dec_input_next_timestep[:, 0] = ( + dec_input_next_timestep[:, 0] + self.speech_offset + ) # add offset to first codebook + dec_input[:, :, t + 1] = dec_input_next_timestep * 1 + else: + dec_input[:, 0, t + 1] = output_tokens_curr_timestep.squeeze(1) + + # end of for loop + output_tokens_combined = torch.stack(output_token_list) # (T, B, 8) if speech else (T, B) + if torch.count_nonzero(speech_mask) > 0: + output_tokens_combined = output_tokens_combined.permute(1, 2, 0) # (B, 8, T) + else: + output_tokens_combined = output_tokens_combined.squeeze(2) + output_tokens_combined = output_tokens_combined.permute(1, 0) # (B, T) + + # consider only autoregressive time, disconsider loading eval models for RTF time + total_process_time = time.time() - start_time + + # Layerwise token error rate + ter_dict = {} + for i in range(self.num_speech_codebooks): + ter_dict[i] = {'hypothesis': [], 'gt': []} + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + if 'nemo_sv_model' not in self.additional_models: + nemo_sv_model = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained(model_name='titanet_large') + nemo_sv_model = nemo_sv_model.to(device) + nemo_sv_model.encoder.disable_torch_distributed = True # For multi-gpu training validation + nemo_sv_model.eval() + self.additional_models['nemo_sv_model'] = nemo_sv_model + logging.info(f"Loaded SV Model: {nemo_sv_model}") + else: + nemo_sv_model = self.additional_models['nemo_sv_model'] + + if 'asr_model' not in self.additional_models: + asr_model = self.cfg.get("asr_model_name", "stt_multilingual_fastconformer_hybrid_large_pc_blend_eu") + + if "hybrid" in asr_model: + model = nemo_asr.models.EncDecHybridRNNTCTCBPEModel + else: + model = nemo_asr.models.EncDecRNNTBPEModel + asr_model = model.from_pretrained(model_name=asr_model) + asr_model = asr_model.to(device) + asr_model.encoder.disable_torch_distributed = True # For multi-gpu training validation + asr_model.eval() + self.additional_models['asr_model'] = asr_model + logging.info(f"Loaded ASR Model: {asr_model}") + else: + asr_model = self.additional_models['asr_model'] + + asr_model_zh = None + if Lang.zh.value in lang: + if 'asr_model_zh' not in self.additional_models: + asr_model_zh = nemo_asr.models.EncDecRNNTModel.from_pretrained( + model_name="stt_zh_conformer_transducer_large" + ) + asr_model_zh = asr_model_zh.to(device) + asr_model_zh.eval() + self.additional_models['asr_model_zh'] = asr_model_zh + else: + asr_model_zh = self.additional_models['asr_model_zh'] + + if 'wavlm_sv_model' not in self.additional_models: + wavlm_sv_extractor = Wav2Vec2FeatureExtractor.from_pretrained('microsoft/wavlm-base-plus-sv') + wavlm_sv_model = WavLMForXVector.from_pretrained('microsoft/wavlm-base-plus-sv') + wavlm_sv_model = wavlm_sv_model.to(device) + wavlm_sv_model = wavlm_sv_model.eval() + self.additional_models['wavlm_sv_model'] = wavlm_sv_model + self.additional_models['wavlm_sv_extractor'] = wavlm_sv_extractor + logging.info(f"Loaded SV Model: {wavlm_sv_model}") + else: + wavlm_sv_model = self.additional_models['wavlm_sv_model'] + wavlm_sv_extractor = self.additional_models['wavlm_sv_extractor'] + + # load MOS estimator model only if True. + if self.estimate_mos: + # load mos estimator. + if 'squim_mos_model' not in self.additional_models: + squim_mos_model_full = SQUIM_SUBJECTIVE.get_model().to(device) + self.additional_models['squim_mos_model'] = squim_mos_model_full + else: + squim_mos_model_full = self.additional_models['squim_mos_model'] + + # load non-matching reference clean audio. + ref_16khz_wav, _ = librosa.load(self.non_matching_ref_audio_filepath, sr=16000) + + # prepare MOS estimator by taking a single audio example as an input. + squim_mos_model = partial( + squim_mos_model_full, reference=torch.from_numpy(ref_16khz_wav).to(device).unsqueeze(0) + ) + + _exp_dir_path = self.logger.log_dir + _exp_dir_path = _exp_dir_path + '/Sample_Audios' + if not os.path.exists(_exp_dir_path): + os.mkdir(_exp_dir_path) + + squim_mos_list_pred = [] + squim_mos_list_context = [] + squim_mos_list_gt = [] + similarity_list = [] + similarity_list_wavlm = [] + pred_context_similarity_list = [] + pred_context_similarity_list_wavlm = [] + gt_context_similarity_list = [] + gt_context_similarity_list_wavlm = [] + question_type = [] + + # predicting audio + batch_size = output_tokens_combined.shape[0] + test_dataloader_batch_size = batch_size + # self.test_dataloader() is not defined during validation + if isinstance(self.test_dataloader(), torch.utils.data.DataLoader): + test_dataloader_batch_size = self.test_dataloader().batch_size + + # logging attention maps. + # empty attention_probs_all indicates self.frozen_model.enc_dec_model.return_all_crossattention_probs is False. + if len(attention_probs_all) != 0: + attention_probs_all = torch.cat(attention_probs_all, dim=2) # B, 12, dec_timesteps, enc_timesteps + attention_probs_all = attention_probs_all.mean(dim=1) # B, dec_timesteps, enc_timesteps + + for i in range(batch_size): + text_end_step = text_limits[i, 1].item() + text_start_step = text_limits[i, 0].item() + end_index = end_indices.get(i, output_tokens_combined.shape[2]) + if len(attention_probs_all) != 0: + attention_probs_example = attention_probs_all[i][ + : end_index - (1 + self.decoder_context_len), text_start_step:text_end_step + ] # T, enc_timesteps + attention_map = attention_probs_example.float().cpu().numpy().T + alignment_image = plot_alignment_to_numpy_for_speechllm( + attention_map, + phoneme_ver=1, + phoneme_seq=None, + ) + + if global_step is not None: + # During validation, step is simply global_step + i + step = global_step + i + else: + # During inference, step is the index of the sample + step = batch_idx * test_dataloader_batch_size + i + + self.logger.experiment.add_image( + "Inf Attention Map", + alignment_image, + step, + dataformats="HWC", + ) + # Save attention image to file + alignment_fp = os.path.join(_exp_dir_path, f'attention_map_{step}.png') + imageio.imwrite(alignment_fp, alignment_image) + + wer_score = 0 + audio_to_pred = [] + audio_to_pred_zh = [] + total_audio_seconds = 0 + for i in range(batch_size): + if global_step is not None: + # During validation, step is simply global_step + i + step = global_step + i + else: + # During inference, step is the index of the sample + step = batch_idx * test_dataloader_batch_size + i + + audio_len = self.decoder_context_len + (labels[i][0][self.decoder_context_len :] != 0).sum().item() + + if torch.count_nonzero(speech_mask) > 0: + dec_input_to_1024 = self.convert_tokens_to_range(dec_input_raw[i, :, 0:audio_len]) + dec_input_to_1024_answer = dec_input_to_1024[:, self.decoder_context_len + 1 :] + dec_input_wav = self.decode_wav_from_codec_model(dec_input_to_1024_answer) + self.logger.experiment.add_audio("Inf Dec Input Wav", dec_input_wav, step, self.sample_rate) + + predicted_tokens = output_tokens_combined[i] # Should not contain context even if decoder context + if i in end_indices: + logging.info(f"Clipping until end index for audio {i}") + if self.cfg.get('seq_pattern', 'parallel') == 'delay_parallel': + predicted_tokens = predicted_tokens[ + :, 0 : end_indices[i] - (1 + self.decoder_context_len) + self.num_speech_codebooks + ] # trim to audio length + else: + predicted_tokens = predicted_tokens[ + :, 0 : end_indices[i] - (1 + self.decoder_context_len) + ] # trim to audio length + + pred_img = predicted_tokens.data.cpu().float().numpy() + dec_inp_img = dec_input_to_1024.data.cpu().float().numpy() + start_time = time.time() + predicted_tokens = self.convert_tokens_to_range(predicted_tokens, apply_offset_correction=False) + predicted_wav = self.decode_wav_from_codec_model(predicted_tokens) + # accumulate audio length in seconds and process time in seconds to the RTF + total_process_time = total_process_time + (time.time() - start_time) + total_audio_seconds = total_audio_seconds + predicted_wav.size(-1) / self.sample_rate + + self.logger.experiment.add_audio("Inf Pred Wav", predicted_wav, step, self.sample_rate) + self.logger.experiment.add_image( + "Inf Pred Tokens", + plot_codec_to_numpy(pred_img), + step, + dataformats="HWC", + ) + self.logger.experiment.add_image( + "Inf Dec Input Tokens", + plot_codec_to_numpy(dec_inp_img), + step, + dataformats="HWC", + ) + + # save predicted_wav and gt_wav to a wav files in dir_path + if global_step is not None: + # During training, overwrite the wav file from the previous validation + wav_num = i + else: + wav_num = step + + audio_fp_pred = os.path.join(_exp_dir_path, f'predicted_wav_{wav_num}.wav') + sf.write(audio_fp_pred, predicted_wav.cpu().numpy(), self.sample_rate) + audio_fp_gt = os.path.join(_exp_dir_path, f'dec_input_wav_{wav_num}.wav') + sf.write(audio_fp_gt, dec_input_wav.cpu().numpy(), self.sample_rate) + + # speaker verification evaluation using nemo model + spk_embedding_pred = nemo_sv_model.get_embedding(audio_fp_pred) + spk_embedding_pred = spk_embedding_pred.cpu().detach().numpy().flatten() + spk_embedding_gt = nemo_sv_model.get_embedding(audio_fp_gt) + spk_embedding_gt = spk_embedding_gt.cpu().detach().numpy().flatten() + similarity = np.dot(spk_embedding_pred, spk_embedding_gt) / ( + np.linalg.norm(spk_embedding_pred) * np.linalg.norm(spk_embedding_gt) + ) + + if log_scalars: + self.logger.experiment.add_scalar(f'Inf SV Cossim Individual Sample', similarity, step) + similarity_list.append(similarity) + + # speaker verification evaluation using wavlm model + gt_16khz_wav, _ = librosa.load(audio_fp_gt, sr=16000) + pred_16khz_wav, _ = librosa.load(audio_fp_pred, sr=16000) + inputs_wavlm = wavlm_sv_extractor( + [pred_16khz_wav, gt_16khz_wav], padding=True, return_tensors="pt", sampling_rate=16000 + ) + for key in inputs_wavlm.keys(): + inputs_wavlm[key] = inputs_wavlm[key].to(device) + + with torch.no_grad(): + wavlm_embeddings = wavlm_sv_model(**inputs_wavlm).embeddings + wavlm_embeddings = torch.nn.functional.normalize(wavlm_embeddings, dim=-1).cpu() + + spk_embedding_pred_wavlm = wavlm_embeddings[0].cpu().detach().numpy().flatten() + spk_embedding_gt_wavlm = wavlm_embeddings[1].cpu().detach().numpy().flatten() + similarity_wavlm = np.dot(spk_embedding_pred_wavlm, spk_embedding_gt_wavlm) / ( + np.linalg.norm(spk_embedding_pred_wavlm) * np.linalg.norm(spk_embedding_gt_wavlm) + ) + similarity_list_wavlm.append(similarity_wavlm) + + if lang[i] == Lang.zh.value: + audio_to_pred_zh.append({"step": i, "audio": audio_fp_pred}) + audio_to_pred_zh.append({"step": i, "audio": audio_fp_gt}) + else: + audio_to_pred.append({"step": i, "audio": audio_fp_pred}) + audio_to_pred.append({"step": i, "audio": audio_fp_gt}) + + if isinstance(context_and_question_tokens, list): + context_tokens, question_tokens = context_and_question_tokens + input_token_list = [ + question_tokens[i, 0, j].item() + for j in range(context_and_question_tokens_lens[1][i].item()) + ] + input_token_list = [ + (ti, t) for ti, t in enumerate(input_token_list) if t != 0 and t < self.speech_offset + ] + context_end_step = context_and_question_tokens_lens[0][i] + context_tokens = context_tokens[i][:, :context_end_step] + else: + input_token_list = [ + context_and_question_tokens[i, 0, j].item() + for j in range(context_and_question_tokens.shape[2]) + ] + input_token_list = [ + (ti, t) for ti, t in enumerate(input_token_list) if t != 0 and t < self.speech_offset + ] + context_end_step = input_token_list[0][0] + context_tokens = context_and_question_tokens[i][:, :context_end_step] + + spk_embedding_context = spk_embedding_gt + spk_embedding_context_wavlm = spk_embedding_gt_wavlm + if self.decoder_context_len > 0: + context_tokens = dec_input_to_1024[:, : self.decoder_context_len + 1] + context_wav = self.decode_wav_from_codec_model(context_tokens) + elif context_end_step > 1: + is_speech_context = context_tokens[1, :].sum().item() > 0 + if is_speech_context: + context_tokens = self.convert_tokens_to_range(context_tokens, pattern=self.context_pattern) + context_wav = self.decode_wav_from_codec_model(context_tokens) + else: + context_wav = None + _context_token_list = [v.item() for v in context_tokens[0, :]] + _context_text = self.frozen_model.tokenizer.ids_to_text( + [v for v in _context_token_list if v < self.lm_vocab_size] + ) + self.logger.experiment.add_text("Context Text", _context_text, self.global_step) + + else: + context_wav = None + + if context_wav is not None: + self.logger.experiment.add_audio("Context Wav", context_wav, step, self.sample_rate) + context_wav_fp = os.path.join(_exp_dir_path, f'context_wav_{wav_num}.wav') + sf.write(context_wav_fp, context_wav.cpu().numpy(), self.sample_rate) + # titanet + spk_embedding_context = nemo_sv_model.get_embedding(context_wav_fp) + spk_embedding_context = spk_embedding_context.cpu().detach().numpy().flatten() + # wavlm + context_wavlm_wav, _ = librosa.load(context_wav_fp, sr=16000) + inputs_wavlm = wavlm_sv_extractor( + [context_wavlm_wav], padding=True, return_tensors="pt", sampling_rate=16000 + ) + for key in inputs_wavlm.keys(): + inputs_wavlm[key] = inputs_wavlm[key].to(device) + + with torch.no_grad(): + wavlm_embeddings = wavlm_sv_model(**inputs_wavlm).embeddings + wavlm_embeddings = torch.nn.functional.normalize(wavlm_embeddings, dim=-1).cpu() + + spk_embedding_context_wavlm = wavlm_embeddings[0].cpu().detach().numpy().flatten() + + pred_similarity_context = np.dot(spk_embedding_context, spk_embedding_pred) / ( + np.linalg.norm(spk_embedding_context) * np.linalg.norm(spk_embedding_pred) + ) + gt_similarity_context = np.dot(spk_embedding_context, spk_embedding_gt) / ( + np.linalg.norm(spk_embedding_context) * np.linalg.norm(spk_embedding_gt) + ) + + pred_similarity_context_wavlm = np.dot(spk_embedding_context_wavlm, spk_embedding_pred_wavlm) / ( + np.linalg.norm(spk_embedding_context_wavlm) * np.linalg.norm(spk_embedding_pred_wavlm) + ) + gt_similarity_context_wavlm = np.dot(spk_embedding_context_wavlm, spk_embedding_gt_wavlm) / ( + np.linalg.norm(spk_embedding_context_wavlm) * np.linalg.norm(spk_embedding_gt_wavlm) + ) + + if log_scalars: + self.logger.experiment.add_scalar(f'Inf SV Cossim Context Pred', pred_similarity_context, step) + self.logger.experiment.add_scalar(f'Inf SV Cossim Context GT', gt_similarity_context, step) + pred_context_similarity_list.append(pred_similarity_context) + gt_context_similarity_list.append(gt_similarity_context) + pred_context_similarity_list_wavlm.append(pred_similarity_context_wavlm) + gt_context_similarity_list_wavlm.append(gt_similarity_context_wavlm) + + task_question = self.frozen_model.tokenizer.ids_to_text( + [v[1] for v in input_token_list if v[1] < self.lm_vocab_size] + ) + self.logger.experiment.add_text("Inf Task Question", task_question, step) + if "Phoneme TTS" in task_question: + question_type.append("Phoneme TTS") + elif "Text to speech this" in task_question: + question_type.append("Text to speech this") + else: + question_type.append("Other") + + task_question_phoneme_tokens = [ + v[1] - self.lm_vocab_size for v in input_token_list if v[1] >= self.lm_vocab_size + ] + if len(task_question_phoneme_tokens) > 0: + phoneme_text = self.phoneme_tokenizer.decode(task_question_phoneme_tokens) + self.logger.experiment.add_text("Inf Task Question Phoneme Text", phoneme_text, step) + + # store predicted_tokens for each layer to compute token error rate + for layer_idx in range(self.num_speech_codebooks): + ter_dict[layer_idx]['hypothesis'].append(predicted_tokens[layer_idx].cpu().numpy().tolist()) + ter_dict[layer_idx]['gt'].append(dec_input_to_1024_answer[layer_idx].cpu().numpy().tolist()) + + # estimate MOS scores. + if self.estimate_mos: + squim_mos_score_pred = squim_mos_model( + torch.from_numpy(pred_16khz_wav).to(device).unsqueeze(0) + ).item() + squim_mos_score_gt = squim_mos_model( + torch.from_numpy(gt_16khz_wav).to(device).unsqueeze(0) + ).item() + if context_wav is not None: + squim_mos_score_context = squim_mos_model(context_wav.to(device).unsqueeze(0)).item() + squim_mos_list_context.append(squim_mos_score_context) + squim_mos_list_pred.append(squim_mos_score_pred) + squim_mos_list_gt.append(squim_mos_score_gt) + else: + r = labels[i, 0].long() + nzm = r != 0 + r = r.tolist()[:-1] + nzm = nzm[:-1] + h = output_tokens_combined[i].long() * nzm + h = h.tolist() + cur_wer_score = editdistance.eval(r, h) + if log_scalars: + self.logger.experiment.add_scalar('WER', cur_wer_score, step) + logging.info(f"current wer score : {cur_wer_score}") + wer_score += cur_wer_score + if wer_score > 0: + wer_score /= batch_size + if log_scalars: + self.logger.experiment.add_scalar('AVG WER', wer_score, step) + logging.info(f"average wer score : {wer_score}") + + # compute token error rate for each layer + if log_scalars: + for layer_idx in range(self.num_speech_codebooks): + wer = word_error_rate(ter_dict[layer_idx]['hypothesis'], ter_dict[layer_idx]['gt'], use_cer=True) + self.logger.experiment.add_scalar(f'Inf TER Layer {layer_idx}', wer, 0) + + greedy_transcripts = [] + if len(audio_to_pred) > 0: + greedy_transcripts.extend(asr_model.transcribe([i["audio"] for i in audio_to_pred])[0]) + if len(audio_to_pred_zh) > 0: + greedy_transcripts.extend(asr_model_zh.transcribe([i["audio"] for i in audio_to_pred_zh])[0]) + + all_audio_to_pred = audio_to_pred + audio_to_pred_zh + # Note WER over the batch is not equal to WER(sample) / batch_size, but approx. here + + # These are between ASR outputs of GT audio and predicted audio + wer_batch = [] + cer_batch = [] + cer_phoneme = [] + wer_phoneme = [] + cer_tts = [] + wer_tts = [] + + # These are between ASR output of Pred audio and GT text + wer_batch_gt = [] + cer_batch_gt = [] + cer_phoneme_gt = [] + wer_phoneme_gt = [] + cer_tts_gt = [] + wer_tts_gt = [] + + for i in range(0, len(greedy_transcripts) - 1, 2): + assert all_audio_to_pred[i]["step"] == all_audio_to_pred[i + 1]["step"] + step = batch_idx * test_dataloader_batch_size + all_audio_to_pred[i]["step"] + question_text = question_texts[i // 2] + + # No need to process text since both are ASR outputs + cer_sample = word_error_rate([greedy_transcripts[i]], [greedy_transcripts[i + 1]], use_cer=True) + wer_sample = word_error_rate([greedy_transcripts[i]], [greedy_transcripts[i + 1]], use_cer=False) + + # Processing text since one is ASR output and the other is the GT text + cer_gt = word_error_rate( + [self.process_text(greedy_transcripts[i])], [self.process_text(question_text)], use_cer=True + ) + wer_gt = word_error_rate( + [self.process_text(greedy_transcripts[i])], [self.process_text(question_text)], use_cer=False + ) + + self.logger.experiment.add_text("Inf Predicted Text", greedy_transcripts[i], step) + self.logger.experiment.add_text("Inf GT Text", greedy_transcripts[i + 1], step) + self.logger.experiment.add_text("Inf Question Text", question_text, step) + if log_scalars: + self.logger.experiment.add_scalar(f'Inf CER Transcript', cer_sample, step) + self.logger.experiment.add_scalar(f'Inf WER Transcript', wer_sample, step) + self.logger.experiment.add_scalar(f'Inf CER GT Transcript', cer_gt, step) + cer_batch.append(cer_sample) + wer_batch.append(wer_sample) + cer_batch_gt.append(cer_gt) + wer_batch_gt.append(wer_gt) + if question_type[all_audio_to_pred[i]["step"]] == "Phoneme TTS": + if log_scalars: + self.logger.experiment.add_scalar(f'Inf CER Phoneme Task', cer_sample, step) + self.logger.experiment.add_scalar(f'Inf WER Phoneme Task', wer_sample, step) + self.logger.experiment.add_scalar(f'Inf CER GT Phoneme Task', cer_gt, step) + cer_phoneme.append(cer_sample) + wer_phoneme.append(wer_sample) + cer_phoneme_gt.append(cer_gt) + wer_phoneme_gt.append(wer_gt) + elif question_type[all_audio_to_pred[i]["step"]] == "Text to speech this": + if log_scalars: + self.logger.experiment.add_scalar(f'Inf CER TTS Task', cer_sample, step) + self.logger.experiment.add_scalar(f'Inf WER TTS Task', wer_sample, step) + self.logger.experiment.add_scalar(f'Inf CER GT TTS Task', cer_gt, step) + cer_tts.append(cer_sample) + wer_tts.append(wer_sample) + cer_tts_gt.append(cer_gt) + wer_tts_gt.append(wer_gt) + + # compute average similarity + similarity_avg = np.mean(similarity_list) + pred_context_similarity_avg = np.mean(pred_context_similarity_list) + gt_context_similarity_avg = np.mean(gt_context_similarity_list) + similarity_avg_wavlm = np.mean(similarity_list_wavlm) + pred_context_similarity_avg_wavlm = np.mean(pred_context_similarity_list_wavlm) + gt_context_similarity_avg_wavlm = np.mean(gt_context_similarity_list_wavlm) + + if log_scalars: + self.logger.experiment.add_scalar(f'Inf SV Avg Cossim', similarity_avg, batch_idx) + self.predict_step_outputs.append( + { + 'titanet_avg_cossim': similarity_avg, + 'titanet_avg_cossim_context_pred': pred_context_similarity_avg, + 'titanet_avg_cossim_context_gt': gt_context_similarity_avg, + 'wavlm_avg_cossim': similarity_avg_wavlm, + 'wavlm_avg_cossim_context_pred': pred_context_similarity_avg_wavlm, + 'wavlm_avg_cossim_context_gt': gt_context_similarity_avg_wavlm, + 'squim_mos_pred': np.mean(squim_mos_list_pred) if len(squim_mos_list_pred) > 0 else None, + 'squim_mos_context': np.mean(squim_mos_list_context) if len(squim_mos_list_context) > 0 else None, + 'squim_mos_gt': np.mean(squim_mos_list_gt) if len(squim_mos_list_gt) > 0 else None, + 'cer_transcript': np.mean(cer_batch), + 'wer_transcript': np.mean(wer_batch), + 'cer_phoneme': np.mean(cer_phoneme) if len(cer_phoneme) > 0 else None, + 'wer_phoneme': np.mean(wer_phoneme) if len(wer_phoneme) > 0 else None, + 'cer_tts': np.mean(cer_tts) if len(cer_tts) > 0 else None, + 'wer_tts': np.mean(wer_tts) if len(wer_tts) > 0 else None, + 'cer_transcript_gt': np.mean(cer_batch_gt), + 'wer_transcript_gt': np.mean(wer_batch_gt), + 'cer_phoneme_gt': np.mean(cer_phoneme_gt) if len(cer_phoneme_gt) > 0 else None, + 'wer_phoneme_gt': np.mean(wer_phoneme_gt) if len(wer_phoneme_gt) > 0 else None, + 'cer_tts_gt': np.mean(cer_tts_gt) if len(cer_tts_gt) > 0 else None, + 'wer_tts_gt': np.mean(wer_tts_gt) if len(wer_tts_gt) > 0 else None, + "RTF": total_process_time / total_audio_seconds, + } + ) + + # TODO @xueyang: PTL 2.0+ patch. Signature of method `on_predict_epoch_end` does not match signature of the base method in PTL class 'ModelHooks'. + # Remove the `outputs` param and choose `self.predict_step_output` instead. + def on_predict_epoch_end(self, outputs: List[Any]) -> None: + + gather_results = [None for _ in range(parallel_state.get_data_parallel_world_size())] + all_preds = list(itertools.chain(*[item['preds_text'] for item in outputs[0]])) + all_labels = list(itertools.chain(*[item['labels_text'] for item in outputs[0]])) + all_inputs = list(itertools.chain(*[item['input_text'] for item in outputs[0]])) + + assert len(all_preds) == len(all_labels) + assert len(all_preds) == len(all_inputs) + + # Gather inputs, predictions, and ground truths from all workers + torch.distributed.all_gather_object( + gather_results, + [(input, pred, label) for (input, pred, label) in zip(all_inputs, all_preds, all_labels)], + group=parallel_state.get_data_parallel_group(), + ) + + # Deduplicate sentences that may have been distributed across multiple data parallel ranks. + if parallel_state.get_data_parallel_rank() == 0: + gather_results_dedup = list(set(itertools.chain(*gather_results))) + + input_prediction_pair = [] + correct = 0 + for input, pred, label in gather_results_dedup: + input_prediction_pair.append((input, pred)) + if label: + if pred == label: + correct += 1 + + acc = correct / len(gather_results_dedup) if all_labels[0] else None + logging.info(f'Prediction results: {acc}') + logging.info(f'Test finish') diff --git a/nemo/collections/tts/models/ssl_tts.py b/nemo/collections/tts/models/ssl_tts.py index 298a1a599008..f2cc4f798ec5 100644 --- a/nemo/collections/tts/models/ssl_tts.py +++ b/nemo/collections/tts/models/ssl_tts.py @@ -18,10 +18,10 @@ import librosa import torch from hydra.utils import instantiate +from lightning.pytorch import Trainer +from lightning.pytorch.loggers import TensorBoardLogger +from lightning.pytorch.utilities.combined_loader import CombinedLoader from omegaconf import DictConfig, OmegaConf -from pytorch_lightning import Trainer -from pytorch_lightning.loggers import TensorBoardLogger -from pytorch_lightning.utilities.combined_loader import CombinedLoader from nemo.collections.asr.losses.angularloss import AngularSoftmaxLoss from nemo.collections.tts.data.dataset import TTSDataset @@ -38,10 +38,10 @@ class SSLDisentangler(ModelPT): """ SSLDisentangler is a Conformer based model for extracting disentangled content and speaker embeddings - from an audio waveform. This model uses a pre-trained Conformer SSL model. To extract the linguistic content - and speaker representations using a pre-trained Conformer, two randomly initialized downstream - heads are added and the entire setup is finetuned in multi-task manner for speech recognition and speaker verification. - These representations can be used by FastPitchModel_SSL for voice conversion by swapping the speaker embedding + from an audio waveform. This model uses a pre-trained Conformer SSL model. To extract the linguistic content + and speaker representations using a pre-trained Conformer, two randomly initialized downstream + heads are added and the entire setup is finetuned in multi-task manner for speech recognition and speaker verification. + These representations can be used by FastPitchModel_SSL for voice conversion by swapping the speaker embedding of a given source utterance, with the speaker embedding of a target speaker. """ @@ -92,7 +92,10 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): librosa_mel_filter = librosa.filters.mel( sr=stft_cfg.sample_rate, n_fft=stft_cfg.n_fft, n_mels=stft_cfg.features, fmin=0, fmax=8000 ) - fb = torch.tensor(librosa_mel_filter, dtype=torch.float,).unsqueeze(0) + fb = torch.tensor( + librosa_mel_filter, + dtype=torch.float, + ).unsqueeze(0) self.register_buffer("fb", fb) @@ -212,7 +215,10 @@ def configure_optimizers(self): sched_downstream_config = optim_downstream_config.pop("sched", None) OmegaConf.set_struct(optim_downstream_config, True) - optim_backbone = instantiate(optim_backbone_config, params=self.encoder.parameters(),) + optim_backbone = instantiate( + optim_backbone_config, + params=self.encoder.parameters(), + ) optim_downstream = instantiate( optim_downstream_config, params=itertools.chain( @@ -254,7 +260,8 @@ def configure_optimizers(self): def forward(self, input_signal=None, input_signal_length=None, normalize_content=True): processed_signal, processed_signal_length = self.preprocessor_disentangler( - input_signal=input_signal, length=input_signal_length, + input_signal=input_signal, + length=input_signal_length, ) encoded, encoded_len = self.encoder(audio_signal=processed_signal, length=processed_signal_length) # b,c,t @@ -292,7 +299,9 @@ def forward_for_export(self, input_signal=None, input_signal_length=None, normal # Same as forward right now. Earlier version of encoder had a different forward for export. # This function is still kept for compatibility with older evaluation/inference scripts. return self.forward( - input_signal=input_signal, input_signal_length=input_signal_length, normalize_content=normalize_content, + input_signal=input_signal, + input_signal_length=input_signal_length, + normalize_content=normalize_content, ) def training_step(self, batch, batch_idx): diff --git a/nemo/collections/tts/models/tacotron2.py b/nemo/collections/tts/models/tacotron2.py index 2fb005d80ca6..33d476029011 100644 --- a/nemo/collections/tts/models/tacotron2.py +++ b/nemo/collections/tts/models/tacotron2.py @@ -18,9 +18,9 @@ import torch from hydra.utils import instantiate +from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger from omegaconf import MISSING, DictConfig, OmegaConf, open_dict from omegaconf.errors import ConfigAttributeError -from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger from torch import nn from nemo.collections.common.parts.preprocessing import parsers diff --git a/nemo/collections/tts/models/univnet.py b/nemo/collections/tts/models/univnet.py index 64ee891b0754..12500be8d180 100644 --- a/nemo/collections/tts/models/univnet.py +++ b/nemo/collections/tts/models/univnet.py @@ -18,8 +18,8 @@ import torch import torch.nn.functional as F from hydra.utils import instantiate +from lightning.pytorch.loggers.wandb import WandbLogger from omegaconf import DictConfig, OmegaConf, open_dict -from pytorch_lightning.loggers.wandb import WandbLogger from nemo.collections.tts.losses.hifigan_losses import DiscriminatorLoss, GeneratorLoss from nemo.collections.tts.losses.stftlosses import MultiResolutionSTFTLoss @@ -114,8 +114,14 @@ def configure_optimizers(self): if sched_config is None and 'sched' in self._cfg: sched_config = self._cfg.sched - optim_g = instantiate(optim_config, params=self.generator.parameters(),) - optim_d = instantiate(optim_config, params=itertools.chain(self.mrd.parameters(), self.mpd.parameters()),) + optim_g = instantiate( + optim_config, + params=self.generator.parameters(), + ) + optim_d = instantiate( + optim_config, + params=itertools.chain(self.mrd.parameters(), self.mpd.parameters()), + ) if sched_config is not None: max_steps = self._cfg.get("max_steps", None) @@ -290,7 +296,7 @@ def stft(x): comp = torch.stft(x.squeeze(1), n_fft=1024, hop_length=256, win_length=1024, return_complex=True) comp = torch.view_as_real(comp) real, imag = comp[..., 0], comp[..., 1] - mags = torch.sqrt(real ** 2 + imag ** 2) + mags = torch.sqrt(real**2 + imag**2) phase = torch.atan2(imag, real) return mags, phase diff --git a/nemo/collections/tts/models/vits.py b/nemo/collections/tts/models/vits.py index 4a891fa8823e..3c53442a0863 100644 --- a/nemo/collections/tts/models/vits.py +++ b/nemo/collections/tts/models/vits.py @@ -18,9 +18,9 @@ import omegaconf import torch from hydra.utils import instantiate +from lightning.pytorch import Trainer +from lightning.pytorch.loggers import WandbLogger from omegaconf import DictConfig, OmegaConf -from pytorch_lightning import Trainer -from pytorch_lightning.loggers import WandbLogger from torch.cuda.amp import autocast from torch.nn import functional as F diff --git a/nemo/collections/tts/models/waveglow.py b/nemo/collections/tts/models/waveglow.py index 728b5b94b084..04eec734b26e 100644 --- a/nemo/collections/tts/models/waveglow.py +++ b/nemo/collections/tts/models/waveglow.py @@ -15,8 +15,8 @@ import torch from hydra.utils import instantiate +from lightning.pytorch.loggers import TensorBoardLogger from omegaconf import DictConfig, open_dict -from pytorch_lightning.loggers import TensorBoardLogger from nemo.collections.tts.losses.waveglowloss import WaveGlowLoss from nemo.collections.tts.models.base import GlowVocoder diff --git a/nemo/collections/tts/parts/utils/callbacks.py b/nemo/collections/tts/parts/utils/callbacks.py index c4ec09031cf9..1856dee0ce0f 100644 --- a/nemo/collections/tts/parts/utils/callbacks.py +++ b/nemo/collections/tts/parts/utils/callbacks.py @@ -23,10 +23,10 @@ import soundfile as sf import torch from einops import rearrange -from pytorch_lightning import Callback, LightningModule, Trainer -from pytorch_lightning.loggers import TensorBoardLogger -from pytorch_lightning.loggers.logger import Logger -from pytorch_lightning.loggers.wandb import WandbLogger +from lightning.pytorch import Callback, LightningModule, Trainer +from lightning.pytorch.loggers import TensorBoardLogger +from lightning.pytorch.loggers.logger import Logger +from lightning.pytorch.loggers.wandb import WandbLogger from torch import Tensor from nemo.collections.tts.parts.utils.helpers import create_plot @@ -194,7 +194,10 @@ def _log_audio(self, audio: AudioArtifact, log_dir: Path, step: int): if self.tensorboard_logger: self.tensorboard_logger.add_audio( - tag=audio.id, snd_tensor=audio.data, global_step=step, sample_rate=audio.sample_rate, + tag=audio.id, + snd_tensor=audio.data, + global_step=step, + sample_rate=audio.sample_rate, ) if self.wandb_logger: @@ -212,7 +215,10 @@ def _log_image(self, image: ImageArtifact, log_dir: Path, step: int): if self.tensorboard_logger: self.tensorboard_logger.add_image( - tag=image.id, img_tensor=image_plot, global_step=step, dataformats="HWC", + tag=image.id, + img_tensor=image_plot, + global_step=step, + dataformats="HWC", ) if self.wandb_logger: @@ -220,8 +226,7 @@ def _log_image(self, image: ImageArtifact, log_dir: Path, step: int): self.wandb_logger.log({image.id: wandb_image}) def _log_artifacts(self, audio_list: list, image_list: list, log_dir: Optional[Path] = None, global_step: int = 0): - """Log audio and image artifacts. - """ + """Log audio and image artifacts.""" if log_dir is not None: log_dir.mkdir(parents=True, exist_ok=True) @@ -232,8 +237,7 @@ def _log_artifacts(self, audio_list: list, image_list: list, log_dir: Optional[P self._log_image(image=image, log_dir=log_dir, step=global_step) def on_fit_start(self, trainer: Trainer, model: LightningModule): - """Log initial data artifacts. - """ + """Log initial data artifacts.""" audio_list = [] image_list = [] for batch_dict in self.data_loader: @@ -255,8 +259,7 @@ def on_fit_start(self, trainer: Trainer, model: LightningModule): self._log_artifacts(audio_list=audio_list, image_list=image_list, log_dir=log_dir) def on_train_epoch_end(self, trainer: Trainer, model: LightningModule): - """Log artifacts at the end of an epoch. - """ + """Log artifacts at the end of an epoch.""" epoch = 1 + model.current_epoch if (epoch not in self.log_epochs) and (epoch % self.epoch_frequency != 0): return @@ -306,7 +309,10 @@ def generate_artifacts( audio_gt_path = Path(f"{dataset_name}/{audio_id}_gt.wav") audio_gt_i = audio[i, : audio_len[i]].cpu().numpy() audio_artifact = AudioArtifact( - id=f"audio_gt_{audio_id}", data=audio_gt_i, filepath=audio_gt_path, sample_rate=model.sample_rate, + id=f"audio_gt_{audio_id}", + data=audio_gt_i, + filepath=audio_gt_path, + sample_rate=model.sample_rate, ) audio_artifacts.append(audio_artifact) return audio_artifacts, [] @@ -321,7 +327,10 @@ def generate_artifacts( audio_pred_path = Path(f"{dataset_name}/{audio_id}.wav") audio_pred_i = audio_pred[i, : audio_len[i]].cpu().numpy() audio_artifact = AudioArtifact( - id=f"audio_{audio_id}", data=audio_pred_i, filepath=audio_pred_path, sample_rate=model.sample_rate, + id=f"audio_{audio_id}", + data=audio_pred_i, + filepath=audio_pred_path, + sample_rate=model.sample_rate, ) audio_artifacts.append(audio_artifact) @@ -378,7 +387,10 @@ def _generate_audio( audio_pred_path = Path(f"{dataset_name}/{audio_id}_audio_out.wav") audio_pred_i = audio_pred[i, : audio_pred_len[i]].cpu().numpy() audio_artifact = AudioArtifact( - id=f"audio_out_{audio_id}", data=audio_pred_i, filepath=audio_pred_path, sample_rate=model.sample_rate, + id=f"audio_out_{audio_id}", + data=audio_pred_i, + filepath=audio_pred_path, + sample_rate=model.sample_rate, ) audio_artifacts.append(audio_artifact) @@ -388,7 +400,10 @@ def _generate_audio( audio_in_path = Path(f"{dataset_name}/{audio_id}_audio_in.wav") audio_in_i = audio[i, : audio_len[i]].cpu().numpy() audio_artifact = AudioArtifact( - id=f"audio_in_{audio_id}", data=audio_in_i, filepath=audio_in_path, sample_rate=model.sample_rate, + id=f"audio_in_{audio_id}", + data=audio_in_i, + filepath=audio_in_path, + sample_rate=model.sample_rate, ) audio_artifacts.append(audio_artifact) @@ -538,7 +553,11 @@ def _create_ground_truth_artifacts( spec_gt_path = Path(f"{dataset_name}/{audio_id}_spec_gt.png") spec_gt_i = spec[i, :, : spec_len[i]].cpu().numpy() spec_artifact = ImageArtifact( - id=f"spec_{audio_id}", data=spec_gt_i, filepath=spec_gt_path, x_axis="Audio Frames", y_axis="Channels", + id=f"spec_{audio_id}", + data=spec_gt_i, + filepath=spec_gt_path, + x_axis="Audio Frames", + y_axis="Channels", ) image_artifacts.append(spec_artifact) @@ -565,14 +584,22 @@ def _generate_predictions( with torch.no_grad(): # [B, C, T_spec] - mels_pred, mels_pred_len, *_ = model.forward(text=text, input_lens=text_lens, speaker=speaker,) + mels_pred, mels_pred_len, *_ = model.forward( + text=text, + input_lens=text_lens, + speaker=speaker, + ) if self.log_spectrogram: for i, (dataset_name, audio_id) in enumerate(zip(dataset_names, audio_ids)): spec_path = Path(f"{dataset_name}/{audio_id}_spec.png") spec_i = mels_pred[i, :, : mels_pred_len[i]].cpu().numpy() spec_artifact = ImageArtifact( - id=f"spec_{audio_id}", data=spec_i, filepath=spec_path, x_axis="Audio Frames", y_axis="Channels", + id=f"spec_{audio_id}", + data=spec_i, + filepath=spec_path, + x_axis="Audio Frames", + y_axis="Channels", ) image_artifacts.append(spec_artifact) diff --git a/nemo/collections/tts/parts/utils/helpers.py b/nemo/collections/tts/parts/utils/helpers.py index a4c65f9ed0e5..28be259502c5 100644 --- a/nemo/collections/tts/parts/utils/helpers.py +++ b/nemo/collections/tts/parts/utils/helpers.py @@ -48,8 +48,8 @@ import librosa import matplotlib.pylab as plt import numpy as np +import seaborn as sns import torch -from einops import rearrange from numba import jit, prange from nemo.collections.tts.torch.tts_data_types import DATA_STR2DATA_CLASS, MAIN_DATA_TYPES, WithLens @@ -63,7 +63,7 @@ HAVE_WANDB = False try: - from pytorch_lightning.utilities import rank_zero_only + from lightning.pytorch.utilities import rank_zero_only except ModuleNotFoundError: from functools import wraps @@ -468,6 +468,74 @@ def plot_alignment_to_numpy(alignment, title='', info=None, phoneme_seq=None, vm return data +def plot_alignment_to_numpy_for_speechllm( + alignment, + title='', + info=None, + phoneme_seq=None, + vmin=None, + vmax=None, + phoneme_ver=0, + phone_offset=2, + h_offset=True, +): + alignment = np.clip(alignment, a_min=0, a_max=None) + fig, ax = plt.subplots(figsize=(8, 6)) + im = ax.imshow(alignment, aspect='auto', origin='lower', interpolation='none', vmin=vmin, vmax=vmax) + ax.set_title(title) + fig.colorbar(im, ax=ax) + xlabel = 'Decoder timestep' + if info is not None: + xlabel += '\n\n' + info + plt.xlabel(xlabel) + plt.ylabel('Encoder timestep') + + if phoneme_seq is not None: + if phoneme_ver == 0: + # for debugging of phonemes and durs in maps. Not used by def in training code + ax.set_yticks(np.arange(len(phoneme_seq))) + ax.set_yticklabels(phoneme_seq) + ax.hlines(np.arange(len(phoneme_seq)), xmin=0.0, xmax=max(ax.get_xticks())) + elif phoneme_ver == 1: + yticks = ax.get_yticks() + new_yticks = [] + for tick in yticks: + if tick < 0 or tick > alignment.shape[0]: + continue + new_yticks.append(tick) + new_yticks += phoneme_seq + ax.set_yticks(new_yticks) + elif phoneme_ver == 2: + phones = phoneme_seq[phone_offset:] + ax.set_yticks(np.arange(len(phones))) + ax.set_yticklabels(phones) + ax.hlines(np.arange(0.5, len(phones) - 0.5, 1.0), xmin=0.0, xmax=alignment.shape[1] - 0.5, colors="black") + + if h_offset: + xticks = ax.get_xticks() + new_xticks = [] + for tick in xticks: + new_xticks.append(f"{tick+phoneme_seq[1]:.0f}") + ax.set_xticklabels(new_xticks) + + plt.tight_layout() + fig.canvas.draw() + data = save_figure_to_numpy(fig) + plt.close() + return data + + +def plot_codec_to_numpy(codes, title=''): + fig, ax = plt.subplots(figsize=(10, 3)) + sns.heatmap(codes, ax=ax) + + plt.tight_layout() + fig.canvas.draw() + data = save_figure_to_numpy(fig) + plt.close() + return data + + def plot_pitch_to_numpy(pitch, ylim_range=None): fig, ax = plt.subplots(figsize=(12, 3)) plt.plot(pitch) diff --git a/nemo/collections/tts/parts/utils/tts_dataset_utils.py b/nemo/collections/tts/parts/utils/tts_dataset_utils.py index 5f1185c2c399..96806f633a54 100644 --- a/nemo/collections/tts/parts/utils/tts_dataset_utils.py +++ b/nemo/collections/tts/parts/utils/tts_dataset_utils.py @@ -67,8 +67,7 @@ def get_audio_filepaths(manifest_entry: Dict[str, Any], audio_dir: Path) -> Tupl def normalize_volume(audio: np.array, volume_level: float = 0.95) -> np.array: - """Apply peak normalization to the input audio. - """ + """Apply peak normalization to the input audio.""" if not (0.0 <= volume_level <= 1.0): raise ValueError(f"Volume must be in range [0.0, 1.0], received {volume_level}") @@ -88,10 +87,11 @@ class BetaBinomialInterpolator: The implementation is taken from https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechSynthesis/FastPitch/fastpitch/data_function.py """ - def __init__(self, round_mel_len_to=50, round_text_len_to=10, cache_size=500): + def __init__(self, round_mel_len_to=50, round_text_len_to=10, cache_size=500, scaling_factor: float = 1.0): self.round_mel_len_to = round_mel_len_to self.round_text_len_to = round_text_len_to - self.bank = functools.lru_cache(maxsize=cache_size)(beta_binomial_prior_distribution) + cached_func = lambda x, y: beta_binomial_prior_distribution(x, y, scaling_factor=scaling_factor) + self.bank = functools.lru_cache(maxsize=cache_size)(cached_func) @staticmethod def round(val, to): @@ -315,7 +315,11 @@ def load_audio( def sample_audio( - manifest_entry: Dict[str, Any], audio_dir: Path, sample_rate: int, n_samples: int, volume_norm: bool = False, + manifest_entry: Dict[str, Any], + audio_dir: Path, + sample_rate: int, + n_samples: int, + volume_norm: bool = False, ) -> Tuple[np.ndarray, Path, Path]: """ Randomly sample an audio segment from a manifest entry. diff --git a/nemo/collections/vision/models/megatron_vit_classification_models.py b/nemo/collections/vision/models/megatron_vit_classification_models.py index 5cffdd6d12a3..c4024a5a47a7 100644 --- a/nemo/collections/vision/models/megatron_vit_classification_models.py +++ b/nemo/collections/vision/models/megatron_vit_classification_models.py @@ -17,9 +17,9 @@ from typing import Any, Optional import torch +from lightning.pytorch.accelerators import CPUAccelerator +from lightning.pytorch.trainer.trainer import Trainer from omegaconf.dictconfig import DictConfig -from pytorch_lightning.accelerators import CPUAccelerator -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.data.language_modeling.megatron.data_samplers import MegatronPretrainingSampler from nemo.collections.nlp.models.language_modeling.megatron_base_model import MegatronBaseModel diff --git a/nemo/collections/vision/modules/common/megatron/vision_transformer.py b/nemo/collections/vision/modules/common/megatron/vision_transformer.py index 80793067128c..2abaf6dfe224 100644 --- a/nemo/collections/vision/modules/common/megatron/vision_transformer.py +++ b/nemo/collections/vision/modules/common/megatron/vision_transformer.py @@ -169,6 +169,8 @@ def forward( self_attention_relative_position_bias=None, cross_attention_relative_position_bias=None, checkpoint_core_attention=False, + decoder_max_sequence_len=None, + encoder_max_sequence_len=None, ): # Self attention. if rotary_pos_emb is not None: @@ -373,6 +375,8 @@ def forward( self_attention_relative_position_bias=None, cross_attention_relative_position_bias=None, checkpoint_core_attention=False, + decoder_max_sequence_len=None, + encoder_max_sequence_len=None, ): kwargs = locals() for key in ["self", "__class__"]: diff --git a/nemo/collections/vlm/__init__.py b/nemo/collections/vlm/__init__.py index 7d8cc2c94247..b5e693830fa5 100644 --- a/nemo/collections/vlm/__init__.py +++ b/nemo/collections/vlm/__init__.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from nemo.collections.vlm.llava_next.data import LlavaNextMockDataModule, LlavaNextTaskEncoder +from nemo.collections.vlm.llava_next.model.base import LlavaNextConfig +from nemo.collections.vlm.llava_next.model.llava_next import LlavaNextConfig7B, LlavaNextConfig13B, LlavaNextModel from nemo.collections.vlm.mllama.data import MLlamaLazyDataModule, MLlamaMockDataModule from nemo.collections.vlm.mllama.model.base import ( CrossAttentionTextConfig, @@ -42,7 +45,8 @@ NevaConfig, NevaModel, ) -from nemo.collections.vlm.neva.model.llava import Llava1_5Config7B, Llava1_5Config13B, LlavaConfig, LlavaModel +from nemo.collections.vlm.neva.model.llava import Llava15Config7B, Llava15Config13B, LlavaConfig, LlavaModel +from nemo.collections.vlm.neva.model.vit_config import CLIPViTL_14_336_Config, SigLIPViT400M_14_384_Config from nemo.collections.vlm.peft import LoRA from nemo.collections.vlm.recipes import * @@ -59,13 +63,16 @@ "VideoToken", "CLIPViTConfig", "HFCLIPVisionConfig", + "CLIPViTL_14_336_Config", + "SigLIPViT400M_14_384_Config", "MultimodalProjectorConfig", "NevaConfig", "NevaModel", "LlavaConfig", - "Llava1_5Config7B", - "Llava1_5Config13B", + "Llava15Config7B", + "Llava15Config13B", "LlavaModel", + "LlavaNextTaskEncoder", "MLlamaModel", "MLlamaModelConfig", "CrossAttentionTextConfig", @@ -76,4 +83,10 @@ "MLlamaConfig90BInstruct", "mllama_11b", "mllama_90b", + "llava_next_7b", + "LlavaNextConfig7B", + "LlavaNextConfig13B", + "LlavaNextModel", + "LlavaNextMockDataModule", + "LlavaNextTaskEncoder", ] diff --git a/nemo/collections/vlm/layer_specs.py b/nemo/collections/vlm/layer_specs.py new file mode 100644 index 000000000000..11c4d697a5aa --- /dev/null +++ b/nemo/collections/vlm/layer_specs.py @@ -0,0 +1,131 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add +from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear +from megatron.core.transformer.attention import SelfAttention, SelfAttentionSubmodules +from megatron.core.transformer.dot_product_attention import DotProductAttention +from megatron.core.transformer.enums import AttnMaskType +from megatron.core.transformer.identity_op import IdentityOp +from megatron.core.transformer.mlp import MLP, MLPSubmodules +from megatron.core.transformer.spec_utils import ModuleSpec +from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules + +try: + from megatron.core.extensions.transformer_engine import ( + TEColumnParallelLinear, + TEDotProductAttention, + TELayerNormColumnParallelLinear, + TENorm, + TERowParallelLinear, + ) + + HAVE_TE = True +except ImportError: + HAVE_TE = False + +try: + from megatron.core.fusions.fused_layer_norm import FusedLayerNorm + + HAVE_APEX = True + LNImpl = FusedLayerNorm +except ImportError: + import warnings + + from megatron.core.transformer.torch_layer_norm import WrappedTorchLayerNorm + + warnings.warn(f'Apex is not installed. Falling back to Torch LayerNorm') + LNImpl = WrappedTorchLayerNorm + + +def get_layer_spec(is_vit, normalization) -> ModuleSpec: + """Transformer Layer Spec""" + attn_mask_type = AttnMaskType.no_mask if is_vit else AttnMaskType.causal + if normalization == "LayerNorm": + norm = LNImpl + elif normalization == "RMSNorm": + norm = TENorm + else: + raise RuntimeError("unknown normalization", normalization) + + mlp = get_mlp_module_spec(use_te=False) # doesn't include norm. + + return ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + input_layernorm=norm, + self_attention=ModuleSpec( + module=SelfAttention, + params={"attn_mask_type": attn_mask_type}, + submodules=SelfAttentionSubmodules( + linear_qkv=ColumnParallelLinear, + core_attention=DotProductAttention, + linear_proj=RowParallelLinear, + q_layernorm=IdentityOp, + k_layernorm=IdentityOp, + ), + ), + self_attn_bda=get_bias_dropout_add, + pre_mlp_layernorm=norm, + mlp=mlp, + mlp_bda=get_bias_dropout_add, + ), + ) + + +def get_layer_spec_te(is_vit=False) -> ModuleSpec: + """Transformer Layer Spec w/ TE Modules""" + attn_mask_type = AttnMaskType.no_mask if is_vit else AttnMaskType.causal + + mlp = get_norm_mlp_module_spec_te() + return ModuleSpec( + module=TransformerLayer, + submodules=TransformerLayerSubmodules( + self_attention=ModuleSpec( + module=SelfAttention, + params={"attn_mask_type": attn_mask_type}, + submodules=SelfAttentionSubmodules( + linear_qkv=TELayerNormColumnParallelLinear, + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + q_layernorm=IdentityOp, + k_layernorm=IdentityOp, + ), + ), + self_attn_bda=get_bias_dropout_add, + pre_mlp_layernorm=IdentityOp, + mlp=mlp, + mlp_bda=get_bias_dropout_add, + ), + ) + + +def get_mlp_module_spec(use_te: bool = True) -> ModuleSpec: + """MLP Submodule Spec""" + # Dense MLP w/ or w/o TE modules. + return ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=TEColumnParallelLinear if use_te else ColumnParallelLinear, + linear_fc2=TERowParallelLinear if use_te else RowParallelLinear, + ), + ) + + +def get_norm_mlp_module_spec_te() -> ModuleSpec: + """Norm + MLP Submodule Spec""" + return ModuleSpec( + module=MLP, + submodules=MLPSubmodules(linear_fc1=TELayerNormColumnParallelLinear, linear_fc2=TERowParallelLinear), + ) diff --git a/nemo/collections/vlm/llava_next/__init__.py b/nemo/collections/vlm/llava_next/__init__.py new file mode 100644 index 000000000000..d9155f923f18 --- /dev/null +++ b/nemo/collections/vlm/llava_next/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/vlm/llava_next/data/__init__.py b/nemo/collections/vlm/llava_next/data/__init__.py new file mode 100644 index 000000000000..1c71e5355f4b --- /dev/null +++ b/nemo/collections/vlm/llava_next/data/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nemo.collections.vlm.llava_next.data.energon import LlavaNextTaskEncoder +from nemo.collections.vlm.llava_next.data.mock import MockDataModule as LlavaNextMockDataModule + +__all__ = [ + "LlavaNextMockDataModule", + "LlavaNextTaskEncoder", +] diff --git a/nemo/collections/vlm/llava_next/data/energon.py b/nemo/collections/vlm/llava_next/data/energon.py new file mode 100644 index 000000000000..effa3236ade7 --- /dev/null +++ b/nemo/collections/vlm/llava_next/data/energon.py @@ -0,0 +1,212 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Dict, List, Optional + +import torch +from megatron.energon import VQASample, batch_list, batch_pad_stack +from torch.nn.utils.rnn import pad_sequence + +from nemo.collections.multimodal.data.energon.config import ImageTextRawBatch, ImageTextSample, MultiModalSampleConfig +from nemo.collections.multimodal.data.energon.sample_encoder import SampleEncoder, VQASampleEncoder +from nemo.collections.multimodal.data.energon.task_encoder import MultiModalTaskEncoder +from nemo.utils import logging + + +@dataclass +class LlavaNextTextSample(ImageTextSample): + ''' + Sample type for LLaVA-Next, extending ImageTextSample to support tiled image data. + + This class adds additional attributes for handling high-resolution images processed as tiles, + along with metadata about the tiled images. + + Attributes: + num_media_tiles (int): The number of tiles used to represent the high-resolution image. + image_sizes (torch.Tensor): A tensor representing the sizes of the tiled images. + attention_mask (Optional[torch.Tensor]): An optional attention mask for the sample, + used to determine which tokens or tiles are attended to during processing. Defaults to None. + ''' + + num_media_tiles: int = 0 + image_sizes: torch.tensor = field(default_factory=lambda: torch.tensor([])) + attention_mask: Optional[torch.tensor] = None + + +@dataclass +class LlavaNextTextRawBatch(ImageTextRawBatch): + """ + Batch type for raw LLaVA-Next samples, supporting tiled image data. + + This class aggregates multiple `LlavaNextTextSample` instances into a batch for processing. + It includes attributes for managing tiled images and associated metadata for each sample in the batch. + + Attributes: + num_media_tiles (List[int]): A list containing the number of tiles for each image in the batch. + image_sizes (torch.Tensor): A tensor containing the sizes of all tiled images in the batch. + attention_mask (Optional[torch.Tensor]): Attention mask. Defaults to None. + """ + + num_media_tiles: List[int] = field(default_factory=list) + image_sizes: torch.tensor = field(default_factory=lambda: torch.tensor([])) + attention_mask: Optional[torch.tensor] = None + + +class LlavaNextSampleEncoder(VQASampleEncoder): + """LlavaNextSampleEncoder""" + + def __init__(self, tokenizer, image_processor, multimodal_sample_config=MultiModalSampleConfig()): + """ + Initialize the LlavaNextSampleEncoder, inherited from VQASampleEncoder for multimodal samples + focused on VQA-style data to support LLaVANeXT + + Parameters: + tokenizer (Tokenizer): The HF tokenizer used for processing text. + image_processor (ImageProcessor): The HF image processor used for preprocessing images. + multimodal_sample_config (MultiModalSampleConfig, optional): Configuration object for multimodal samples. + Defaults to MultiModalSampleConfig(). + """ + super().__init__(tokenizer, image_processor, multimodal_sample_config) + + def process_image(self, image): + """ + Process and prepare an image sample for encoding. + + This method preprocesses the image using the HF image_processor, converting it to + a tensor. + + Parameters: + image: The input image to be processed. + + Returns: + torch.Tensor: The processed image tensor. + """ + image_array = self.image_processor.preprocess(image, return_tensors='pt', do_rescale=False)['pixel_values'][0] + return image_array + + def encode(self, input_sample: VQASample, output_sample: LlavaNextTextSample): + """ + Encode a single sample into a format suitable for model input. + + This method prepares the conversation prompt, tokenizes it, and processes + the associated image. It fills the output sample with tokens, labels, loss mask, + and other required fields for multimodal processing. + + Parameters: + input_sample (VQASample): The input VQA sample containing an image and conversation text. + output_sample (LlavaNextTextSample): The output sample structure where encoded results are stored. + + Returns: + LlavaNextTextSample: The encoded output sample, containing processed tokens, labels, + images, loss masks, and metadata. + """ + conversation_prompt = self.apply_prompt_template(input_sample) + logging.debug(f"[Energon] task encoder encode_sample conversation_prompt {conversation_prompt}") + # tokenize prompt + tokens = self.tokenize(conversation_prompt) + labels = self.compute_labels(tokens, input_sample) + tokens = tokens[:-1].contiguous() + labels = labels[1:].contiguous() + logging.debug(f"[Energon] task encoder encode_sample after tokenize prompt tokens {tokens}") + logging.debug(f"[Energon] task encoder encode_sample lables {labels}") + loss_mask = self.compute_loss_mask(labels) + processed_image = self.process_image(input_sample.image) + output_sample.__key__ = input_sample.__key__ + output_sample.images = processed_image + output_sample.tokens = tokens + output_sample.labels = labels + output_sample.loss_mask = loss_mask + output_sample.num_media_tiles = processed_image.shape[0] + output_sample.attention_mask = torch.ones(len(tokens), dtype=torch.long) + height = input_sample.image.shape[1] + width = input_sample.image.shape[2] + output_sample.image_sizes = torch.tensor([[height, width]], dtype=torch.long) + return output_sample + + +class LlavaNextTaskEncoder(MultiModalTaskEncoder): + """LlavaNextTaskEncoder""" + + def __init__(self, tokenizer, image_processor, multimodal_sample_config): + """ + Initialize the LlavaNextTaskEncoder. + + This encoder extends MultiModalTaskEncoder to specifically handle LlavaNeXT, + overriding encoders for VQA sample type. + + Parameters: + tokenizer (Tokenizer): The tokenizer for processing text data across sample types. + image_processor (ImageProcessor): The image processor for preprocessing images. + multimodal_sample_config (MultiModalSampleConfig): Configuration settings for multimodal samples. + """ + super().__init__(tokenizer, image_processor, multimodal_sample_config) + self.encoders: Dict[str, SampleEncoder] = { + VQASample.__name__: LlavaNextSampleEncoder(tokenizer, image_processor, multimodal_sample_config) + } + + def batch(self, samples: List[LlavaNextTextSample]) -> LlavaNextTextRawBatch: + """ + Batch multiple encoded samples into a single batch structure for model input. + + This method combines individual sample fields (keys, images, tokens, labels, etc.) and + pads or stacks them as needed to create a unified batch. + + Parameters: + samples (List[LlavaNextTextSample]): A list of LlavaNextTextSample instances to be batched. + + Returns: + LlavaNextTextRawBatch: A batch containing all input samples' images, tokens, labels, + loss masks, and other metadata prepared for model processing. + """ + keys, images, tokens, labels, loss_mask, num_media_tiles, image_sizes, attention_mask = ( + [], + [], + [], + [], + [], + [], + [], + [], + ) + for sample in samples: + keys.append(sample.__key__) + images.append(sample.images) + tokens.append(sample.tokens) + labels.append(sample.labels) + loss_mask.append(sample.loss_mask) + num_media_tiles.append(sample.num_media_tiles) + image_sizes.append(sample.image_sizes) + attention_mask.append(sample.attention_mask) + + batch_keys = batch_list(keys) + + batch_images = torch.cat(images, dim=0) + + batch_tokens = pad_sequence(tokens, batch_first=True) + batch_labels = pad_sequence(labels, batch_first=True) + image_sizes = torch.cat(image_sizes, dim=0) + batch_loss_mask = batch_pad_stack(loss_mask) + batch_attention_mask = batch_pad_stack(attention_mask) + batch_num_media_tiles = torch.tensor(batch_list(num_media_tiles), dtype=torch.int) + return LlavaNextTextRawBatch( + __keys__=batch_keys, + images=batch_images, + tokens=batch_tokens, + labels=batch_labels, + loss_mask=batch_loss_mask, + num_media_tiles=batch_num_media_tiles, + image_sizes=image_sizes, + attention_mask=batch_attention_mask, + ) diff --git a/nemo/collections/vlm/llava_next/data/mock.py b/nemo/collections/vlm/llava_next/data/mock.py new file mode 100644 index 000000000000..f61df7336e6f --- /dev/null +++ b/nemo/collections/vlm/llava_next/data/mock.py @@ -0,0 +1,311 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, Optional + +import lightning.pytorch as pl +import numpy as np +import torch +from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS +from torch.utils import data +from torch.utils.data import DataLoader, Dataset + +from nemo.collections.vlm.neva.data.multimodal_tokens import IMAGE_TOKEN_INDEX +from nemo.lightning.pytorch.plugins import MegatronDataSampler +from nemo.utils import logging + + +class MockDataModule(pl.LightningDataModule): + """ + A mock data module for LLaVA-Next training, validation, and testing. + + Provides datasets and data loaders for training, validation, and testing phases. + Includes data sampling and preprocessing for multimodal tasks. + """ + + def __init__( + self, + seq_length: int = 2048, + decoder_seq_length: Optional[int] = None, + tokenizer=None, + image_processor=None, + micro_batch_size: int = 4, + global_batch_size: int = 8, + rampup_batch_size: Optional[List[int]] = None, + num_train_samples: int = 10_000_000, + num_val_samples: int = 10_000_000, + num_test_samples: int = 10_000_000, + num_workers: int = 8, + pin_memory: bool = True, + persistent_workers: bool = False, + ): + """ + Initializes the mock data module with data sampling and preprocessing configurations. + + Args: + seq_length (int): Maximum sequence length for tokens. + decoder_seq_length (Optional[int]): Sequence length for the decoder. + tokenizer: Tokenizer for text processing. + image_processor: Processor for image preprocessing. + micro_batch_size (int): Batch size per GPU. + global_batch_size (int): Total batch size across GPUs. + rampup_batch_size (Optional[List[int]]): Batch size ramp-up schedule. + num_train_samples (int): Number of training samples. + num_val_samples (int): Number of validation samples. + num_test_samples (int): Number of testing samples. + num_workers (int): Number of workers for data loading. + pin_memory (bool): Whether to pin memory for data loaders. + persistent_workers (bool): Whether to keep workers alive after the first iteration. + """ + super().__init__() + self.seq_length = seq_length + self.decoder_seq_len = decoder_seq_length + self.num_train_samples = num_train_samples + self.num_val_samples = num_val_samples + self.num_test_samples = num_test_samples + self.num_workers = num_workers + self.pin_memory = pin_memory + self.persistent_workers = persistent_workers + + if tokenizer is None or image_processor is None: + logging.warning( + f"Processor or tokenizer are not provided! Fall back to `llava-hf/llava-v1.6-vicuna-7b-hf`." + ) + from transformers import AutoProcessor + + from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer + + model_name = "llava-hf/llava-v1.6-vicuna-7b-hf" + + processor = AutoProcessor.from_pretrained(model_name) + self.tokenizer = tokenizer or AutoTokenizer(model_name) + self.image_processor = image_processor or processor.image_processor + self.data_sampler = MegatronDataSampler( + seq_len=self.seq_length, + decoder_seq_len=self.decoder_seq_len, + micro_batch_size=micro_batch_size, + global_batch_size=global_batch_size, + rampup_batch_size=rampup_batch_size, + ) + + def setup(self, stage: str = "") -> None: + """ + Sets up the training, validation, and testing datasets. + + Args: + stage (str): Stage of the setup ('train', 'valid', 'test'). + """ + self._train_ds = _MockLlavaNextDataset( + self.tokenizer, self.image_processor, "train", self.num_train_samples, self.seq_length + ) + self._validation_ds = _MockLlavaNextDataset( + self.tokenizer, self.image_processor, "valid", self.num_val_samples, self.seq_length + ) + self._test_ds = _MockLlavaNextDataset( + self.tokenizer, self.image_processor, "test", self.num_test_samples, self.seq_length + ) + + def train_dataloader(self) -> TRAIN_DATALOADERS: + """ + Creates the training data loader. + + Returns: + TRAIN_DATALOADERS: Training data loader. + """ + if not hasattr(self, "_train_ds"): + self.setup() + return self._create_dataloader(self._train_ds) + + def val_dataloader(self) -> EVAL_DATALOADERS: + """ + Creates the validation data loader. + + Returns: + EVAL_DATALOADERS: Validation data loader. + """ + if not hasattr(self, "_validation_ds"): + self.setup() + return self._create_dataloader(self._validation_ds) + + def test_dataloader(self) -> EVAL_DATALOADERS: + """ + Creates the testing data loader. + + Returns: + TEST_DATALOADERS: Testing data loader. + """ + if not hasattr(self, "_test_ds"): + self.setup() + return self._create_dataloader(self._test_ds) + + def _create_dataloader(self, dataset, **kwargs) -> DataLoader: + """ + Creates a generic data loader for the given dataset. + + Args: + dataset: The dataset for which the data loader is created. + **kwargs: Additional arguments for the DataLoader. + + Returns: + DataLoader: The created data loader. + """ + return DataLoader( + dataset, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + persistent_workers=self.persistent_workers, + collate_fn=dataset.collate_fn, + **kwargs, + ) + + +class _MockLlavaNextDataset(Dataset): + """ + A mock dataset for LLaVA-Next, generating synthetic multimodal data. + + Attributes: + tokenizer: Tokenizer for text inputs. + image_processor: Processor for image inputs. + name (str): Name of the dataset ('train', 'valid', 'test'). + num_samples (int): Number of samples in the dataset. + seq_length (int): Sequence length for text tokens. + seed (int): Random seed for reproducibility. + """ + + def __init__( + self, + tokenizer, + image_processor, + name: str, + num_samples: int, + seq_length: int, + seed: int = 42, + ) -> None: + """ + Initializes the mock dataset with synthetic multimodal data. + + Args: + tokenizer: Tokenizer for text inputs. + image_processor: Processor for image inputs. + name (str): Dataset name ('train', 'valid', 'test'). + num_samples (int): Total number of samples in the dataset. + seq_length (int): Sequence length for text tokens. + seed (int): Random seed for data generation. + """ + super().__init__() + self.name = name + self.seq_length = seq_length + + self.vocab_size = tokenizer.vocab_size + + crop_size = image_processor.crop_size + self.image_height, self.image_width = crop_size["height"], crop_size["width"] + + self.length = num_samples + self.seed = seed + + self.loss_mask = torch.ones(self.seq_length, dtype=torch.float) + self.position_ids = torch.arange(self.seq_length, dtype=torch.int64) + self.tokenizer = tokenizer + self.image_processor = image_processor + + def __len__(self) -> int: + """ + Returns the length of the dataset. + + Returns: + int: Number of samples in the dataset. + """ + return self.length + + def _get_text(self, idx: int) -> np.ndarray: + """ + Generates synthetic text data. + + Args: + idx (int): Index of the sample. + + Returns: + np.ndarray: Synthetic text token IDs. + """ + np_gen = np.random.default_rng(seed=(self.seed + idx)) + return np_gen.integers(self.vocab_size, size=[self.seq_length], dtype=np.int64) + + def __getitem__(self, idx) -> Dict[str, torch.Tensor]: + """ + Generates a synthetic multimodal sample. + + Args: + idx (int): Index of the sample. + + Returns: + Dict[str, torch.Tensor]: A dictionary containing synthetic tokens, images, and metadata. + """ + # Generate data of the expected size and datatype (based on GPTDataset). + np_gen = np.random.default_rng(seed=(self.seed + idx)) + tokens = torch.from_numpy(np_gen.integers(self.vocab_size, size=[self.seq_length + 1], dtype=np.int64)) + tokens[2] = IMAGE_TOKEN_INDEX # ImageToken token index + labels = tokens.clone() + images = torch.from_numpy(np_gen.random(size=[3, self.image_height, self.image_width], dtype=np.float32)) + tokens = tokens[:-1] + labels = labels[1:] + + # attention_mask, image_sizes, num_media_tiles required for llava-next. Neva model will ignore these + attention_mask = torch.ones(len(tokens), dtype=torch.long) + image_sizes = torch.tensor([[self.image_height, self.image_width]], dtype=torch.long) + image_array = self.image_processor.preprocess(images, return_tensors='pt', do_rescale=False)['pixel_values'][0] + num_media_tiles = image_array.shape[0] + return { + "media": image_array, + "tokens": tokens, + "labels": labels, + "loss_mask": self.loss_mask, + "position_ids": self.position_ids, + "image_sizes": image_sizes, + "num_media_tiles": num_media_tiles, + "attention_mask": attention_mask, + } + + def _collate_fn(self, batch): + """ + A default implementation of a collation function. + Users should override this method to define custom data loaders. + """ + collated_batch = data.dataloader.default_collate(batch) + + collated_batch['media'] = collated_batch['media'].contiguous().view(-1, *collated_batch['media'].shape[2:]) + collated_batch['image_sizes'] = ( + collated_batch['image_sizes'].contiguous().view(-1, *collated_batch['image_sizes'].shape[2:]) + ) + return collated_batch + + def collate_fn(self, batch): + """Method that user pass as functor to DataLoader. + + The method optionally performs neural type checking and add types to the outputs. + + Please note, subclasses of Dataset should not implement `input_types`. + + # Usage: + dataloader = torch.utils.data.DataLoader( + ...., + collate_fn=dataset.collate_fn, + .... + ) + + Returns + ------- + Collated batch, with or without types. + """ + return self._collate_fn(batch) diff --git a/nemo/collections/vlm/llava_next/model/__init__.py b/nemo/collections/vlm/llava_next/model/__init__.py new file mode 100644 index 000000000000..6d7b02482f62 --- /dev/null +++ b/nemo/collections/vlm/llava_next/model/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from nemo.collections.vlm.llava_next.model.base import LlavaNextConfig +from nemo.collections.vlm.llava_next.model.llava_next import LlavaNextConfig7B, LlavaNextConfig13B, LlavaNextModel + +__all__ = [ + "LlavaNextConfig", + "LlavaNextModel", + "LlavaNextConfig7B", + "LlavaNextConfig13B", +] diff --git a/nemo/collections/vlm/llava_next/model/base.py b/nemo/collections/vlm/llava_next/model/base.py new file mode 100644 index 000000000000..7968c720db0e --- /dev/null +++ b/nemo/collections/vlm/llava_next/model/base.py @@ -0,0 +1,373 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass, field +from typing import Callable, Dict, List, Optional + +import torch +import torch.distributed +from megatron.core import parallel_state as ps +from megatron.core.inference_params import InferenceParams +from megatron.core.tensor_parallel import gather_from_sequence_parallel_region + +from nemo.collections.llm.gpt.model.base import get_batch_on_this_context_parallel_rank, get_packed_seq_params +from nemo.collections.vlm.llava_next.model.utils import merge_input_ids_with_image_features, pack_image_features +from nemo.collections.vlm.neva.data.multimodal_tokens import IMAGE_TOKEN_INDEX + + +def llava_next_data_step(dataloader_iter) -> Dict[str, torch.Tensor]: + """ + Processes a batch of data from the dataloader for the LLaVA Next model. + + Args: + dataloader_iter (Iterator): An iterator that provides batches of data from the dataloader. + + Returns: + Dict[str, torch.Tensor]: A dictionary containing the processed batch, ready for input into the model. + + Notes: + - Filters and moves required keys to the appropriate device. + - Slices the batch along the sequence dimension for context parallelism. + """ + from megatron.core import parallel_state + + # Based on: https://github.com/NVIDIA/Megatron-LM/blob/main/pretrain_gpt.py#L87 + # https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/models/language_modeling/ + # megatron_gpt_model.py#L828-L842 + batch = next(dataloader_iter) + _batch: dict + if isinstance(batch, tuple) and len(batch) == 3: + _batch = batch[0] + else: + _batch = batch + + required_keys = set() + required_keys.update( + ( + "tokens", + "attention_mask", + "media", + "num_media_tiles", + "image_sizes", + ) + ) + if parallel_state.is_pipeline_first_stage(): + required_keys.update(("position_ids", "attention_mask")) + if parallel_state.is_pipeline_last_stage(): + required_keys.update(("labels", "loss_mask", "attention_mask")) + + _batch = { + key: val.cuda(non_blocking=True) if key in required_keys and val is not None else None + for key, val in _batch.items() + } + # slice batch along sequence dimension for context parallelism + output = get_batch_on_this_context_parallel_rank(_batch) + + return output + + +def llava_next_forward_step(model, batch) -> torch.Tensor: + """ + Performs the forward step for the LLaVA Next model. + + Args: + model (torch.nn.Module): The LLaVA Next model instance. + batch (Dict[str, torch.Tensor]): A dictionary containing input tensors for the forward step. + + Returns: + torch.Tensor: The output from the model's forward computation. + + Notes: + - Constructs the forward arguments based on the provided batch. + - Includes optional parameters like packed sequence parameters if available. + """ + forward_args = { + "media": batch["media"], + "input_ids": batch["tokens"], + "position_ids": batch["position_ids"], + "attention_mask": batch.get("attention_mask", None), + "loss_mask": batch.get("loss_mask", None), + "labels": batch.get("labels", None), + "image_sizes": batch.get("image_sizes", None), + "num_media_tiles": batch.get("num_media_tiles", None), + } + + if 'cu_seqlens' in batch: + forward_args['packed_seq_params'] = get_packed_seq_params(batch) + return model(**forward_args) + + +from nemo.collections.vlm.neva.model.base import MCoreNevaModel, NevaConfig + + +@dataclass +class LlavaNextConfig(NevaConfig): + """ + Configuration class for the LLaVA Next model. + Overrides NevaConfig and modifies forward and data step fn. + + """ + + forward_step_fn: Callable = field(default=llava_next_forward_step) + data_step_fn: Callable = field(default=llava_next_data_step) + + def configure_model(self, tokenizer) -> "MCoreLlavaNextModel": + """ + Configures the LLaVA Next model with the appropriate settings. + + Args: + tokenizer: Tokenizer instance to be used with the model. + + Returns: + MCoreLlavaNextModel: An instance of the LLaVA Next model. + """ + + self.language_transformer_config.tensor_model_parallel_size = self.tensor_model_parallel_size + self.language_transformer_config.sequence_parallel = self.sequence_parallel + self.vision_transformer_config.tensor_model_parallel_size = self.tensor_model_parallel_size + self.vision_projection_config.tensor_model_parallel_size = self.tensor_model_parallel_size + self.language_transformer_config.pipeline_model_parallel_size = self.pipeline_model_parallel_size + + if self.encoder_pipeline_model_parallel_size > 0: + assert self.encoder_pipeline_model_parallel_size == 1, "ViT can only live on 1 pipeline stage." + self.vision_transformer_config.pipeline_model_parallel_size = self.encoder_pipeline_model_parallel_size + self.vision_projection_config.pipeline_model_parallel_size = self.encoder_pipeline_model_parallel_size + self.language_transformer_config.encoder_pipeline_model_parallel_size = ( + self.encoder_pipeline_model_parallel_size + ) + if self.encoder_tensor_model_parallel_size > 0: + self.vision_transformer_config.tensor_model_parallel_size = self.encoder_tensor_model_parallel_size + self.vision_projection_config.tensor_model_parallel_size = self.encoder_tensor_model_parallel_size + + model = MCoreLlavaNextModel( + config=self, + tokenizer=tokenizer, + pre_process=ps.is_pipeline_first_stage() + or ps.get_pipeline_model_parallel_rank() == self.encoder_pipeline_model_parallel_size, + post_process=ps.is_pipeline_last_stage(), + add_encoder=ps.is_pipeline_first_stage(), + add_decoder=ps.is_pipeline_last_stage() + or ps.get_pipeline_model_parallel_rank() >= self.encoder_pipeline_model_parallel_size, + drop_vision_class_token=self.drop_vision_class_token, + ) + + return model + + +class MCoreLlavaNextModel(MCoreNevaModel): + """ + The LLaVA Next model class, extending MCoreNevaModel. + + Attributes: + image_newline (torch.nn.Parameter): A learnable parameter for handling image newlines. + """ + + def __init__( + self, + config: LlavaNextConfig, + tokenizer=None, + pre_process: bool = True, + post_process: bool = True, + add_encoder: bool = True, + add_decoder: bool = True, + drop_vision_class_token: bool = False, + ) -> None: + """ + Initializes the LLaVA Next model. + Calls the super class init and initialize image_newline parameter + + Args: + config (LlavaNextConfig): Model configuration instance. + tokenizer: Optional tokenizer instance. + pre_process (bool): Whether to enable preprocessing. + post_process (bool): Whether to enable postprocessing. + add_encoder (bool): Whether to add the encoder module. + add_decoder (bool): Whether to add the decoder module. + drop_vision_class_token (bool): Whether to drop the vision class token. + """ + super().__init__( + config=config, + tokenizer=tokenizer, + pre_process=pre_process, + post_process=post_process, + add_encoder=add_encoder, + add_decoder=add_decoder, + drop_vision_class_token=drop_vision_class_token, + ) + # extra image_newline learnable parameter for llava_next + embed_std = 1 / math.sqrt(config.vision_projection_config.hidden_size) + self.image_newline = torch.nn.Parameter(torch.randn(config.vision_projection_config.hidden_size) * embed_std) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + image_sizes: torch.Tensor, + loss_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + media: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + inference_params: Optional[InferenceParams] = None, + num_media_tiles: Optional[List[int]] = None, + media_token_index: Optional[int] = IMAGE_TOKEN_INDEX, + runtime_gather_output: Optional[bool] = None, + ) -> torch.Tensor: + """Forward function of the LLaVA Next model. + + Args: + images (torch.Tensor): input image of shape [num_tiles, img_h, img_w]. + num_tiles means the number of image tiles in this batch. + input_ids (torch.Tensor): input text ids [batch, text_seq_len]. + position_ids (torch.Tensor): input text position ids [batch, text_seq_len]. + image_sizes (torch.Tensor): Raw image sizes before tiling (N,2). + attention_mask (torch.Tensor): Attention mask for the language model [batch, text seq length]. + labels (torch.Tensor): Optional target text labels [batch, combined_seq_len]. + loss_mask (torch.Tensor): Text loss mask [batch, text_seq_len]. + inference_params (InferenceParams): Inference-time parameters including KV cache. + num_media_tiles (list of int): Number of tiles per image. Default None assumes 1 tile per image. + image_token_index (int): ID for input images. + + Returns: + output (torch.Tensor): Loss ([b, s]) if labels are provided; logits ([b, s, vocab_size]) otherwise. + loss_mask (torch.Tensor): Loss mask expanded to combined sequence length. Shape [b, s]. + """ + use_inference_kv_cache = ( + inference_params is not None and "image_tokens_count" in inference_params.key_value_memory_dict + ) + has_images = media.shape[0] > 0 + + # If running inference, we can skip media token computation + # if they were computed already earlier for this sample. + if use_inference_kv_cache: + media_embeddings = None + elif self.add_encoder and not has_images: + # If no images provided, use an empty image embeddings tensor. + media_embeddings = torch.tensor([], dtype=media.dtype, device=media.device).reshape(0, 0, 0) + elif self.add_encoder and has_images: + # media is in shape of (num_images_in_mbs, c, h, w) + # note num_images_in_mbs is not mbs but total images in this mbs. + if self.vision_model_from_hf: + self.vision_model = self.vision_model.eval() + media_embeddings = self.vision_model(media, output_hidden_states=True) + media_embeddings = media_embeddings[-1][ + self.config.vision_feature_layer + ] # [num_images, img_seq_len, h_vision] + else: + # TODO(yuya): MCore Clip path not yet support taking a specific layer hidden states + media = media.to(next(self.vision_model.parameters()).dtype) + media_embeddings = self.vision_model(media, num_unused_layers=-self.config.vision_feature_layer - 1) + if self._drop_vision_class_token: + class_token_len = getattr(self.vision_model, "class_token_len", 1) + media_embeddings = media_embeddings[:, class_token_len:, :] + + # contiguous() required as `permute` can sparsify the tensor and this breaks pipelining + media_embeddings = media_embeddings.contiguous() + # map vision model output size to language model input size. + media_embeddings = self.vision_projection(media_embeddings) # [img_seq_len, num_tiles, h_language] + # TODO: Support batched inference. + # In inference, the language model KV cache will be updated for image token positions. + # Store the image tokens sequence length to be used as an offset to the KV cache later. + if inference_params is not None: + inference_params.key_value_memory_dict["media_tokens_count"] = ( + media_embeddings.shape[0] * media_embeddings.shape[1] + ) + else: + media_embeddings = self.encoder_hidden_state + + if not self.add_decoder: + return media_embeddings + + language_embeddings = None + if self.pre_process: + input_ids_text = input_ids.clone() + # MultiModal Token indices are assumed to be values + input_ids_text[input_ids_text < 0] = 0 + # Note: This adds absolute position embedding but not RoPE. + # Each image is counted as one position. + # RoPE is added in language_model forward. Each image embedding is one position. + if self.sequence_parallel_lm: + # Pad to nearest multiple of TP world size for embedding. + tp_world_size = ps.get_tensor_model_parallel_world_size() + padded_seq_len = ( + int((input_ids_text.shape[1] + tp_world_size - 1) // tp_world_size * tp_world_size) + - input_ids_text.shape[1] + ) + if padded_seq_len != 0: + input_ids_text = torch.nn.functional.pad(input_ids_text, (0, padded_seq_len)) + if position_ids is not None: + position_ids = torch.nn.functional.pad(position_ids, (0, padded_seq_len)) + language_embeddings = self.language_model.embedding( + input_ids=input_ids_text, position_ids=position_ids + ) # [text_seq_len, b, h_language] + if self.sequence_parallel_lm: + # Gather the language embeddings back. + # We use the full embedding to insert image embeddings + # and then scatter to avoid load imbalance. + language_embeddings = gather_from_sequence_parallel_region( + language_embeddings, tensor_parallel_output_grad=False + ) + # Remove the padding done for SP as we'll need new padding calculation + # after image embeddings are inserted. + if padded_seq_len != 0: + language_embeddings = language_embeddings[:-padded_seq_len] + language_embeddings = language_embeddings.transpose(1, 0).contiguous() # [b, text_seq_len, h_language] + + # Assume 1 tile per image if the number of tiles is not provided. + if num_media_tiles is None: + num_media_tiles = torch.ones(media.shape[0], dtype=torch.int, device=input_ids.device) + elif isinstance(num_media_tiles, list): + num_media_tiles = torch.tensor(num_media_tiles, dtype=torch.int, device=input_ids.device) + + media_embeddings = torch.split(media_embeddings, num_media_tiles.tolist(), dim=0) + media_embeddings, feature_lens = pack_image_features( + media_embeddings, + image_sizes, + vision_feature_select_strategy='default', + image_newline=self.image_newline, + ) + + combined_embeddings, attention_mask, position_ids, final_labels, final_input_ids, final_loss_mask = ( + merge_input_ids_with_image_features( + media_embeddings, + feature_lens, + language_embeddings, + input_ids, + attention_mask, + position_ids, + labels=labels, + image_token_index=media_token_index, + ) + ) + combined_embeddings = combined_embeddings.permute(1, 0, 2) + combined_embeddings = combined_embeddings.contiguous() + output = self.language_model( + input_ids=None, + position_ids=None, + attention_mask=attention_mask, + decoder_input=combined_embeddings, + labels=final_labels, + inference_params=inference_params, + runtime_gather_output=runtime_gather_output, + ) + + if labels is None or loss_mask is None: + return output + + return output, final_loss_mask.contiguous() + + +__all__ = [ + "LlavaNextConfig", +] diff --git a/nemo/collections/vlm/llava_next/model/llava_next.py b/nemo/collections/vlm/llava_next/model/llava_next.py new file mode 100644 index 000000000000..fac5d5dd0871 --- /dev/null +++ b/nemo/collections/vlm/llava_next/model/llava_next.py @@ -0,0 +1,267 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from pathlib import Path +from typing import Callable, List, Optional, Union + +import torch +import torch.distributed +from megatron.core.inference_params import InferenceParams +from megatron.core.optimizer import OptimizerConfig +from megatron.core.transformer.transformer_config import TransformerConfig +from transformers import LlavaNextForConditionalGeneration + +from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec +from nemo.collections.llm import Llama2Config7B, Llama2Config13B, LlamaConfig +from nemo.collections.vlm.llava_next.model.base import LlavaNextConfig, MCoreLlavaNextModel +from nemo.collections.vlm.neva.model.base import HFCLIPVisionConfig, MultimodalProjectorConfig, NevaModel +from nemo.collections.vlm.neva.model.llava import HFLlavaImporter +from nemo.lightning import OptimizerModule, io, teardown +from nemo.lightning.pytorch.optim import MegatronOptimizerModule, OptimizerModule + + +@dataclass +class LlavaNextConfig7B(LlavaNextConfig): + """ + Configuration class for the 7B parameter variant of the LLaVA 16 model. + + Inherits all attributes and methods from Llava15Config7B without modification. + """ + + from transformers import PretrainedConfig + + language_transformer_config: TransformerConfig = field(default_factory=lambda: Llama2Config7B()) + vision_transformer_config: Union[TransformerConfig, PretrainedConfig] = field( + default_factory=lambda: HFCLIPVisionConfig(pretrained_model_name_or_path="openai/clip-vit-large-patch14-336") + ) + vision_projection_config: TransformerConfig = field( + default_factory=lambda: MultimodalProjectorConfig(input_size=1024, hidden_size=4096, ffn_hidden_size=4096) + ) + + +@dataclass +class LlavaNextConfig13B(LlavaNextConfig): + """ + Configuration class for the 13B parameter variant of the LLaVA 16 model. + + Inherits all attributes and methods from Llava15Config13B without modification. + """ + + from transformers import PretrainedConfig + + language_transformer_config: TransformerConfig = field(default_factory=lambda: Llama2Config13B()) + vision_transformer_config: Union[TransformerConfig, PretrainedConfig] = field( + default_factory=lambda: HFCLIPVisionConfig(pretrained_model_name_or_path="openai/clip-vit-large-patch14-336") + ) + vision_projection_config: TransformerConfig = field( + default_factory=lambda: MultimodalProjectorConfig(input_size=1024, hidden_size=5120, ffn_hidden_size=5120) + ) + + +class LlavaNextModel(NevaModel): + """ + The LLaVA Next model class, extending NevaModel. + + Attributes: + config (LlavaNextConfig): Configuration object for the model. + optim (Optional[OptimizerModule]): Optimizer module. Defaults to a Megatron optimizer. + tokenizer (Optional[TokenizerSpec]): Tokenizer specification for processing text inputs. + model_transform (Optional[Callable[[torch.nn.Module], torch.nn.Module]]): + Optional transformation applied to the model after initialization. + """ + + def __init__( + self, + config: LlavaNextConfig, + optim: Optional[OptimizerModule] = None, + tokenizer: Optional["TokenizerSpec"] = None, + model_transform: Optional[Callable[[torch.nn.Module], torch.nn.Module]] = None, + ): + """ + Initializes the LlavaNextModel. + + Args: + config (LlavaNextConfig): Configuration object for the model. + optim (Optional[OptimizerModule]): optimizer module. Defaults to Megatron optimizer. + tokenizer (Optional[TokenizerSpec]): Optional tokenizer specification for processing text inputs. + model_transform (Optional[Callable[[torch.nn.Module], torch.nn.Module]]): + Optional transformation function applied to the model after initialization. + """ + super().__init__( + config=config, + optim=optim or MegatronOptimizerModule(config=OptimizerConfig(lr=1e-4, use_distributed_optimizer=True)), + tokenizer=tokenizer, + model_transform=model_transform, + ) + + def configure_model(self) -> MCoreLlavaNextModel: + """ + Configures the underlying model instance if it has not been initialized. + + Returns: + MCoreLlavaNextModel: The configured model instance. + """ + if not hasattr(self, "module"): + self.module = self.config.configure_model(self.tokenizer) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + image_sizes: torch.Tensor, + loss_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + media: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + inference_params: InferenceParams = None, + num_media_tiles: Optional[List[int]] = None, + ) -> torch.Tensor: + """ + Performs the forward pass of the LLaVA Next model. + + Args: + input_ids (torch.Tensor): Input token IDs of shape [batch, text_seq_len]. + position_ids (torch.Tensor): Position IDs of shape [batch, text_seq_len]. + image_sizes (torch.Tensor): Raw image sizes before tiling, of shape [batch, 2]. + loss_mask (Optional[torch.Tensor]): Text loss mask of shape [batch, text_seq_len]. + attention_mask (Optional[torch.Tensor]): Attention mask shape [batch, text_seq_len]. + media (Optional[torch.Tensor]): Input media tensor. + labels (Optional[torch.Tensor]): Target labels of shape [batch, combined_seq_len]. + inference_params (InferenceParams): Inference-time parameters. + num_media_tiles (Optional[List[int]]): Number of tiles per image. Default assumes 1 tile per image. + + Returns: + torch.Tensor: The model output. Shape depends on whether labels are provided. + - If `labels` is provided: Loss tensor of shape [batch, seq_len]. + - If `labels` is not provided: Logits tensor of shape [batch, seq_len, vocab_size]. + """ + output_tensor = self.module( + media=media, + input_ids=input_ids, + position_ids=position_ids, + image_sizes=image_sizes, + loss_mask=loss_mask, + attention_mask=attention_mask, + labels=labels, + inference_params=inference_params, + num_media_tiles=num_media_tiles, + ) + + return output_tensor + + +@io.model_importer(LlavaNextModel, "hf") +class HFLlavaNextImporter( + HFLlavaImporter, + io.ModelConnector["LlavaNextForConditionalGeneration", LlavaNextModel], +): + """ + Importer class for converting HuggingFace LLaVA Next checkpoint to NeMo format. + + Inherits: + HFLlavaImporter: Base class for HuggingFace LLaVA model importers. + io.ModelConnector: Connector interface to handle setup, save, and load using the Lightning framework. + + Methods: + init: Initializes a new LlavaNextModel instance. + apply: Converts the HuggingFace model to NeMo format and saves it. + config: Generates and returns the LlavaNextConfig for the model. + """ + + def init(self) -> LlavaNextModel: + """ + Initializes the LlavaNextModel. + + Returns: + LlavaNextModel: An instance of the LLaVA Next model initialized with the configuration. + """ + return LlavaNextModel(self.config, tokenizer=self.tokenizer) + + def apply(self, output_path: Path) -> Path: + """ + Converts the HuggingFace LLaVA Next model to NeMo format and saves it to the specified path. + + Args: + output_path (Path): The path where the converted NeMo model will be saved. + + Returns: + Path: The output path where the NeMo model was saved. + """ + + source = LlavaNextForConditionalGeneration.from_pretrained(str(self)) + target = self.init() + trainer = self.nemo_setup(target) + self.convert_state(source, target, image_newline=True) + print(f"Converted Llava next model to Nemo, saving to {output_path}") + + self.nemo_save(output_path, trainer) + + print(f"Converted Llava next model saved to {output_path}") + + teardown(trainer, target) + del trainer, target + + return output_path + + @property + def config(self) -> LlavaNextConfig: + """ + Generates the configuration for the LLaVA Next model based on the HuggingFace model. + + Returns: + LlavaNextConfig: A configuration object for the LLaVA Next model. + """ + from transformers import LlavaConfig as HFLlavaConfig + + source = HFLlavaConfig.from_pretrained(str(self)) + text_conifg = source.text_config + + def make_vocab_size_divisible_by(vocab_size): + base = 128 + while vocab_size % base != 0: + base //= 2 + return base + + language_transformer_config = LlamaConfig( + num_layers=text_conifg.num_hidden_layers, + hidden_size=text_conifg.hidden_size, + ffn_hidden_size=text_conifg.intermediate_size, + num_attention_heads=text_conifg.num_attention_heads, + init_method_std=text_conifg.initializer_range, + layernorm_epsilon=text_conifg.rms_norm_eps, + num_query_groups=text_conifg.num_key_value_heads, + rotary_base=text_conifg.rope_theta, + gated_linear_unit=True, + make_vocab_size_divisible_by=make_vocab_size_divisible_by(text_conifg.vocab_size), + share_embeddings_and_output_weights=False, + ) + vision_transformer_config = HFCLIPVisionConfig( + pretrained_model_name_or_path="openai/clip-vit-large-patch14-336" + ) + vision_projection_config = MultimodalProjectorConfig(input_size=1024, hidden_size=4096, ffn_hidden_size=4096) + + output = LlavaNextConfig( + language_transformer_config=language_transformer_config, + vision_transformer_config=vision_transformer_config, + vision_projection_config=vision_projection_config, + vision_feature_layer=source.vision_feature_layer, + ) + + return output + + +__all__ = [ + "LlavaNextModel", +] diff --git a/nemo/collections/vlm/llava_next/model/utils.py b/nemo/collections/vlm/llava_next/model/utils.py new file mode 100644 index 000000000000..2996bc277983 --- /dev/null +++ b/nemo/collections/vlm/llava_next/model/utils.py @@ -0,0 +1,436 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +# 'These functions implementation is adapted from +# https://github.com/huggingface/transformers/blob/ +# 53fad641cfdb5105e2470bcf3ef17ea8e25cc300/src/transformers/models/llava_next/modeling_llava_next.py' + + +def get_image_sequence_length(img_h, img_w, patch_dim, add_class_token, class_token_len): + """Get image sequence length given image size, patch size, and class token.""" + num_patches_per_dim_h = img_h // patch_dim + num_patches_per_dim_w = img_w // patch_dim + num_patches = num_patches_per_dim_h * num_patches_per_dim_w + return num_patches + (class_token_len if add_class_token else 0) + + +def merge_input_ids_with_image_features( + image_features, + feature_lens, + inputs_embeds, + input_ids, + attention_mask, + position_ids=None, + labels=None, + image_token_index=-200, + ignore_index=-100, +): + """ + Merge input_ids with with image features into final embeddings + Args: + image_features (`torch.Tensor` of shape `(all_feature_lens, embed_dim)`): + All vision vectors of all images in the batch + feature_lens (`torch.LongTensor` of shape `(num_images)`): + The length of visual embeddings of each image as stacked in `image_features` + inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, embed_dim)`): + Token embeddings before merging with visual embeddings + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Input_ids of tokens, possibly filled with image token + attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Mask to avoid performing attention on padding token indices. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*) + :abels need to be recalculated to support training (if provided) + image_token_index (`int`, *optional*) + Token id used to indicate the special "image" token. Defaults to `config.image_token_index` + ignore_index (`int`, *optional*) + Value that is used to pad `labels` and will be ignored when calculated loss. Default: -100. + Returns: + final_embedding, final_attention_mask, position_ids, final_labels + Explanation: + each image has variable length embeddings, with length specified by feature_lens + image_features is concatenation of all visual embed vectors + task: fill each with the correct number of visual embeddings + Example: + X (5 patches), Y (3 patches), Z (8) + X, Y are in the same sequence (in-context learning) + if right padding + input_ids: [ + a b c d e f X g h i j k Y l m + o p q r Z s t u v _ _ _ _ _ _ + ] + input_ids should be: [ + a b c d e f X X X X X g h i j k Y Y Y l m + o p q r Z Z Z Z Z Z Z Z s t u v _ _ _ _ _ + ] + labels should be: [ + a b c d e f _ _ _ _ _ g h i j k _ _ _ l m + o p q r _ _ _ _ _ _ _ _ s t u v _ _ _ _ _ + ] + elif left padding + input_ids: [ + a b c d e f X g h i j k Y l m + _ _ _ _ _ _ o p q r Z s t u v + ] + input_ids should be: [ + a b c d e f X X X X X g h i j k Y Y Y l m + _ _ _ _ _ o p q r Z Z Z Z Z Z Z Z s t u v + ] + labels should be: [ + a b c d e f _ _ _ _ _ g h i j k _ _ _ l m + _ _ _ _ _ o p q r _ _ _ _ _ _ _ _ s t u v + ] + Edge cases: + * If tokens are same but image token sizes are different, then cannot infer left or right padding + ```python + cat_img = Image.open(requests.get( + "http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw) + chart_img = Image.open(requests.get( + "https://github.com/haotian-liu/LLaVA/blob/" + "1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true" + , stream=True).raw) + prompts = [ + "[INST] \nWhat is shown in this image? [/INST]", + "[INST] \nWhat is shown in this image? [/INST]", + ] + inputs = processor(prompts, [chart_img, cat_img], return_tensors='pt', padding=True).to("cuda") + chart_img has 2634 tokens, while cat_img has 2340 tokens + ``` + input_ids: [ + a b c d X g h + i j Y k l m n + ] + where X is 3 tokens while Y is 5, this mean after merge + if left-padding (batched generation) + input_ids should be: [ + _ _ a b c d X X X g h + i j Y Y Y Y Y k l m n + ] + elif (right padding) (training) + input_ids should be: [ + a b c d X X X g h _ _ + i j Y Y Y Y Y k l m n + ] + """ + + padding_side = 'right' + pad_token_id = 0 + with torch.no_grad(): + # ! in llava 1.6, number of patches is variable + num_images = feature_lens.size(0) + num_image_features, embed_dim = image_features.shape + if feature_lens.sum() != num_image_features: + raise ValueError(f"{feature_lens=} / {feature_lens.sum()} != {image_features.shape=}") + batch_size = input_ids.shape[0] + _left_padding = torch.any(attention_mask[:, 0] == 0) + _right_padding = torch.any(attention_mask[:, -1] == 0) + + left_padding = padding_side == "left" + if batch_size > 1: + if _left_padding and _right_padding: + raise ValueError(f"both side of attention_mask has zero, invalid. {attention_mask}") + elif _right_padding and left_padding: + left_padding = False + elif _left_padding and not left_padding: + left_padding = True + # Whether to turn off right padding + # 1. Create a mask to know where special image tokens are + special_image_token_mask = input_ids == image_token_index + # special_image_token_mask: [bsz, seqlen] + num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) + # num_special_image_tokens: [bsz] + # Reserve for padding of num_images + total_num_special_image_tokens = torch.sum(special_image_token_mask) + if total_num_special_image_tokens != num_images: + raise ValueError( + f"Number of image tokens in input_ids ({total_num_special_image_tokens}) " + f"different from num_images ({num_images})." + ) + # Compute the maximum embed dimension + # max_image_feature_lens is max_feature_lens per batch + feature_lens = feature_lens.to(input_ids.device) + feature_lens_batch = feature_lens.split(num_special_image_tokens.tolist(), dim=0) + feature_lens_batch_sum = torch.tensor([x.sum() for x in feature_lens_batch], device=input_ids.device) + embed_sequence_lengths = ( + (attention_mask == 1).long().sum(-1) - num_special_image_tokens + feature_lens_batch_sum + ) + max_embed_dim = embed_sequence_lengths.max() + batch_indices, non_image_indices = torch.where((input_ids != image_token_index) & (attention_mask == 1)) + # batch_indices, non_image_indices = torch.where((input_ids != image_token_index) ) + # 2. Compute the positions where text should be written + # Calculate new positions for text tokens in merged image-text sequence. + # `special_image_token_mask` identifies image tokens. + # Each image token will be replaced by `nb_text_tokens_per_images` text tokens. + # `torch.cumsum` computes how each image token shifts subsequent text token positions. + # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. + # ! instead of special_image_token_mask * (num_image_patches - 1) + # special_image_token_mask * (num_feature_len - 1) + special_image_token_mask = special_image_token_mask.long() + special_image_token_mask[special_image_token_mask == 1] = feature_lens - 1 + new_token_positions = torch.cumsum((special_image_token_mask + 1), -1) - 1 + if left_padding: + # shift right token positions so that they are ending at the same number + # the below here was incorrect? + # new_token_positions += new_token_positions[:, -1].max() - new_token_positions[:, -1:] + new_token_positions += max_embed_dim - 1 - new_token_positions[:, -1:] + + text_to_overwrite = new_token_positions[batch_indices, non_image_indices] + # 3. Create the full embedding, already padded to the maximum position + final_embedding = torch.zeros( + batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device + ) + final_attention_mask = torch.zeros( + batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device + ) + final_input_ids = torch.full( + (batch_size, max_embed_dim), pad_token_id, dtype=input_ids.dtype, device=inputs_embeds.device + ) + # In case the Vision model or the Language model has been offloaded to CPU, we need to manually + # set the corresponding tensors into their correct target device. + target_device = inputs_embeds.device + batch_indices, non_image_indices, text_to_overwrite = ( + batch_indices.to(target_device), + non_image_indices.to(target_device), + text_to_overwrite.to(target_device), + ) + attention_mask = attention_mask.to(target_device) + input_ids = input_ids.to(target_device) + + # 4. Fill the embeddings based on the mask. If we have ["hey" "", "how", "are"] + # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features + final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices] + final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices] + final_input_ids[batch_indices, text_to_overwrite] = input_ids[batch_indices, non_image_indices] + final_labels = None + if labels is not None: + labels = labels.to(target_device) + final_labels = torch.full_like(final_attention_mask, ignore_index).to(torch.long) + final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices] + + # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835) + with torch.no_grad(): + image_to_overwrite = torch.full( + (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device + ) + image_to_overwrite[batch_indices, text_to_overwrite] = False + embed_indices = torch.arange(max_embed_dim).unsqueeze(0).to(target_device) + embed_indices = embed_indices.expand(batch_size, max_embed_dim) + embed_seq_lens = embed_sequence_lengths[:, None].to(target_device) + + if left_padding: + # exclude padding on the left + max_embed_dim = max_embed_dim.to(target_device) + val = (max_embed_dim - embed_indices) <= embed_seq_lens + else: + # exclude padding on the right + val = embed_indices < embed_seq_lens + image_to_overwrite &= val + + if image_to_overwrite.sum() != num_image_features: + raise ValueError( + f"{image_to_overwrite.sum()=} != {num_image_features=} The input provided to the model are wrong. " + f"The number of image tokens is {torch.sum(special_image_token_mask)} while" + f" the number of image given to the model is {num_images}. " + f"This prevents correct indexing and breaks batch generation." + ) + final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device) + final_attention_mask |= image_to_overwrite + position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) + final_loss_mask = None + if final_labels is not None: + final_loss_mask = (final_labels != ignore_index).long() + + return final_embedding, final_attention_mask, position_ids, final_labels, final_input_ids, final_loss_mask + + +def unpad_image(tensor, original_size): + """ + Unpads a PyTorch tensor of a padded and resized image. + Args: + tensor (`torch.Tensor`): + The image tensor, assumed to be of shape (num_channels, height, width). + original_size (`tuple`): + The original size of the image (height, width). + Returns: + `torch.Tensor`: The unpadded image tensor. + """ + import numpy as np + + if not isinstance(original_size, (list, tuple)): + if not isinstance(original_size, (torch.Tensor, np.ndarray)): + raise TypeError( + f"image_size invalid type: {type(original_size)} not valid ", + "should be either list, tuple, np.ndarray or tensor", + ) + original_size = original_size.tolist() + original_height, original_width = original_size + current_height, current_width = tensor.shape[1:] + + original_aspect_ratio = original_width / original_height + current_aspect_ratio = current_width / current_height + + if original_aspect_ratio > current_aspect_ratio: + scale_factor = current_width / original_width + new_height = int(original_height * scale_factor) + padding = (current_height - new_height) // 2 + unpadded_tensor = tensor[:, padding : current_height - padding, :] + else: + scale_factor = current_height / original_height + new_width = int(original_width * scale_factor) + padding = (current_width - new_width) // 2 + unpadded_tensor = tensor[:, :, padding : current_width - padding] + + return unpadded_tensor + + +def select_best_resolution(original_size: tuple, possible_resolutions: list) -> tuple: + """ + Selects the best resolution from a list of possible resolutions based on the original size. + This is done by calculating the effective and wasted resolution for each possible resolution. + The best fit resolution is the one that maximizes the effective resolution and minimizes the wasted resolution. + Args: + original_size (tuple): + The original size of the image in the format (height, width). + possible_resolutions (list): + A list of possible resolutions in the format [(height1, width1), (height2, width2), ...]. + Returns: + tuple: The best fit resolution in the format (height, width). + """ + original_height, original_width = original_size + best_fit = None + max_effective_resolution = 0 + min_wasted_resolution = float("inf") + + for height, width in possible_resolutions: + scale = min(width / original_width, height / original_height) + downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale) + effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height) + wasted_resolution = (width * height) - effective_resolution + + if effective_resolution > max_effective_resolution or ( + effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution + ): + max_effective_resolution = effective_resolution + min_wasted_resolution = wasted_resolution + best_fit = (height, width) + + return best_fit + + +def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): + """ + Calculate the shape of the image patch grid after the preprocessing for images of any resolution. + Args: + image_size (`tuple`): + The size of the input image in the format (width, height). + grid_pinpoints (`List`): + A list containing possible resolutions. Each item in the list should be a tuple or list + of the form `(height, width)`. + patch_size (`int`): + The size of each image patch. + Returns: + tuple: The shape of the image patch grid in the format (width, height). + """ + import numpy as np + + if not isinstance(grid_pinpoints, list): + raise TypeError("grid_pinpoints should be a list of tuples or lists") + + # ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate + if not isinstance(image_size, (list, tuple)): + if not isinstance(image_size, (torch.Tensor, np.ndarray)): + raise TypeError( + f"image_size invalid type: {type(image_size)} not valid, " + "should be either list, tuple, np.ndarray or tensor" + ) + image_size = image_size.tolist() + + height, width = select_best_resolution(image_size, grid_pinpoints) + return height // patch_size, width // patch_size + + +# These functions implementation is adapted from +# https://github.com/huggingface/transformers/blob/ +# 53fad641cfdb5105e2470bcf3ef17ea8e25cc300/src/transformers/models/llava_next/modeling_llava_next.py#L655' + + +def pack_image_features(image_features, image_sizes, vision_feature_select_strategy, image_newline=None): + """ + Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors. + Args: + image_features (`List[torch.Tensor]` of length num_images, + each of shape `(num_patches, image_length, embed_dim)`) + List of image feature tensor, each contains all the visual feature of all patches. + image_sizes (`torch.Tensor` of shape `(num_images, 2)`) + Actual image size of each images (H, W). + vision_feature_select_strategy (`str`) + The feature selection strategy used to select the vision feature from the vision backbone. + image_newline (`torch.Tensor` of shape `(embed_dim)`) + New line embedding vector. + Returns: + image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`) + feature_lens (`List[int]`) + token length of each image in image_features + """ + from transformers import LlavaNextConfig + + config = LlavaNextConfig() + new_image_features = [] + feature_lens = [] + + for image_idx, image_feature in enumerate(image_features): + if image_feature.shape[0] > 1: + base_image_feature = image_feature[0] + image_feature = image_feature[1:] + height = width = config.vision_config.image_size // config.vision_config.patch_size + + if vision_feature_select_strategy == "default": + expected_num_patches = height * width + elif vision_feature_select_strategy == "full": + expected_num_patches = height * width + 1 + if expected_num_patches != base_image_feature.shape[0]: + raise ValueError("The number of patches is not consistent with the image size.") + + num_patch_height, num_patch_width = get_anyres_image_grid_shape( + image_sizes[image_idx], + config.image_grid_pinpoints, + config.vision_config.image_size, + ) + image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) + image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = unpad_image(image_feature, image_sizes[image_idx]) + if image_newline is not None: + image_feature = torch.cat( + ( + image_feature, + image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.dtype), + ), + dim=-1, + ) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + image_feature = torch.cat((base_image_feature, image_feature), dim=0) + else: + image_feature = image_feature[0] + if image_newline is not None: + image_feature = torch.cat((image_feature, image_newline[None].to(image_feature)), dim=0) + new_image_features.append(image_feature) + feature_lens.append(image_feature.size(0)) + image_features = torch.cat(new_image_features, dim=0) + feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features.device) + return image_features, feature_lens diff --git a/nemo/collections/vlm/mllama/data/lazy.py b/nemo/collections/vlm/mllama/data/lazy.py index 30b8b2ea9d9c..5069f8593377 100644 --- a/nemo/collections/vlm/mllama/data/lazy.py +++ b/nemo/collections/vlm/mllama/data/lazy.py @@ -18,10 +18,10 @@ import re from typing import Any, Dict, List, Optional, Sequence -import pytorch_lightning as pl +import lightning.pytorch as pl import torch import torch.nn.functional as F -from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS +from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS from torch.utils import data from torch.utils.data import DataLoader, default_collate diff --git a/nemo/collections/vlm/mllama/data/mock.py b/nemo/collections/vlm/mllama/data/mock.py index bb3afe83ea46..4d078c745492 100644 --- a/nemo/collections/vlm/mllama/data/mock.py +++ b/nemo/collections/vlm/mllama/data/mock.py @@ -14,10 +14,10 @@ from typing import Dict, List, Optional, Tuple +import lightning.pytorch as pl import numpy as np -import pytorch_lightning as pl import torch -from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS +from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS from torch.utils import data from torch.utils.data import DataLoader, Dataset @@ -25,6 +25,26 @@ class MockDataModule(pl.LightningDataModule): + """ + Mock DataModule for testing and development. + Generates synthetic data for training, validation, and testing purposes. + + Args: + seq_length (int): Sequence length for the generated data. + decoder_seq_length (Optional[int]): Decoder sequence length if applicable, used in pp. + vocab_size (int): Size of the vocabulary of tokenizer. + crop_size (Tuple[int, int]): Image crop size (height, width). + micro_batch_size (int): Micro batch size for data loading. + global_batch_size (int): Global batch size across all processes. + rampup_batch_size (Optional[List[int]]): Batch size ramp-up configuration. + num_train_samples (int): Number of training samples to generate. + num_val_samples (int): Number of validation samples to generate. + num_test_samples (int): Number of test samples to generate. + num_workers (int): Number of workers for data loading. + pin_memory (bool): Whether to pin memory for data loading. + persistent_workers (bool): Whether workers should remain persistent. + """ + def __init__( self, seq_length: int = 2048, @@ -34,6 +54,8 @@ def __init__( micro_batch_size: int = 4, global_batch_size: int = 8, rampup_batch_size: Optional[List[int]] = None, + tokenizer: Optional = None, + image_processor: Optional = None, num_train_samples: int = 10_000, num_val_samples: int = 10_000, num_test_samples: int = 10_000, @@ -52,6 +74,8 @@ def __init__( self.persistent_workers = persistent_workers self.vocab_size = vocab_size self.crop_size = crop_size + self.tokenizer = tokenizer + self.image_processor = image_processor self.data_sampler = MegatronDataSampler( seq_len=self.seq_length, @@ -62,6 +86,7 @@ def __init__( ) def setup(self, stage: str = "") -> None: + """Set up datasets for the specified stage.""" self._train_ds = _MockMLlamaDataset( self.vocab_size, self.crop_size, "train", self.num_train_samples, self.decoder_seq_length ) @@ -73,21 +98,25 @@ def setup(self, stage: str = "") -> None: ) def train_dataloader(self) -> TRAIN_DATALOADERS: + """Returns the DataLoader for training.""" if not hasattr(self, "_train_ds"): self.setup() return self._create_dataloader(self._train_ds) def val_dataloader(self) -> EVAL_DATALOADERS: + """Returns the DataLoader for validation.""" if not hasattr(self, "_validation_ds"): self.setup() return self._create_dataloader(self._validation_ds) def test_dataloader(self) -> EVAL_DATALOADERS: + """Returns the DataLoader for testing.""" if not hasattr(self, "_test_ds"): self.setup() return self._create_dataloader(self._test_ds) def _create_dataloader(self, dataset, **kwargs) -> DataLoader: + """Creates a DataLoader for the specified dataset.""" return DataLoader( dataset, num_workers=self.num_workers, @@ -99,6 +128,18 @@ def _create_dataloader(self, dataset, **kwargs) -> DataLoader: class _MockMLlamaDataset(Dataset): + """ + Mock dataset for generating synthetic data with text and image components. + + Args: + vocab_size (int): Vocabulary size for text data. + crop_size (Tuple[int, int]): Image crop size (height, width). + name (str): Name of the dataset split ('train', 'valid', 'test'). + num_samples (int): Number of samples in the dataset. + seq_length (int): Sequence length for the text data. + seed (int): Seed for random number generation. + """ + def __init__( self, vocab_size, @@ -123,13 +164,16 @@ def __init__( self.position_ids = torch.arange(self.seq_length, dtype=torch.int64) def __len__(self) -> int: + """Returns the number of samples in the dataset.""" return self.length def _get_text(self, idx: int) -> np.ndarray: + """Generates a random sequence of integers representing text tokens.""" np_gen = np.random.default_rng(seed=(self.seed + idx)) return np_gen.integers(self.vocab_size, size=[self.seq_length], dtype=np.int64) def __getitem__(self, idx) -> Dict[str, torch.Tensor]: + """Generates a single data sample.""" # Generate data of the expected size and datatype (based on GPTDataset). np_gen = np.random.default_rng(seed=(self.seed + idx)) tokens = torch.from_numpy(np_gen.integers(self.vocab_size, size=[self.seq_length + 1], dtype=np.int64)) @@ -142,8 +186,8 @@ def __getitem__(self, idx) -> Dict[str, torch.Tensor]: return { "images": images, - "masks": [[5, 512]], - "num_chunks": [4], + "masks": torch.tensor([[5, 512]]), + "num_chunks": torch.tensor([4]), "tokens": tokens, "aspect_ratio_ids": aspect_ratio_ids, "loss_mask": self.loss_mask, diff --git a/nemo/collections/vlm/mllama/model/base.py b/nemo/collections/vlm/mllama/model/base.py index f03af078987d..9279936e23d7 100644 --- a/nemo/collections/vlm/mllama/model/base.py +++ b/nemo/collections/vlm/mllama/model/base.py @@ -17,7 +17,7 @@ from dataclasses import dataclass from typing import Callable, Dict, List, Optional, Tuple -import pytorch_lightning as L +import lightning.pytorch as L import torch import torch.distributed from einops import rearrange @@ -40,13 +40,15 @@ from nemo.collections.vlm.mllama.model.language import CrossAttentionTextModel from nemo.collections.vlm.mllama.model.utils import _generate_cross_attention_mask, _pad_attention_masks from nemo.collections.vlm.mllama.model.vision import VisionEncoder +from nemo.collections.vlm.neva.model.base import MODEL_CONFIG_ATTR from nemo.lightning import get_vocab_size, io from nemo.lightning.megatron_parallel import MaskedTokenLossReduction from nemo.lightning.pytorch.optim import MegatronOptimizerModule, OptimizerModule from nemo.utils import logging -def llama_data_step(dataloader_iter) -> Dict[str, torch.Tensor]: +def mllama_data_step(dataloader_iter) -> Dict[str, torch.Tensor]: + """Mllama data step.""" from megatron.core import parallel_state # Based on: https://github.com/NVIDIA/Megatron-LM/blob/main/pretrain_gpt.py#L87 @@ -95,7 +97,8 @@ def llama_data_step(dataloader_iter) -> Dict[str, torch.Tensor]: return output -def llama_forward_step(model, batch) -> torch.Tensor: +def mllama_forward_step(model, batch) -> torch.Tensor: + """Mllama model forward step.""" forward_config = { "batch_images": batch["batch_images"], "batch_masks": batch["batch_masks"], @@ -113,13 +116,15 @@ def llama_forward_step(model, batch) -> torch.Tensor: def set_input_tensor(self, tensor): + """Placeholder for `set_input_tensor` method for PP implementation.""" pass @dataclass class CrossAttentionVisionConfig(TransformerConfig, io.IOMixin): - # core params + """Configuration for llama vision model.""" + # core params bias_activation_fusion: bool = True bias_dropout_add_fusion: bool = True @@ -149,9 +154,11 @@ class CrossAttentionVisionConfig(TransformerConfig, io.IOMixin): @property def max_aspect_ratio_id(self) -> int: + # pylint: disable=C0115,C0116 return len(self.supported_aspect_ratios) def configure_model(self) -> "CrossAttentionVisionModel": + """Configure mllama vision model.""" return CrossAttentionVisionModel( self, ) @@ -159,6 +166,10 @@ def configure_model(self) -> "CrossAttentionVisionModel": @dataclass class CrossAttentionTextConfig(Llama31Config): + """ + Configuration for llama model with cross-attention layers to take in multimodal features. + """ + rotary_base: int = 500_000 seq_length: int = 8192 num_layers: int = 32 @@ -170,12 +181,14 @@ class CrossAttentionTextConfig(Llama31Config): apply_rope_fusion: bool = False def _init_fusion_schedule(self, num_layers: int) -> List[int]: - llama_layers = list(range(self.num_layers)) + """Initialize self-attention layer / cross-attention layer fusion schedule""" + mllama_layers = list(range(self.num_layers)) # uniformly spread the layers - k = math.ceil(len(llama_layers) / num_layers) - return llama_layers[::-1][::k][:num_layers][::-1] + k = math.ceil(len(mllama_layers) / num_layers) + return mllama_layers[::-1][::k][:num_layers][::-1] def configure_model(self, tokenizer, pre_process=True, post_process=True): + """Configure mllama text model.""" self.fusion_schedule = self._init_fusion_schedule(self.num_cross_attention_layers) vp_size = self.virtual_pipeline_model_parallel_size if vp_size: @@ -224,6 +237,8 @@ def configure_model(self, tokenizer, pre_process=True, post_process=True): @dataclass class MLlamaModelConfig(TransformerConfig, io.IOMixin): + """Combined configuration for multimodal vision-language model.""" + language_model_config: Optional[CrossAttentionTextConfig] = None vision_model_config: Optional[CrossAttentionVisionConfig] = None @@ -236,42 +251,16 @@ class MLlamaModelConfig(TransformerConfig, io.IOMixin): language_model_from_pretrained: Optional[str] = None # TODO vision_model_from_pretrained: Optional[str] = None # TODO - forward_step_fn: Callable = llama_forward_step - data_step_fn: Callable = llama_data_step + forward_step_fn: Callable = mllama_forward_step + data_step_fn: Callable = mllama_data_step def __post_init__(self): - model_config_attr = [ - 'num_layers', - 'hidden_size', - 'num_attention_heads', - 'num_query_groups', - 'ffn_hidden_size', - 'kv_channels', - 'hidden_dropout', - 'attention_dropout', - 'fp32_residual_connection', - 'apply_residual_connection_post_layernorm', - 'layernorm_epsilon', - 'layernorm_zero_centered_gamma', - 'add_bias_linear', - 'add_qkv_bias', - 'gated_linear_unit', - 'activation_func', - 'activation_func_fp8_input_store', - 'num_moe_experts', - 'rotary_interleaved', - 'window_size', - 'normalization', - 'qk_layernorm', - 'test_mode', - 'calculate_per_token_loss', - ] - if self.language_model_config is not None: - for attr in model_config_attr: + for attr in MODEL_CONFIG_ATTR: setattr(self, attr, getattr(self.language_model_config, attr)) def configure_model(self, tokenizer) -> "MLlamaBaseModel": + """Configure mllama model.""" from megatron.core import parallel_state as ps self.language_model_config.tensor_model_parallel_size = self.tensor_model_parallel_size @@ -300,6 +289,8 @@ def configure_model(self, tokenizer) -> "MLlamaBaseModel": class CrossAttentionVisionModel(MegatronModule): + """Mllama vision model.""" + def __init__(self, config) -> None: super().__init__(config=config) return_intermediate = "3,7,15,23,30" @@ -329,6 +320,7 @@ def __init__(self, config) -> None: self.vision_projection.encoder.skip_bias_add = False # Temporary fix for a MCore side bug def forward(self, images: torch.Tensor, aspect_ratio_ids: torch.Tensor) -> torch.Tensor: + """Forward.""" # vision_tokens: (B, T, D) # aspect_ratio_ids: (B, 1) # h: (B, T, D) @@ -339,10 +331,13 @@ def forward(self, images: torch.Tensor, aspect_ratio_ids: torch.Tensor) -> torch return vision_tokens def set_input_tensor(self, tensor): + # pylint: disable=C0115,C0116 pass class MLlamaBaseModel(MegatronModule): + """Mllama base model combining vision and text models with cross-attention.""" + def __init__( self, config: MLlamaModelConfig, @@ -382,10 +377,6 @@ def __init__( self.patch_size = 14 self.image_res = vision_model_config.vision_chunk_size self.max_num_chunks = vision_model_config.vision_max_num_chunks - logging.warning("[WARNING] NeMo Mllama will always pad images to max number of tiles. A fix is coming soon!") - - def setup_cache(self, max_batch_size: int, dtype: torch.dtype): - self.language_model.setup_cache(max_batch_size, dtype) def compute_xattn_caches_masks( self, @@ -395,6 +386,7 @@ def compute_xattn_caches_masks( num_chunks: torch.Tensor, total_len: int, ) -> Tuple[List, torch.Tensor, torch.Tensor]: + """Compute xattn caches masks used in text model.""" bsz, nimg, nchunk, ntok, image_token_dim = vision_orig_shape xattn_caches = [ @@ -434,6 +426,7 @@ def forward( full_text_row_masked_out_mask: Optional[torch.Tensor] = None, xattn_caches: Optional[List] = None, ) -> torch.Tensor: + """Forward.""" if xattn_caches is None: bsz, max_num_images = batch_images.size(0), batch_images.size(1) vision_orig_shape = ( @@ -444,8 +437,8 @@ def forward( self.config.hidden_size, ) skip_vision_encoder = False - num_chunks[num_chunks > 0] = self.max_num_chunks if max_num_images == 0: + num_chunks[num_chunks > 0] = self.max_num_chunks skip_vision_encoder = True if self.encoder_hidden_state is not None: @@ -515,6 +508,8 @@ def set_input_tensor(self, input_tensor) -> None: class MLlamaModel(L.LightningModule, io.IOMixin, io.ConnectorMixin, fn.FNMixin): + """Lightning Module for the MLlama model.""" + def __init__( self, config: MLlamaModelConfig, @@ -532,6 +527,7 @@ def __init__( self._validation_loss_reduction = None def configure_model(self) -> None: + """Configure mllama model""" if not hasattr(self, "module"): self.module: MLlamaBaseModel = self.config.configure_model(self.tokenizer) @@ -548,7 +544,7 @@ def forward( full_text_row_masked_out_mask: Optional[torch.Tensor] = None, xattn_caches: Optional[torch.Tensor] = None, ) -> torch.Tensor: - + """Forward.""" output_tensor = self.module( position_ids=position_ids, tokens=tokens, @@ -565,22 +561,26 @@ def forward( return output_tensor def data_step(self, dataloader_iter) -> Dict[str, torch.Tensor]: + # pylint: disable=C0115,C0116 return self.config.data_step_fn(dataloader_iter) def forward_step(self, batch) -> torch.Tensor: + # pylint: disable=C0115,C0116 return self.config.forward_step_fn(self, batch) def training_step(self, batch, batch_idx=None) -> torch.Tensor: + # pylint: disable=C0115,C0116 # In mcore the loss-function is part of the forward-pass (when labels are provided) return self.forward_step(batch) def validation_step(self, batch, batch_idx=None) -> torch.Tensor: + # pylint: disable=C0115,C0116 # In mcore the loss-function is part of the forward-pass (when labels are provided) - return self.forward_step(batch) @property def training_loss_reduction(self) -> MaskedTokenLossReduction: + # pylint: disable=C0115,C0116 if not self._training_loss_reduction: self._training_loss_reduction = MaskedTokenLossReduction() @@ -588,6 +588,7 @@ def training_loss_reduction(self) -> MaskedTokenLossReduction: @property def validation_loss_reduction(self) -> MaskedTokenLossReduction: + # pylint: disable=C0115,C0116 if not self._validation_loss_reduction: self._validation_loss_reduction = MaskedTokenLossReduction(validation_step=True) @@ -599,8 +600,8 @@ def validation_loss_reduction(self) -> MaskedTokenLossReduction: "MLlamaModelConfig", "CrossAttentionTextConfig", "CrossAttentionVisionConfig", - "llama_data_step", - "llama_forward_step", + "mllama_data_step", + "mllama_forward_step", "transformer_engine_layer_spec", "local_layer_spec", ] diff --git a/nemo/collections/vlm/mllama/model/language.py b/nemo/collections/vlm/mllama/model/language.py index b8985e53c54c..5d4cc2e09f21 100644 --- a/nemo/collections/vlm/mllama/model/language.py +++ b/nemo/collections/vlm/mllama/model/language.py @@ -60,6 +60,10 @@ @dataclass class MLlamaCrossAttentionSubmodules: + """ + Defines the submodules required for cross-attention layers in the Llama architecture. + """ + linear_q: Union[ModuleSpec, type] = None linear_kv: Union[ModuleSpec, type] = None core_attention: Union[ModuleSpec, type] = None @@ -69,6 +73,10 @@ class MLlamaCrossAttentionSubmodules: class CrossAttentionTextModel(MCoreGPTModel): + """ + GPT-based model with integrated cross-attention layers for multimodal tasks. + """ + def __init__( self, config: TransformerConfig, @@ -122,6 +130,7 @@ def __init__( self._thresh = self.num_frozen_embeddings - 1 def get_partially_trainable_embedding(self, x): + """Get word embedding w/ few extra learnable tokens.""" xz = torch.zeros_like(x, device=x.device) oz = torch.ones_like(x, device=x.device) x_orig = torch.minimum(x, torch.tensor(self._thresh, device=x.device)) @@ -148,7 +157,7 @@ def forward( packed_seq_params: PackedSeqParams = None, extra_block_kwargs: dict = None, ) -> Tensor: - + """Forward.""" # Decoder embedding. if decoder_input is not None: pass @@ -171,6 +180,9 @@ def forward( ) rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len) + dtype = decoder_input.dtype + cross_attention_bias = cross_attention_masks.to(dtype) * torch.finfo(dtype).min + # Run decoder. hidden_states = self.decoder( hidden_states=decoder_input, @@ -178,9 +190,10 @@ def forward( inference_params=inference_params, rotary_pos_emb=rotary_pos_emb, packed_seq_params=packed_seq_params, - cross_attention_masks=cross_attention_masks, + cross_attention_masks=None, full_text_row_masked_out_mask=full_text_row_masked_out_mask, xattn_caches=xattn_caches, + cross_attention_bias=cross_attention_bias, **(extra_block_kwargs or {}), ) @@ -203,6 +216,10 @@ def forward( class CrossAttentionTransformerBlock(TransformerBlock): + """ + Transformer block with integrated cross-attention layers for multimodal tasks. + """ + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -220,7 +237,7 @@ def __init__(self, *args, **kwargs): submodules=TransformerLayerSubmodules( cross_attention=ModuleSpec( module=MLlamaCrossAttention, - params={"attn_mask_type": AttnMaskType.arbitrary}, + params={"attn_mask_type": AttnMaskType.no_mask}, submodules=MLlamaCrossAttentionSubmodules( linear_q=TELayerNormColumnParallelLinear, # This wraps attention_norm before attention linear_kv=TEColumnParallelLinear, @@ -250,6 +267,7 @@ def __init__(self, *args, **kwargs): assert len(self.xattn_layers) == len(self.layers), 'Check PP implementation for cross attention layers!' def _get_layer_offset(self): + """Get correct layer offset when encoder pipeline parallel size > 0.""" encoder_pipeline_model_parallel_size = getattr(self.config, "encoder_pipeline_model_parallel_size", 0) decoder_pipeline_model_parallel_rank = ( parallel_state.get_pipeline_model_parallel_rank() - encoder_pipeline_model_parallel_size @@ -264,9 +282,12 @@ def forward( cross_attention_masks: Tensor = None, full_text_row_masked_out_mask: Tensor = None, rotary_pos_emb: Tensor = None, + attention_bias: Tensor = None, + cross_attention_bias: Tensor = None, inference_params: InferenceParams = None, packed_seq_params: PackedSeqParams = None, ): + """Forward.""" # hidden_states (float): [s, b, h] # attention_mask (bool): [1, 1, s, s] @@ -324,6 +345,7 @@ def forward( xattn_cache=xattn_caches[l_no], full_text_row_masked_out_mask=full_text_row_masked_out_mask, rotary_pos_emb=rotary_pos_emb, + cross_attention_bias=cross_attention_bias, inference_params=inference_params, packed_seq_params=packed_seq_params, ) @@ -331,6 +353,7 @@ def forward( hidden_states=hidden_states, attention_mask=attention_mask, rotary_pos_emb=rotary_pos_emb, + attention_bias=attention_bias, inference_params=inference_params, packed_seq_params=packed_seq_params, ) @@ -361,6 +384,7 @@ def forward( def sharded_state_dict( self, prefix: str = '', sharded_offsets: tuple = (), metadata: dict = None ) -> ShardedStateDict: + """Update shareded state dict for cross-attention layers""" sharded_state_dict = {} layer_prefix = f'{prefix}layers.' @@ -399,6 +423,10 @@ def sharded_state_dict( class CrossAttentionTransformerLayer(TransformerLayer): + """ + Transformer layer with cross-attention for integration. + """ + def __init__( self, config: TransformerConfig, @@ -417,6 +445,7 @@ def __init__( self.gate_ffn = nn.Parameter(torch.zeros(1, dtype=self.config.params_dtype)) def compute_xattn_kv_cache(self, xattn_tokens: Tensor) -> Tensor: + """Compute cross-attention kv cahce.""" return self.cross_attention._compute_xattn_kv_cache(xattn_tokens) def forward( @@ -426,9 +455,11 @@ def forward( xattn_cache=None, full_text_row_masked_out_mask=None, rotary_pos_emb=None, + cross_attention_bias=None, inference_params=None, packed_seq_params=None, ): + """Forward.""" # hidden_states: [s, b, h] # Residual connection. @@ -444,6 +475,7 @@ def forward( xattn_cache=xattn_cache, full_text_row_masked_out_mask=full_text_row_masked_out_mask, rotary_pos_emb=rotary_pos_emb, + cross_attention_bias=cross_attention_bias, inference_params=inference_params, ) @@ -507,11 +539,13 @@ def __call__( return hidden_states, None def compute_xattn_kv_cache(self, xattn_tokens: Tensor) -> Optional[Tensor]: + # pylint: disable=C0115,C0116 return None class MLlamaCrossAttention(Attention): - """Cross-attention layer class for Llama VLM support + """ + Cross-attention layer for Llama multimodal tasks. Cross-attention layer takes input with size [s, b, h] and context with size [s, b, h] and returns output of the same size. @@ -574,6 +608,7 @@ def __init__( ) def get_key_value_tensors(self, key_value_states): + """Get key value tensors.""" mixed_kv, _ = self.linear_kv(key_value_states) # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn] @@ -590,7 +625,7 @@ def get_key_value_tensors(self, key_value_states): return key, value def get_query_tensor(self, hidden_states): - + """ "Get query tensor.""" # Attention head [sq, b, h] --> [sq, b, hp] query, _ = self.linear_q(hidden_states) @@ -607,6 +642,7 @@ def get_query_tensor(self, hidden_states): return query def get_query_key_value_tensors(self, hidden_states, key_value_states): + """Get query key value tensors.""" query = self.get_query_tensor(hidden_states) key, value = self.get_key_value_tensors(key_value_states) return query, key, value @@ -619,8 +655,17 @@ def forward( full_text_row_masked_out_mask=None, inference_params=None, rotary_pos_emb=None, + rotary_pos_cos=None, + rotary_pos_sin=None, + cross_attention_bias=None, packed_seq_params=None, ): + """Forward.""" + # hidden_states: [sq, b, h] + if self.config.flash_decode: + rotary_pos_emb = None + else: + assert rotary_pos_cos is None and rotary_pos_sin is None # For self attention we just duplicate the rotary_pos_emb if it isn't already if rotary_pos_emb is not None and not isinstance(rotary_pos_emb, tuple): @@ -637,8 +682,8 @@ def forward( # =================================================== # Adjust key, value, and rotary_pos_emb for inference # =================================================== - key, value, rotary_pos_emb, attn_mask_type = self._adjust_key_value_for_inference( - inference_params, key, value, rotary_pos_emb + query, key, value, rotary_pos_emb, attn_mask_type = self._adjust_key_value_for_inference( + inference_params, query, key, value, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin ) if packed_seq_params is not None: @@ -650,9 +695,6 @@ def forward( # core attention computation # ================================== - # In TE "True" means masked out - cross_attention_masks = torch.where(cross_attention_masks == 0, False, True) - if self.checkpoint_core_attention and self.training: core_attn_out = self._checkpointed_attention_forward( query, @@ -660,6 +702,7 @@ def forward( value, cross_attention_masks, attn_mask_type=attn_mask_type, + attention_bias=cross_attention_bias, packed_seq_params=packed_seq_params, ) else: @@ -669,6 +712,7 @@ def forward( value, cross_attention_masks, attn_mask_type=attn_mask_type, + attention_bias=cross_attention_bias, packed_seq_params=packed_seq_params, ) @@ -702,8 +746,22 @@ def apply_rope_scaling( high_freq_factor: int = 4, old_context_len: int = 8192, ): + """ + Apply scaling to rotary embeddings for positional encoding. + + Args: + inv_freq (Tensor): Tensor of inverse frequencies. + factor (int): Scaling factor for medium-to-high frequencies. + low_freq_factor (int): Factor for identifying low frequencies. + high_freq_factor (int): Factor for identifying high frequencies. + old_context_len (int): Original context length for scaling computation. + + Returns: + Tensor: Scaled inverse frequencies. + """ logging.info( - f"Apply rope scaling with factor={factor}, low_freq_factor={low_freq_factor}, high_freq_factor={high_freq_factor}, old_context_len={old_context_len}." + f"Apply rope scaling with factor={factor}, low_freq_factor={low_freq_factor}, " + f"high_freq_factor={high_freq_factor}, old_context_len={old_context_len}." ) low_freq_wavelen = old_context_len / low_freq_factor diff --git a/nemo/collections/vlm/mllama/model/vision.py b/nemo/collections/vlm/mllama/model/vision.py index f662546d21ae..bb58ad093cd6 100644 --- a/nemo/collections/vlm/mllama/model/vision.py +++ b/nemo/collections/vlm/mllama/model/vision.py @@ -59,6 +59,9 @@ def to_2tuple(x): + """ + Convert an input to a 2-tuple. + """ if isinstance(x, collections.abc.Iterable): return x return (x, x) @@ -71,9 +74,16 @@ def _stack_images( max_num_images: int, ) -> Tuple[torch.Tensor, List[int]]: """ - Takes a list of list of images and stacks them into a tensor. - This function is needed since images can be of completely - different resolutions and aspect ratios. + Stack a list of image lists into a tensor while accounting for varying resolutions and aspect ratios. + + Args: + images (List[List[PIL_Image.Image]]): List of image lists for stacking. + max_num_chunks (int): Maximum number of chunks per image. + image_res (int): Target resolution for each image. + max_num_images (int): Maximum number of images to stack. + + Returns: + Tuple[torch.Tensor, List[int]]: Tensor of stacked images and a list of chunk counts for each image. """ out_images, out_num_chunks = [], [] for imgs_sample in images: @@ -97,22 +107,36 @@ def build_encoder_attention_mask( x: torch.Tensor, ar_ids: torch.Tensor, ntok: int, num_chunks: int, supported_aspect_ratios: List[List[int]] ): """ - Build vision encoder attention mask that omits padding tiles and tokens. + Build attention masks for a vision encoder to handle padding and token alignment. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, sequence_length). + ar_ids (torch.Tensor): Aspect ratio IDs for masking. + ntok (int): Number of tokens. + num_chunks (int): Number of chunks in the data. + supported_aspect_ratios (List[List[int]]): List of supported aspect ratios. + + Returns: + torch.Tensor: Tensor containing the attention mask. """ masks = [] + dtype = x.dtype for ar_id in ar_ids: arx = supported_aspect_ratios[ar_id - 1] mask_i = torch.ones((num_chunks, x.shape[1] // num_chunks), device=x.device) mask_i[: arx[0] * arx[1], :ntok] = 0 mask_i = mask_i.view(num_chunks * x.shape[1] // num_chunks, -1) - mask_i = (mask_i @ mask_i.T).type(torch.bool) + mask_i = mask_i @ mask_i.T mask_i = mask_i.unsqueeze(0) masks.append(mask_i) - masks = torch.stack(masks) + masks = torch.stack(masks).to(dtype) * torch.finfo(dtype).min return masks def apply_scaling(freqs: torch.Tensor): + """ + Scale frequency values based on predefined thresholds and a smoothing factor. + """ # Values obtained from grid search scale_factor = 8 low_freq_factor = 1 @@ -137,6 +161,9 @@ def apply_scaling(freqs: torch.Tensor): # Use this spec for an implementation using modules in TE def get_image_transformer_layer_spec() -> ModuleSpec: + """ + Create a specification for an image transformer layer. + """ image_transformer_submodules = TransformerLayerSubmodules( input_layernorm=TENorm, self_attention=ModuleSpec( @@ -171,10 +198,15 @@ def forward_with_return_intermediate( context: Tensor = None, context_mask: Tensor = None, rotary_pos_emb: Tensor = None, + attention_bias: Tensor = None, inference_params: InferenceParams = None, packed_seq_params: PackedSeqParams = None, return_intermediate: List[int] = None, ): + """ + Perform a forward pass through the transformer layers with optional intermediate outputs. + Override regular MCore transformer layer forward pass. + """ # hidden_states (float): [s, b, h] # attention_mask (bool): [1, 1, s, s] @@ -223,6 +255,7 @@ def forward_with_return_intermediate( context=context, context_mask=context_mask, rotary_pos_emb=rotary_pos_emb, + attention_bias=attention_bias, packed_seq_params=packed_seq_params, ) else: @@ -239,6 +272,7 @@ def forward_with_return_intermediate( context=context, context_mask=context_mask, rotary_pos_emb=rotary_pos_emb, + attention_bias=attention_bias, inference_params=inference_params, packed_seq_params=packed_seq_params, ) @@ -278,16 +312,22 @@ def forward_with_return_intermediate( class ColumnParallelConv2dPatch(MegatronModule): - """Conv2D Patching layer with model parallelism. - Column parallel over unfolded input. - Arguments: - in_channels: Input channels. - out_channels: Output channels. - kernel_size: Size of convolution kernel. - stride (default 1): Stride for convolution. - bias (default False): Use bias in Conv2d. - Input: (bsz, in_channels, width, height) - Output: (bsz, num_tokens, out_channels) + """ + Conv2D Patching layer with model parallelism. Applies convolution in a column-parallel fashion. + + Args: + config (TransformerConfig): Configuration object for the layer. + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + kernel_size (Union[int, Tuple[int, int]]): Size of the convolution kernel. + stride (Union[int, Tuple[int, int]]): Stride of the convolution. + bias (Optional[bool], default=False): Whether to include a bias term. + + Input: + torch.Tensor: Input tensor of shape (batch_size, in_channels, width, height). + + Output: + torch.Tensor: Output tensor of shape (batch_size, num_tokens, out_channels). """ def __init__( @@ -316,6 +356,7 @@ def __init__( ) def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward.""" x = self._unfold(x) x = x.permute(0, 2, 1) x = F.linear(x, self._linear.weight) @@ -324,6 +365,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class PrecomputedTilePositionEmbedding(torch.nn.Module): + """ + Module to compute positional embeddings for tiles with optional gating. + + Args: + config (TransformerConfig): Configuration object. + gated (bool, default=False): Whether to apply gating to the embeddings. + """ + def __init__( self, config: TransformerConfig, @@ -340,6 +389,7 @@ def __init__( self.gate = nn.Parameter(torch.zeros(1)) def forward(self, hidden_states: torch.Tensor, aspect_ratio_ids: torch.Tensor) -> torch.Tensor: + """Forward.""" embeddings = self.embedding(aspect_ratio_ids) embeddings = embeddings.reshape(-1, self.max_num_tiles, 1, self.hidden_size) @@ -351,7 +401,15 @@ def forward(self, hidden_states: torch.Tensor, aspect_ratio_ids: torch.Tensor) - class SelfAttentionNoBias(SelfAttention): - """Self-attention layer class without bias""" + """ + Self-attention layer implementation without bias. + + Args: + config (TransformerConfig): Configuration for the transformer. + submodules (SelfAttentionSubmodules): Submodules required for self-attention. + layer_number (int): The layer number in the transformer stack. + attn_mask_type (AttnMaskType): Type of attention mask to apply. + """ def __init__( self, @@ -396,6 +454,16 @@ def __init__( class ImageTransformerLayer(TransformerLayer): + """ + Transformer layer adapted for processing image data with optional gating. + + Args: + config (TransformerConfig): Transformer configuration object. + submodules (TransformerLayerSubmodules): Submodules to use in the layer. + layer_number (int, default=1): Layer number in the transformer. + hidden_dropout (float, optional): Dropout rate for hidden layers. + """ + def __init__( self, config: TransformerConfig, @@ -423,9 +491,11 @@ def forward( rotary_pos_emb=None, rotary_pos_cos=None, rotary_pos_sin=None, + attention_bias=None, inference_params=None, packed_seq_params=None, ): + """Forward.""" # hidden_states: [s, b, h] # Residual connection. @@ -440,6 +510,7 @@ def forward( attention_mask=attention_mask, inference_params=inference_params, rotary_pos_emb=rotary_pos_emb, + attention_bias=attention_bias, packed_seq_params=packed_seq_params, ) @@ -485,6 +556,19 @@ def forward( class VisionEncoder(MegatronModule): + """ + Vision encoder module for processing image inputs with patch-based embeddings. + + Args: + config ('CrossAttentionVisionConfig'): Configuration object for the encoder. + image_size (int, default=560): Input image size. + patch_size (int, default=14): Size of patches extracted from the image. + in_channels (int, default=3): Number of input channels. + pre_process (bool, default=True): Whether to preprocess input. + post_process (bool, default=True): Whether to postprocess output. + return_intermediate (Optional[bool]): Whether to return intermediate layers. + """ + def __init__( self, config: 'CrossAttentionVisionConfig', @@ -556,7 +640,7 @@ def __init__( self.gated_positional_embedding_gate = nn.Parameter(torch.zeros(1)) def apply_positional_embedding(self, x, aspect_ratio_ids): - # apply regular position embedding + """Apply regular position embedding and tile positonal embedding.""" bsz, num_chunks, num_tokens, dim = x.shape x = x.view(bsz * num_chunks, num_tokens, dim) x = x + self.positional_embedding * (1 - self.gated_positional_embedding_gate.tanh()) @@ -567,6 +651,7 @@ def apply_positional_embedding(self, x, aspect_ratio_ids): return x def apply_class_embedding(self, x): + """Concat class embedding tokens.""" x = torch.cat( [ self.class_embedding.to(x.dtype) @@ -578,6 +663,7 @@ def apply_class_embedding(self, x): return x def forward(self, images: torch.Tensor, ar_ids: torch.Tensor) -> torch.Tensor: + """Forward.""" if images.ndim == 5: num_concurrent_media = 1 bsz, num_chunks, nch, w, h = images.shape @@ -609,15 +695,17 @@ def forward(self, images: torch.Tensor, ar_ids: torch.Tensor) -> torch.Tensor: x = x.view(bsz * num_concurrent_media, -1, dim) npad, attn_mask = 0, None - attn_mask = build_encoder_attention_mask(x, ar_ids, ntok, num_chunks, self.config.supported_aspect_ratios) + attn_bias = build_encoder_attention_mask(x, ar_ids, ntok, num_chunks, self.config.supported_aspect_ratios) x = x.transpose(0, 1).contiguous() x, int_x = self.transformer( hidden_states=x, attention_mask=attn_mask, + attention_bias=attn_bias, return_intermediate=self.return_intermediate, ) - # [ntok * num_concurrent_media * num_chunks, bsz, hidden_size] -> [bsz, ntok * num_concurrent_media * num_chunks, hidden_size] + # [ntok * num_concurrent_media * num_chunks, bsz, hidden_size] + # -> [bsz, ntok * num_concurrent_media * num_chunks, hidden_size] x, int_x = x.transpose(0, 1).contiguous(), int_x.transpose(0, 1).contiguous() x = self.ln_post(x) x = x.reshape(bsz * num_concurrent_media, num_chunks, ntok + npad, dim) @@ -627,6 +715,7 @@ def forward(self, images: torch.Tensor, ar_ids: torch.Tensor) -> torch.Tensor: x = self.global_transformer( hidden_states=x, attention_mask=None, + attention_bias=attn_bias, ) x = x.transpose(0, 1) x = x.reshape(bsz * num_concurrent_media, num_chunks, ntok + npad, dim) diff --git a/nemo/collections/vlm/neva/data/api.py b/nemo/collections/vlm/neva/data/api.py index c2e51e033d8a..15ba45c82fd9 100644 --- a/nemo/collections/vlm/neva/data/api.py +++ b/nemo/collections/vlm/neva/data/api.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl from nemo.collections.vlm.neva.data.lazy import NevaLazyDataModule from nemo.collections.vlm.neva.data.mock import MockDataModule diff --git a/nemo/collections/vlm/neva/data/conversation.py b/nemo/collections/vlm/neva/data/conversation.py index d78d3bd28acb..58953dc53b7a 100644 --- a/nemo/collections/vlm/neva/data/conversation.py +++ b/nemo/collections/vlm/neva/data/conversation.py @@ -77,7 +77,6 @@ def process_chat_template(self, tokenizer_name_or_path, messages): def get_prompt(self): messages = self.messages - messages = self.process_prompt_with_images(messages) if self.sep_style == SeparatorStyle.SINGLE: ret = self.system + self.sep @@ -100,6 +99,8 @@ def get_prompt(self): if type(message) is tuple: message, _, _ = message ret += role + ": " + message + seps[i % 2] + # Add space to make sure the labels can be correctly generated. + self.messages[i][1] = " " + self.messages[i][1] else: ret += role + ":" @@ -155,7 +156,6 @@ def get_prompt(self): ret = self.process_chat_template(tokenizer_name_or_path, messages) elif self.sep_style == SeparatorStyle.MLLAMA: - """ """ tokenizer_name_or_path = self.tokenizer_name_or_path or "meta-llama/Llama-3.2-11B-Vision-Instruct" ret = self.process_chat_template(tokenizer_name_or_path, messages) diff --git a/nemo/collections/vlm/neva/data/lazy.py b/nemo/collections/vlm/neva/data/lazy.py index 57aa5b408835..5bc2cbe0458e 100644 --- a/nemo/collections/vlm/neva/data/lazy.py +++ b/nemo/collections/vlm/neva/data/lazy.py @@ -20,12 +20,12 @@ from typing import Any, Dict, List, Optional, Sequence import decord +import lightning.pytorch as pl import numpy as np -import pytorch_lightning as pl import torch import torch.nn.functional as F +from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS from PIL import Image -from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS from torch.utils import data from torch.utils.data import DataLoader, Dataset, default_collate from transformers import CLIPImageProcessor, SiglipImageProcessor @@ -251,7 +251,7 @@ def __init__( data_config, tokenizer, image_processor, - sequence_length, + sequence_length=None, ): super().__init__() if data_path is not None: @@ -497,6 +497,7 @@ def __init__( weights: Optional[List[float]] = None, data_config: Optional[DataConfig] = ImageDataConfig, seq_length: int = 2048, + decoder_seq_length: Optional[int] = None, tokenizer: Optional = None, image_processor: Optional = None, micro_batch_size: int = 4, @@ -523,6 +524,7 @@ def __init__( self.weights = weights self.data_config = data_config self.seq_length = seq_length + self.decoder_seq_length = decoder_seq_length self.tokenizer = tokenizer self.image_processor = image_processor self.num_train_samples = num_train_samples @@ -538,13 +540,15 @@ def __init__( if tokenizer is None or image_processor is None: logging.warning(f"Processor and tokenizer are not provided! Fall back to `llava-hf/llava-1.5-7b-hf`.") from transformers import AutoProcessor + from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf") - self.tokenizer = tokenizer or processor.tokenizer + self.tokenizer = tokenizer or AutoTokenizer("llava-hf/llava-1.5-7b-hf") self.image_processor = image_processor or processor.image_processor self.data_sampler = MegatronDataSampler( seq_len=self.seq_length, + decoder_seq_len=self.decoder_seq_length, micro_batch_size=micro_batch_size, global_batch_size=global_batch_size, dataloader_type="cyclic", diff --git a/nemo/collections/vlm/neva/data/mock.py b/nemo/collections/vlm/neva/data/mock.py index ac4bc56a068c..9e2308752641 100644 --- a/nemo/collections/vlm/neva/data/mock.py +++ b/nemo/collections/vlm/neva/data/mock.py @@ -14,35 +14,38 @@ from typing import Dict, List, Optional +import lightning.pytorch as pl import numpy as np -import pytorch_lightning as pl import torch -from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS +from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS from torch.utils import data from torch.utils.data import DataLoader, Dataset from nemo.collections.vlm.neva.data.multimodal_tokens import IMAGE_TOKEN_INDEX from nemo.lightning.pytorch.plugins import MegatronDataSampler +from nemo.utils import logging class MockDataModule(pl.LightningDataModule): def __init__( self, seq_length: int = 2048, + decoder_seq_length: Optional[int] = None, tokenizer: Optional = None, image_processor: Optional = None, micro_batch_size: int = 4, global_batch_size: int = 8, rampup_batch_size: Optional[List[int]] = None, - num_train_samples: int = 10_000, - num_val_samples: int = 10_000, - num_test_samples: int = 10_000, + num_train_samples: int = 10_000_000, + num_val_samples: int = 10_000_000, + num_test_samples: int = 10_000_000, num_workers: int = 8, pin_memory: bool = True, persistent_workers: bool = False, ): super().__init__() self.seq_length = seq_length + self.decoder_seq_len = decoder_seq_length self.num_train_samples = num_train_samples self.num_val_samples = num_val_samples self.num_test_samples = num_test_samples @@ -51,13 +54,16 @@ def __init__( self.persistent_workers = persistent_workers if tokenizer is None or image_processor is None: + logging.warning(f"Processor or tokenizer are not provided! Fall back to `llava-hf/llava-1.5-7b-hf`.") from transformers import AutoProcessor + from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf") - self.tokenizer = tokenizer or processor.tokenizer + self.tokenizer = tokenizer or AutoTokenizer("llava-hf/llava-1.5-7b-hf") self.image_processor = image_processor or processor.image_processor self.data_sampler = MegatronDataSampler( seq_len=self.seq_length, + decoder_seq_len=self.decoder_seq_len, micro_batch_size=micro_batch_size, global_batch_size=global_batch_size, rampup_batch_size=rampup_batch_size, diff --git a/nemo/collections/vlm/neva/model/__init__.py b/nemo/collections/vlm/neva/model/__init__.py index 25842186ecfe..99862f97b9ed 100644 --- a/nemo/collections/vlm/neva/model/__init__.py +++ b/nemo/collections/vlm/neva/model/__init__.py @@ -19,16 +19,19 @@ NevaConfig, NevaModel, ) -from nemo.collections.vlm.neva.model.llava import Llava1_5Config7B, Llava1_5Config13B, LlavaConfig, LlavaModel +from nemo.collections.vlm.neva.model.llava import Llava15Config7B, Llava15Config13B, LlavaConfig, LlavaModel +from nemo.collections.vlm.neva.model.vit_config import CLIPViTL_14_336_Config, SigLIPViT400M_14_384_Config __all__ = [ "CLIPViTConfig", + "CLIPViTL_14_336_Config", + "SigLIPViT400M_14_384_Config", "HFCLIPVisionConfig", "MultimodalProjectorConfig", "NevaConfig", "NevaModel", "LlavaConfig", - "Llava1_5Config7B", - "Llava1_5Config13B", + "Llava15Config7B", + "Llava15Config13B", "LlavaModel", ] diff --git a/nemo/collections/vlm/neva/model/api.py b/nemo/collections/vlm/neva/model/api.py index 62374d536712..13444632464e 100644 --- a/nemo/collections/vlm/neva/model/api.py +++ b/nemo/collections/vlm/neva/model/api.py @@ -12,20 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl -from nemo.collections.vlm.neva.model import Llava1_5Config7B, Llava1_5Config13B, LlavaModel +from nemo.collections.vlm.neva.model import Llava15Config7B, Llava15Config13B, LlavaModel -def llava1_5_7b() -> pl.LightningModule: - return LlavaModel(Llava1_5Config7B()) +def llava15_7b() -> pl.LightningModule: + return LlavaModel(Llava15Config7B()) -def llava1_5_13b() -> pl.LightningModule: - return LlavaModel(Llava1_5Config13B()) +def llava15_13b() -> pl.LightningModule: + return LlavaModel(Llava15Config13B()) __all__ = [ - "llava1_5_7b", - "llava1_5_13b", + "llava15_7b", + "llava15_13b", ] diff --git a/nemo/collections/vlm/neva/model/base.py b/nemo/collections/vlm/neva/model/base.py index 260b7e7e0f4a..360874152cf7 100644 --- a/nemo/collections/vlm/neva/model/base.py +++ b/nemo/collections/vlm/neva/model/base.py @@ -17,22 +17,25 @@ from dataclasses import dataclass from typing import Callable, Dict, List, Optional, Union -import pytorch_lightning as L +import lightning.pytorch as L import torch import torch.distributed import torch.nn.functional as F from megatron.core import dist_checkpointing +from megatron.core import parallel_state as ps +from megatron.core.enums import ModelType +from megatron.core.extensions.transformer_engine import TEDotProductAttention from megatron.core.inference_params import InferenceParams from megatron.core.models.multimodal.llava_model import LLaVAModel as MCoreLLaVAModel from megatron.core.models.vision.clip_vit_model import CLIPViTModel as MCoreCLIPViTModel from megatron.core.models.vision.multimodal_projector import MultimodalProjector as MCoreMultimodalProjector from megatron.core.optimizer import OptimizerConfig +from megatron.core.tensor_parallel import gather_from_sequence_parallel_region from megatron.core.transformer.custom_layers.transformer_engine import ( TEColumnParallelLinear, TENorm, TERowParallelLinear, ) -from megatron.core.transformer.enums import ModelType from megatron.core.transformer.mlp import MLP, MLPSubmodules from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.transformer_config import TransformerConfig @@ -41,15 +44,43 @@ from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec from nemo.collections.llm import fn -from nemo.collections.llm.gpt.model import local_layer_spec, transformer_engine_layer_spec +from nemo.collections.llm.gpt.model import transformer_engine_layer_spec from nemo.collections.llm.gpt.model.base import get_batch_on_this_context_parallel_rank, get_packed_seq_params -from nemo.collections.nlp.modules.common.megatron.module import MegatronModule -from nemo.collections.vlm.neva.data.multimodal_tokens import IGNORE_INDEX, IMAGE_TOKEN_INDEX +from nemo.collections.vlm.neva.data.multimodal_tokens import IMAGE_TOKEN_INDEX from nemo.lightning import io +from nemo.lightning.io.pl import ckpt_to_weights_subdir from nemo.lightning.megatron_parallel import MaskedTokenLossReductionWithLossMask from nemo.lightning.pytorch.optim import MegatronOptimizerModule, OptimizerModule from nemo.utils import logging +MODEL_CONFIG_ATTR = [ + 'num_layers', + 'hidden_size', + 'num_attention_heads', + 'num_query_groups', + 'ffn_hidden_size', + 'kv_channels', + 'hidden_dropout', + 'attention_dropout', + 'fp32_residual_connection', + 'apply_residual_connection_post_layernorm', + 'layernorm_epsilon', + 'layernorm_zero_centered_gamma', + 'add_bias_linear', + 'add_qkv_bias', + 'gated_linear_unit', + 'activation_func', + 'activation_func_fp8_input_store', + 'num_moe_experts', + 'rotary_interleaved', + 'window_size', + 'normalization', + 'qk_layernorm', + 'test_mode', + 'calculate_per_token_loss', + 'seq_length', +] + def get_image_sequence_length(img_h, img_w, patch_dim, add_class_token, class_token_len): """Get image sequence length given image size, patch size, and class token.""" @@ -64,9 +95,7 @@ def neva_data_step(dataloader_iter) -> Dict[str, torch.Tensor]: # Based on: https://github.com/NVIDIA/Megatron-LM/blob/main/pretrain_gpt.py#L87 # https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L828-L842 - batch = next(dataloader_iter) - _batch: dict if isinstance(batch, tuple) and len(batch) == 3: _batch = batch[0] @@ -74,11 +103,23 @@ def neva_data_step(dataloader_iter) -> Dict[str, torch.Tensor]: _batch = batch required_keys = set() - required_keys.add("attention_mask") + required_keys.update( + ( + "tokens", + "attention_mask", + "media", + "num_media_tiles", + ) + ) if parallel_state.is_pipeline_first_stage(): - required_keys.update(("media", "tokens", "position_ids")) + required_keys.update(("position_ids",)) if parallel_state.is_pipeline_last_stage(): - required_keys.update(("labels", "loss_mask")) + required_keys.update( + ( + "labels", + "loss_mask", + ) + ) _batch = { key: val.cuda(non_blocking=True) if key in required_keys and val is not None else None @@ -98,6 +139,7 @@ def neva_forward_step(model, batch) -> torch.Tensor: "attention_mask": batch.get("attention_mask", None), "loss_mask": batch.get("loss_mask", None), "labels": batch.get("labels", None), + "num_media_tiles": batch.get("num_media_tiles", None), } if 'cu_seqlens' in batch: @@ -176,10 +218,11 @@ class HFCLIPVisionConfig(CLIPVisionConfig, io.IOMixin): https://github.com/huggingface/transformers/blob/v4.44.0/src/transformers/models/clip/configuration_clip.py#L261 """ + hidden_size: int = 1024 pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None - def configure_hf_config(self, *args, **kwargs) -> None: - CLIPVisionConfig.__init__(self, *args, **kwargs) + def __post_init__(self, *args, **kwargs) -> None: + CLIPVisionConfig.__init__(self, *args, **kwargs, hidden_size=self.hidden_size) def configure_model(self) -> "CLIPVisionModel": # Monkey patch the method to the vision encoder @@ -198,26 +241,40 @@ def configure_model(self) -> "CLIPVisionModel": @dataclass class CLIPViTConfig(TransformerConfig, io.IOMixin): ln_pre_impl: Union[ModuleSpec, type] = TENorm + ln_post_impl: Union[ModuleSpec, type] = TENorm add_class_token: bool = True class_token_len: int = 1 patch_dim: int = 14 img_h: int = 336 img_w: int = 336 + vision_model_type: str = "clip" # ["clip", "siglip"] transformer_layer_spec: ModuleSpec = transformer_engine_layer_spec - def configure_model(self) -> "MCoreCLIPViTModel": + num_layers: int = 1 # Placeholder, NOT used! + num_attention_heads: int = 8 # Placeholder, NOT used! + + def __post_init__(self): + if self.vision_model_type == "siglip": + self.add_class_token = False + self.class_token_len = 0 + + def configure_model(self) -> "CLIPViTModel": transformer_layer_spec = self.transformer_layer_spec if not isinstance(transformer_layer_spec, ModuleSpec): - transformer_layer_spec = transformer_layer_spec(self) - return MCoreCLIPViTModel( + from nemo.collections.vlm.layer_specs import get_layer_spec_te + + transformer_layer_spec = get_layer_spec_te(is_vit=True) + return CLIPViTModel( self, transformer_layer_spec, ln_pre_impl=self.ln_pre_impl, + ln_post_impl=self.ln_post_impl, add_class_token=self.add_class_token, class_token_len=self.class_token_len, patch_dim=self.patch_dim, img_h=self.img_h, img_w=self.img_w, + model_subtype=self.vision_model_type, ) @@ -226,283 +283,173 @@ class NevaConfig(TransformerConfig, io.IOMixin): language_transformer_config: Optional[TransformerConfig] = None vision_transformer_config: Optional[TransformerConfig] = None vision_projection_config: Optional[TransformerConfig] = None + drop_vision_class_token: bool = True + vision_feature_layer: int = -2 + + encoder_pipeline_model_parallel_size: int = 0 + encoder_tensor_model_parallel_size: int = 1 num_layers: int = 1 # Placeholder, NOT used! num_attention_heads: int = 8 # Placeholder, NOT used! - vision_feature_layer: int = -2 + + seq_length: int = 1024 language_model_from_pretrained: Optional[str] = None vision_model_from_pretrained: Optional[str] = None # TODO vision_projection_from_pretrained: Optional[str] = None # TODO - freeze_language_model: bool = True - freeze_vision_model: bool = True + freeze_language_model: bool = False + freeze_vision_model: bool = False freeze_vision_projection: bool = False forward_step_fn: Callable = neva_forward_step data_step_fn: Callable = neva_data_step - def configure_model(self, tokenizer) -> "MCoreLLaVAModel": - language_model = self.language_transformer_config.configure_model(tokenizer=tokenizer) - vision_model = self.vision_transformer_config.configure_model() - vision_projection = self.vision_projection_config.configure_model() - - if self.language_model_from_pretrained is not None: - sharded_state_dict = dict(state_dict=language_model.sharded_state_dict(prefix="module.")) - loaded_state_dict = dist_checkpointing.load( - sharded_state_dict=sharded_state_dict, checkpoint_dir=self.language_model_from_pretrained + def __post_init__(self): + if self.language_transformer_config is not None: + for attr in MODEL_CONFIG_ATTR: + setattr(self, attr, getattr(self.language_transformer_config, attr)) + + def configure_model(self, tokenizer) -> "MCoreNevaModel": + from megatron.core import parallel_state as ps + + self.language_transformer_config.tensor_model_parallel_size = self.tensor_model_parallel_size + self.language_transformer_config.sequence_parallel = self.sequence_parallel + self.vision_transformer_config.tensor_model_parallel_size = self.tensor_model_parallel_size + self.vision_projection_config.tensor_model_parallel_size = self.tensor_model_parallel_size + self.language_transformer_config.pipeline_model_parallel_size = self.pipeline_model_parallel_size + + if self.encoder_pipeline_model_parallel_size > 0: + assert self.encoder_pipeline_model_parallel_size == 1, "ViT can only live on 1 pipeline stage." + self.vision_transformer_config.pipeline_model_parallel_size = self.encoder_pipeline_model_parallel_size + self.vision_projection_config.pipeline_model_parallel_size = self.encoder_pipeline_model_parallel_size + self.language_transformer_config.encoder_pipeline_model_parallel_size = ( + self.encoder_pipeline_model_parallel_size ) - loaded_state_dict = {k.removeprefix("module."): v for k, v in loaded_state_dict["state_dict"].items()} - language_model.load_state_dict(loaded_state_dict) - logging.info(f"Restored language model weights from {self.language_model_from_pretrained}") + if self.encoder_tensor_model_parallel_size > 0: + self.vision_transformer_config.tensor_model_parallel_size = self.encoder_tensor_model_parallel_size + self.vision_projection_config.tensor_model_parallel_size = self.encoder_tensor_model_parallel_size + model = MCoreNevaModel( - transformer_config=self, - language_model=language_model, - vision_model=vision_model, - vision_projection=vision_projection, + config=self, + tokenizer=tokenizer, + pre_process=ps.is_pipeline_first_stage() + or ps.get_pipeline_model_parallel_rank() == self.encoder_pipeline_model_parallel_size, + post_process=ps.is_pipeline_last_stage(), + add_encoder=ps.is_pipeline_first_stage(), + add_decoder=ps.is_pipeline_last_stage() + or ps.get_pipeline_model_parallel_rank() >= self.encoder_pipeline_model_parallel_size, drop_vision_class_token=self.drop_vision_class_token, ) - model.freeze( - freeze_language_model=self.freeze_language_model, - freeze_vision_model=self.freeze_vision_model, - freeze_vision_projection=self.freeze_vision_projection, - ) + return model +class CLIPViTModel(MCoreCLIPViTModel): + """CLIP ViT vision model.""" + + def forward( + self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, num_unused_layers: int = 0 + ) -> torch.Tensor: + if num_unused_layers > 0: + unused_layers = self.decoder.layers[-num_unused_layers:] + self.decoder.layers = self.decoder.layers[:-num_unused_layers] + x = super().forward(x, attention_mask) + self.decoder.layers.append(unused_layers) + return x + + return super().forward(x, attention_mask) + + class MCoreNevaModel(MCoreLLaVAModel): def __init__( self, - transformer_config: TransformerConfig, - language_model: MegatronModule, - vision_model: MegatronModule, - vision_projection: MegatronModule, + config: NevaConfig, + tokenizer: Optional = None, pre_process: bool = True, post_process: bool = True, + add_encoder: bool = True, + add_decoder: bool = True, drop_vision_class_token: bool = False, ) -> None: - super(MCoreLLaVAModel, self).__init__(config=transformer_config) + super(MCoreLLaVAModel, self).__init__(config=config) - logging.warning("LLaVA model is under development and may be missing features.") + language_transformer_config = config.language_transformer_config + vision_transformer_config = config.vision_transformer_config + vision_projection_config = config.vision_projection_config self.pre_process = pre_process self.post_process = post_process + self.add_encoder = add_encoder + self.add_decoder = add_decoder self.encoder_hidden_state = None - self.vision_model = vision_model - self.vision_projection = vision_projection - self.language_model = language_model - self.model_type = ModelType.encoder_or_decoder - # This attribute is needed to check if an all-reduce is required - # on the word embeddings inside `finalize_model_grads._allreduce_word_embedding_grads`. - self.share_embeddings_and_output_weights = False - if self.language_model is not None: - self.share_embeddings_and_output_weights = self.language_model.share_embeddings_and_output_weights - self._language_max_sequence_length = self.language_model.max_sequence_length + self.vision_model = None + self.vision_projection = None + self.language_model = None - if self.vision_model is not None: - self._drop_vision_class_token = drop_vision_class_token + self.sequence_parallel_lm = language_transformer_config.sequence_parallel + self.tp_comm_overlap_lm = language_transformer_config.tp_comm_overlap - self.add_encoder = self.vision_model is not None - self.add_decoder = self.language_model is not None - self.vision_model_from_hf = str(self.vision_model.__class__.__module__).startswith("transformers.") + self.share_embeddings_and_output_weights = False if self.add_decoder: - vision_config = self.config.vision_transformer_config - if self.vision_model_from_hf: - # img_h, img_w, patch_dim, add_class_token, class_token_len - self._img_seq_len = get_image_sequence_length( - img_h=vision_config.image_size, - img_w=vision_config.image_size, - patch_dim=vision_config.patch_size, - add_class_token=not drop_vision_class_token, - class_token_len=0 if "siglip" in vision_config.model_type else 1, + self.language_model = language_transformer_config.configure_model( + tokenizer=tokenizer, pre_process=pre_process, post_process=post_process + ) + self.share_embeddings_and_output_weights = self.language_model.share_embeddings_and_output_weights + self._language_max_sequence_length = self.language_model.max_sequence_length + self._language_is_pipeline_parallel = language_transformer_config.pipeline_model_parallel_size > 1 + if config.language_model_from_pretrained is not None: + sharded_state_dict = dict(state_dict=self.language_model.sharded_state_dict(prefix="module.")) + loaded_state_dict = dist_checkpointing.load( + sharded_state_dict=sharded_state_dict, + checkpoint_dir=ckpt_to_weights_subdir(config.language_model_from_pretrained, is_saving=False), + validate_access_integrity=False, ) - else: - self._img_seq_len = 576 # TODO(yuya): Fix hardcode + loaded_state_dict = {k.removeprefix("module."): v for k, v in loaded_state_dict["state_dict"].items()} + self.language_model.load_state_dict(loaded_state_dict) + logging.info(f"Restored language model weights from {config.language_model_from_pretrained}") else: - self._img_seq_len = 0 - - def _preprocess_data( - self, - image_embeddings, - language_embeddings, - input_ids, - loss_mask, - labels, - use_inference_kv_cache, - image_token_index, - num_image_tiles, - ): - # TODO (yuya): remove this and use the mcore method - """Preprocess input data before input to language model. - - This function is adopted from - https://github.com/huggingface/transformers/blob/85817d98fb60977c97e3014196a462b732d2ed1a/src/transformers/models/llava_next/modeling_llava_next.py#L409 - for our input data conventions. - - image_token_index = -200 indicates the image position in the input_ids = [0, 1, -200, 2, 3] and labels = [1, -200, 2, 3, 4], for example. - We want to replace the image position (-200) with image_embeddings and return the following: - - final_embeddings = [0, 1, image_embeddings, 2, 3], - - final_labels = [1, -100, 2, 3, 4] - - final_loss_mask = [1, 0, 0, 1, 1] - - This function also handles the case where the input does not contain an image (text-only sample). It also handles the case where a single input - image is split into multiple tiles. - - If pipeline parallelism is not used, then self.pre_process and self.post_process are both True and we update both - input embeddings, labels and loss masks (if available). - - If pipeline parallelism is used, then we do the following - - the first language model chunk has self.pre_process = True and self.post_process = False. We update input embeddings. - - the middle language model chunk(s) has self.pre_process = False and self.post_process = False. We don't need to update anything. - - the last language model chunk has self.pre_process = False and self.post_process = True. We update labels and loss mask. + if config.language_model_from_pretrained is not None: + dist_checkpointing.load( + sharded_state_dict=dict(state_dict={}), + checkpoint_dir=config.language_model_from_pretrained, + validate_access_integrity=False, + ) - TODO: This function should adjust the attention mask too. Currently, we assume the language model uses a causal mask. + if self.add_encoder: + self.vision_model = vision_transformer_config.configure_model() + self.vision_projection = vision_projection_config.configure_model() + self._drop_vision_class_token = drop_vision_class_token - Returns: - final_embedding (torch.Tensor): image and text embeddings concated [combined_seq_len, b, h]. - final_labels (torch.Tensor): labels for image and text positions [b, combined_seq_len]. - final_loss_mask (torch.Tensor): loss mask for image and text positions [b, combined_seq_len]. - """ - assert self.add_decoder, "input text preprocessing is only needed for the language model" + self.freeze( + freeze_language_model=config.freeze_language_model, + freeze_vision_model=config.freeze_vision_model, + freeze_vision_projection=config.freeze_vision_projection, + ) - # No pre- or postprocessing needed. With pipeline parallel > 2, this means a chunk in the middle of the model. - if not self.pre_process and not self.post_process: - return language_embeddings, loss_mask, labels + self.model_type = ModelType.encoder_or_decoder + # This attribute is needed to check if an all-reduce is required + # on the word embeddings inside `finalize_model_grads._allreduce_word_embedding_grads`. - # If using the inference KV cache, the image tokens are already computed. - if use_inference_kv_cache: - return language_embeddings, loss_mask, labels - - img_seq_len = self._img_seq_len - batch_size, text_seq_len = input_ids.shape - - has_labels = labels is not None - if has_labels: - assert ( - labels.shape == loss_mask.shape - ), f"mismatching labels shape {labels.shape} and loss mask shape {loss_mask.shape}" - - # Create indices for new text and label positions. - with torch.no_grad(): - image_token_mask = input_ids == image_token_index - num_image_tokens = torch.sum(image_token_mask, dim=-1) - - # Number of tiles per sample. - num_image_tiles_batch = num_image_tiles.split(num_image_tokens.tolist(), dim=0) - num_image_tiles_batch = torch.tensor([x.sum() for x in num_image_tiles_batch], device=input_ids.device) - - # Sequence length for each sample is the image sequence length multiplied by the number of tiles for that image, minus image token indices, - # plus text sequence length. - seq_lens = num_image_tiles_batch * img_seq_len - num_image_tokens + text_seq_len - max_seq_len = seq_lens.max() - batch_indices, non_image_indices = torch.where(input_ids != image_token_index) - - # New position ids for the text tokens, shifted by the image sequence length. - # E.g. for input_ids = [-200, 1, 2, 3] and img_seq_len = 576, we get new_position_ids = [576, 577, 578, 579]. - # text_position_ids are then [577, 578, 579]. - image_token_mask_lens = image_token_mask.int().clone() - # -1 is for the removed image token index. - image_token_mask_lens[image_token_mask] = num_image_tiles * img_seq_len - 1 - # +1 is needed here for the cumulative sum. -1 is adjusting for zero-based indexing. - new_position_ids = torch.cumsum((image_token_mask_lens + 1), dim=-1) - 1 - text_position_ids = new_position_ids[batch_indices, non_image_indices] - - # Labels are shifted to left by one. So, shift text position ids and non-image indices to left by one. - if has_labels: - label_text_position_ids = text_position_ids - 1 - valid_label_text_position_ids = label_text_position_ids >= 0 - label_text_position_ids = label_text_position_ids[valid_label_text_position_ids] - - label_batch_indices = batch_indices[valid_label_text_position_ids] - - label_non_image_indices = non_image_indices - 1 - valid_label_non_image_indices = label_non_image_indices >= 0 - label_non_image_indices = label_non_image_indices[valid_label_non_image_indices] - - # Create a mask for the image embedding positions. - with torch.no_grad(): - images_mask = torch.full((batch_size, max_seq_len), True, dtype=torch.bool, device=input_ids.device) - # No images in the text positions. - images_mask[batch_indices, text_position_ids] = False - # Samples can have different amount of images tokens. new_position_ids[:, -1] gives the last text position id for each sample. - # Padding is needed when the number of image tokens differs. - first_padding_idx = new_position_ids[:, -1] + 1 - images_mask[ - torch.arange(max_seq_len, device=first_padding_idx.device).repeat(batch_size, 1) - >= first_padding_idx.unsqueeze(1) - ] = False - - # Create the final input embedding (if this is the first language model stage). - final_embedding = None - if self.pre_process: - embed_dim = language_embeddings.shape[-1] - final_embedding = torch.zeros( - batch_size, - max_seq_len, - embed_dim, - dtype=image_embeddings.dtype, - device=image_embeddings.device, + self.vision_model_from_hf = hasattr(vision_transformer_config, "image_size") + if self.vision_model_from_hf: + # img_h, img_w, patch_dim, add_class_token, class_token_len + self._img_seq_len = get_image_sequence_length( + img_h=vision_transformer_config.image_size, + img_w=vision_transformer_config.image_size, + patch_dim=vision_transformer_config.patch_size, + add_class_token=not drop_vision_class_token, + class_token_len=0 if "siglip" in vision_transformer_config.model_type else 1, ) - - # Put text embeddings to the text positions in the result tensor. - final_embedding[batch_indices, text_position_ids] = language_embeddings[batch_indices, non_image_indices] - - # Put image embeddings to image positions. - final_embedding[images_mask] = image_embeddings.reshape(-1, embed_dim).contiguous() - - # Create the final labels and loss mask (if this is the last language model stage). - final_labels, final_loss_mask = None, None - if has_labels: - final_labels = torch.full( - (batch_size, max_seq_len), IGNORE_INDEX, dtype=labels.dtype, device=labels.device + else: + self._img_seq_len = get_image_sequence_length( + img_h=vision_transformer_config.img_h, + img_w=vision_transformer_config.img_w, + patch_dim=vision_transformer_config.patch_dim, + add_class_token=not drop_vision_class_token, + class_token_len=vision_transformer_config.class_token_len, ) - final_loss_mask = torch.full((batch_size, max_seq_len), 0, dtype=loss_mask.dtype, device=loss_mask.device) - - # Put text labels and loss mask to the text positions. - final_labels[label_batch_indices, label_text_position_ids] = labels[ - label_batch_indices, label_non_image_indices - ] - - final_loss_mask[batch_indices, text_position_ids] = loss_mask[batch_indices, non_image_indices] - - # For labels, we need to pick the last label index that got dropped by the shift to left. - label_extra_text_position_ids = seq_lens - 1 - batch_range = torch.arange(len(label_extra_text_position_ids)) - final_labels[batch_range, label_extra_text_position_ids] = labels[batch_range, -1] - - # Loss mask the image positions. - final_loss_mask[images_mask] = 0 - - # Loss mask last text position just before an image so that text token does not need to predict the first image token. - batch_image_indices, image_indices = torch.where(image_token_mask) - # Indices just before image tokens. If it's -1, skip it. - before_image_indices = image_indices - 1 - valid = before_image_indices >= 0 - valid_batch_image_indices = batch_image_indices[valid] - valid_before_image_indices = before_image_indices[valid] - # Map those indices those position ids. - valid_before_image_indices = new_position_ids[valid_batch_image_indices, valid_before_image_indices] - - final_loss_mask[valid_batch_image_indices, valid_before_image_indices] = 0 - - if final_embedding is not None and has_labels: - assert ( - final_embedding.shape[:2] == final_labels.shape == final_loss_mask.shape - ), "unexpected shapes after data preprocessing" - - if final_embedding is not None: - final_embedding = final_embedding.transpose(1, 0).contiguous() - - # Truncate if exceeding the language model's max sequence length. - if final_embedding is not None and final_embedding.shape[0] > self._language_max_sequence_length: - final_embedding = final_embedding[: self._language_max_sequence_length] - - if has_labels and final_labels.shape[1] > self._language_max_sequence_length: - final_labels = final_labels[:, : self._language_max_sequence_length] - final_loss_mask = final_loss_mask[:, : self._language_max_sequence_length] - - return final_embedding, final_labels, final_loss_mask def forward( self, @@ -515,6 +462,7 @@ def forward( inference_params: Optional[InferenceParams] = None, num_media_tiles: Optional[List[int]] = None, media_token_index: Optional[int] = IMAGE_TOKEN_INDEX, + runtime_gather_output: Optional[bool] = None, ) -> torch.Tensor: """Forward function of the LLaVA model. @@ -533,34 +481,44 @@ def forward( output (torch.Tensor): Loss of shape [b, s] if labels are provided, otherwise logits of shape [b, s, vocab_size]. loss_mask (torch.Tensor): Loss mask expanded to combined sequence length. Shape [b, s]. """ + use_inference_kv_cache = ( inference_params is not None and "image_tokens_count" in inference_params.key_value_memory_dict ) + has_images = media.shape[0] > 0 + # If running inference, we can skip media token computation if they were computed already earlier for this sample. - if use_inference_kv_cache or media is None: + if use_inference_kv_cache: media_embeddings = None - elif self.add_encoder: + elif self.add_encoder and not has_images: + # If no images provided, use an empty image embeddings tensor. + media_embeddings = torch.tensor([], dtype=media.dtype, device=media.device).reshape(0, 0, 0) + elif self.add_encoder and has_images: # media is in shape of (num_images_in_mbs, c, h, w) # note num_images_in_mbs is not mbs but total images in this mbs. if self.vision_model_from_hf: - media_embeddings = self.vision_model( - media, output_hidden_states=True - ) # [num_images, img_seq_len, h_vision] + self.vision_model = self.vision_model.eval() + media_embeddings = self.vision_model(media, output_hidden_states=True) media_embeddings = media_embeddings[-1][ self.config.vision_feature_layer - ] # take second from last layer + ] # [num_images, img_seq_len, h_vision] else: # TODO(yuya): MCore Clip path not yet support taking a specific layer hidden states - media_embeddings = self.vision_model(media) + media = media.to(next(self.vision_model.parameters()).dtype) + media_embeddings = self.vision_model(media, num_unused_layers=-self.config.vision_feature_layer - 1) if self._drop_vision_class_token: class_token_len = getattr(self.vision_model, "class_token_len", 1) media_embeddings = media_embeddings[:, class_token_len:, :] + # contiguous() required as `permute` can sparsify the tensor and this breaks pipelining + media_embeddings = media_embeddings.permute(1, 0, 2).contiguous() # [img_seq_len, num_tiles, h_vision] + # map vision model output size to language model input size. - media_embeddings = self.vision_projection(media_embeddings) # [img_seq_len, num_tiles, h_vision] + media_embeddings = self.vision_projection(media_embeddings) # [img_seq_len, num_tiles, h_language] - # If running inference, the language model KV cache will be updated for media token positions. - # Here we store the media tokens sequence length, which can be used as an offset to the KV cache later. + # TODO: Support batched inference. + # In inference, the language model KV cache will be updated for image token positions. + # Store the image tokens sequence length to be used as an offset to the KV cache later. if inference_params is not None: inference_params.key_value_memory_dict["media_tokens_count"] = ( media_embeddings.shape[0] * media_embeddings.shape[1] @@ -569,40 +527,61 @@ def forward( media_embeddings = self.encoder_hidden_state if not self.add_decoder: - return media_embeddings, loss_mask + return media_embeddings language_embeddings = None if self.pre_process: input_ids_text = input_ids.clone() # MultiModal Token indices are assumed to be values input_ids_text[input_ids_text < 0] = 0 - # Note: This adds absolute position embedding but not RoPE. Each image is counted as one position. - # RoPE is added in language_model forward call. Each image embedding is one position. + # Note: This adds absolute position embedding but not RoPE. + # Each image is counted as one position. + # RoPE is added in language_model forward. Each image embedding is one position. + if self.sequence_parallel_lm: + # Pad to nearest multiple of TP world size for embedding. + tp_world_size = ps.get_tensor_model_parallel_world_size() + padded_seq_len = ( + int((input_ids_text.shape[1] + tp_world_size - 1) // tp_world_size * tp_world_size) + - input_ids_text.shape[1] + ) + if padded_seq_len != 0: + input_ids_text = torch.nn.functional.pad(input_ids_text, (0, padded_seq_len)) + if position_ids is not None: + position_ids = torch.nn.functional.pad(position_ids, (0, padded_seq_len)) language_embeddings = self.language_model.embedding( input_ids=input_ids_text, position_ids=position_ids ) # [text_seq_len, b, h_language] + if self.sequence_parallel_lm: + # Gather the language embeddings back. + # We use the full embedding to insert image embeddings + # and then scatter to avoid load imbalance. + language_embeddings = gather_from_sequence_parallel_region( + language_embeddings, tensor_parallel_output_grad=False + ) + # Remove the padding done for SP as we'll need new padding calculation + # after image embeddings are inserted. + if padded_seq_len != 0: + language_embeddings = language_embeddings[:-padded_seq_len] language_embeddings = language_embeddings.transpose(1, 0).contiguous() # [b, text_seq_len, h_language] - if media is None: - combined_embeddings = language_embeddings.transpose(1, 0).contiguous() - final_labels = labels - final_loss_mask = loss_mask - else: - # Assume 1 tile per image if the number of tiles is not provided. - if num_media_tiles is None: - num_media_tiles = torch.ones(media.shape[0], dtype=torch.int, device=input_ids.device) - - # Preprocess input, labels and loss mask. - combined_embeddings, final_labels, final_loss_mask = self._preprocess_data( - media_embeddings, - language_embeddings, - input_ids, - loss_mask, - labels, - use_inference_kv_cache, - media_token_index, - num_media_tiles, - ) # [combined_seq_len, b, h_language], [b, combined_seq_len], [b, combined_seq_len] + # Assume 1 tile per image if the number of tiles is not provided. + if num_media_tiles is None: + num_media_tiles = torch.ones(media.shape[0], dtype=torch.int, device=input_ids.device) + elif isinstance(num_media_tiles, list): + num_media_tiles = torch.tensor(num_media_tiles, dtype=torch.int, device=input_ids.device) + + # Preprocess input, labels and loss mask. + combined_embeddings, final_labels, final_loss_mask, final_attention_mask = self._preprocess_data( + media_embeddings, + language_embeddings, + input_ids, + loss_mask, + labels, + use_inference_kv_cache, + media_token_index, + num_media_tiles, + attention_mask, + ) # [combined_seq_len, b, h_language], [b, combined_seq_len], [b, combined_seq_len] output = self.language_model( input_ids=None, @@ -611,6 +590,7 @@ def forward( decoder_input=combined_embeddings, labels=final_labels, inference_params=inference_params, + runtime_gather_output=runtime_gather_output, ) if labels is None or loss_mask is None: @@ -618,6 +598,23 @@ def forward( return output, final_loss_mask.contiguous() + def set_input_tensor(self, input_tensor) -> None: + """Set model chunk input tensor.""" + # This is usually handled in schedules.py but some inference code still + # gives us non-lists or None + if not isinstance(input_tensor, list): + input_tensor = [input_tensor] + assert len(input_tensor) == 1, 'input_tensor should only be length 1 for llava' + + if self.add_encoder and self.add_decoder: + self.vision_model.set_input_tensor(input_tensor[0]) + elif self.add_encoder: + self.vision_model.set_input_tensor(input_tensor[0]) + elif self.pre_process: + self.encoder_hidden_state = input_tensor[0] + else: + self.language_model.set_input_tensor(input_tensor[0]) + class NevaModel(L.LightningModule, io.IOMixin, io.ConnectorMixin, fn.FNMixin): def __init__( @@ -649,6 +646,7 @@ def forward( media: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, inference_params: InferenceParams = None, + num_media_tiles: Optional[List[int]] = None, ) -> torch.Tensor: output_tensor = self.module( media=media, @@ -658,6 +656,7 @@ def forward( attention_mask=attention_mask, labels=labels, inference_params=inference_params, + num_media_tiles=num_media_tiles, ) return output_tensor @@ -697,6 +696,4 @@ def validation_loss_reduction(self) -> MaskedTokenLossReductionWithLossMask: "NevaConfig", "neva_data_step", "neva_forward_step", - "transformer_engine_layer_spec", - "local_layer_spec", ] diff --git a/nemo/collections/vlm/neva/model/llava.py b/nemo/collections/vlm/neva/model/llava.py index 52b55b6f9c2d..7f5f46380b29 100644 --- a/nemo/collections/vlm/neva/model/llava.py +++ b/nemo/collections/vlm/neva/model/llava.py @@ -43,7 +43,7 @@ class LlavaConfig(NevaConfig): @dataclass -class Llava1_5Config7B(LlavaConfig): +class Llava15Config7B(LlavaConfig): from transformers import PretrainedConfig language_transformer_config: TransformerConfig = field(default_factory=lambda: Llama2Config7B()) @@ -56,7 +56,7 @@ class Llava1_5Config7B(LlavaConfig): @dataclass -class Llava1_5Config13B(LlavaConfig): +class Llava15Config13B(LlavaConfig): from transformers import PretrainedConfig language_transformer_config: TransformerConfig = field(default_factory=lambda: Llama2Config13B()) @@ -102,7 +102,7 @@ def apply(self, output_path: Path) -> Path: return output_path - def convert_state(self, source, target): + def convert_state(self, source, target, image_newline=False): mapping = { "language_model.model.embed_tokens.weight": "language_model.embedding.word_embeddings.weight", "language_model.model.layers.*.self_attn.o_proj.weight": "language_model.decoder.layers.*.self_attention.linear_proj.weight", @@ -111,7 +111,6 @@ def convert_state(self, source, target): "language_model.model.layers.*.post_attention_layernorm.weight": "language_model.decoder.layers.*.mlp.linear_fc1.layer_norm_weight", "language_model.model.norm.weight": "language_model.decoder.final_layernorm.weight", "language_model.lm_head.weight": "language_model.output_layer.weight", - "vision_tower.vision_model.**": "vision_model.vision_model.**", } if "vision_projection.encoder.linear_fc1.weight" in target.module.state_dict().keys(): mapping.update( @@ -134,7 +133,48 @@ def convert_state(self, source, target): else: raise KeyError("Unable to map vision projection keys.") - return io.apply_transforms(source, target, mapping=mapping, transforms=[_import_qkv, _import_linear_fc1]) + if image_newline: + mapping.update({"image_newline": "image_newline"}) + + if "vision_model.vision_model.embeddings.class_embedding" in target.module.state_dict().keys(): + mapping.update( + { + "vision_tower.vision_model.**": "vision_model.vision_model.**", + } + ) + elif "vision_model.class_token" in target.module.state_dict().keys(): + mapping.update( + { + "vision_tower.vision_model.embeddings.patch_embedding.weight": "vision_model.conv1.weight", + "vision_tower.vision_model.embeddings.position_embedding.weight": "vision_model.position_embeddings.weight", + "vision_tower.vision_model.encoder.layers.*.layer_norm1.weight": "vision_model.decoder.layers.*.self_attention.linear_qkv.layer_norm_weight", + "vision_tower.vision_model.encoder.layers.*.layer_norm1.bias": "vision_model.decoder.layers.*.self_attention.linear_qkv.layer_norm_bias", + "vision_tower.vision_model.encoder.layers.*.layer_norm2.weight": "vision_model.decoder.layers.*.mlp.linear_fc1.layer_norm_weight", + "vision_tower.vision_model.encoder.layers.*.layer_norm2.bias": "vision_model.decoder.layers.*.mlp.linear_fc1.layer_norm_bias", + "vision_tower.vision_model.encoder.layers.*.self_attn.out_proj.weight": "vision_model.decoder.layers.*.self_attention.linear_proj.weight", + "vision_tower.vision_model.encoder.layers.*.self_attn.out_proj.bias": "vision_model.decoder.layers.*.self_attention.linear_proj.bias", + "vision_tower.vision_model.encoder.layers.*.mlp.fc1.weight": "vision_model.decoder.layers.*.mlp.linear_fc1.weight", + "vision_tower.vision_model.encoder.layers.*.mlp.fc1.bias": "vision_model.decoder.layers.*.mlp.linear_fc1.bias", + "vision_tower.vision_model.encoder.layers.*.mlp.fc2.weight": "vision_model.decoder.layers.*.mlp.linear_fc2.weight", + "vision_tower.vision_model.encoder.layers.*.mlp.fc2.bias": "vision_model.decoder.layers.*.mlp.linear_fc2.bias", + "vision_tower.vision_model.pre_layrnorm.weight": "vision_model.ln_pre.weight", + "vision_tower.vision_model.pre_layrnorm.bias": "vision_model.ln_pre.bias", + } + ) + else: + raise KeyError("Unable to map vision encoder keys.") + return io.apply_transforms( + source, + target, + mapping=mapping, + transforms=[ + _import_language_qkv, + _import_vision_qkv, + _import_vision_qkv_bias, + _import_cls_token, + _import_linear_fc1, + ], + ) @property def tokenizer(self) -> "AutoTokenizer": @@ -183,80 +223,7 @@ def make_vocab_size_divisible_by(vocab_size): return output -@io.model_exporter(LlavaModel, "hf") -class HFLlavaExporter(io.ModelConnector[LlavaModel, "LlavaForConditionalGeneration"]): - def init(self) -> "LlavaForConditionalGeneration": - raise NotImplementedError("Neva Exporter hasn't been verified!") - - from transformers import AutoModelForCausalLM - - return AutoModelForCausalLM.from_config(self.config) - - def apply(self, output_path: Path) -> Path: - target = self.init() - source, _ = self.nemo_load(str(self)) - - target = self.convert_state(source, target) - - target = target.cpu() - target.save_pretrained(output_path) - self.tokenizer.save_pretrained(output_path) - - return output_path - - def convert_state(self, source, target): - mapping = { - "language_model.embedding.word_embeddings.weight": "language_model.model.embed_tokens.weight", - "language_model.decoder.layers.*.self_attention.linear_proj.weight": "language_model.model.layers.*.self_attn.o_proj.weight", - "language_model.decoder.layers.*.mlp.linear_fc2.weight": "language_model.model.layers.*.mlp.down_proj.weight", - "language_model.decoder.layers.*.self_attention.linear_qkv.layer_norm_weight": "language_model.model.layers.*.input_layernorm.weight", - "language_model.decoder.layers.*.mlp.linear_fc1.layer_norm_weight": "language_model.model.layers.*.post_attention_layernorm.weight", - "language_model.decoder.final_layernorm.weight": "language_model.model.norm.weight", - "language_model.output_layer.weight": "language_model.lm_head.weight", - } - - return io.apply_transforms(source, target, mapping=mapping, transforms=[_export_qkv, _export_linear_fc1]) - - @property - def tokenizer(self): - return io.load_context(str(self)).model.tokenizer.tokenizer - - @property - def config(self) -> "HFLlavaConfig": - source: LlavaConfig = io.load_context(str(self)).model.config - - from transformers import LlavaConfig as HFLlavaConfig - - return HFLlavaConfig( - num_hidden_layers=source.num_layers, - hidden_size=source.hidden_size, - intermediate_size=source.ffn_hidden_size, - num_attention_heads=source.num_attention_heads, - max_position_embeddings=source.seq_length, - initializer_range=source.init_method_std, - rms_norm_eps=source.layernorm_epsilon, - num_key_value_heads=source.num_query_groups, - rope_theta=source.rotary_base, - vocab_size=self.tokenizer.vocab_size, - ) - - -@io.state_transform( - source_key=( - "language_model.model.layers.*.self_attn.q_proj.weight", - "language_model.model.layers.*.self_attn.k_proj.weight", - "language_model.model.layers.*.self_attn.v_proj.weight", - ), - target_key="language_model.decoder.layers.*.self_attention.linear_qkv.weight", -) -def _import_qkv(ctx: io.TransformCTX, q, k, v): - megatron_config = ctx.target.config.language_transformer_config - head_num = megatron_config.num_attention_heads - num_query_groups = megatron_config.num_query_groups - heads_per_group = head_num // num_query_groups - hidden_size = megatron_config.hidden_size - head_size = megatron_config.kv_channels - +def import_qkv(q, k, v, head_num, num_query_groups, heads_per_group, hidden_size, head_size): old_tensor_shape = q.size() new_q_tensor_shape = (head_num, head_size) + old_tensor_shape[1:] new_kv_tensor_shape = (num_query_groups, head_size) + old_tensor_shape[1:] @@ -282,59 +249,85 @@ def _import_qkv(ctx: io.TransformCTX, q, k, v): @io.state_transform( - source_key="language_model.decoder.layers.*.self_attention.linear_qkv.weight", - target_key=( + source_key=( "language_model.model.layers.*.self_attn.q_proj.weight", "language_model.model.layers.*.self_attn.k_proj.weight", "language_model.model.layers.*.self_attn.v_proj.weight", ), + target_key="language_model.decoder.layers.*.self_attention.linear_qkv.weight", ) -def _export_qkv(ctx: io.TransformCTX, linear_qkv): - megatron_config = ctx.source.config - - head_num = megatron_config.num_attention_heads - num_query_groups = megatron_config.num_query_groups - heads_per_group = head_num // num_query_groups - hidden_size = megatron_config.hidden_size - head_size = megatron_config.kv_channels - qkv_total_dim = head_num + 2 * num_query_groups - - linear_qkv = linear_qkv.reshape([qkv_total_dim, head_size, hidden_size]) - q_slice = torch.cat( - [ - torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group) - for i in range(num_query_groups) - ] +def _import_language_qkv(ctx: io.TransformCTX, q, k, v): + megatron_config = ctx.target.config.language_transformer_config + return import_qkv( + q, + k, + v, + head_num=megatron_config.num_attention_heads, + num_query_groups=megatron_config.num_query_groups, + heads_per_group=megatron_config.num_attention_heads // megatron_config.num_query_groups, + hidden_size=megatron_config.hidden_size, + head_size=megatron_config.kv_channels, ) - k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2)) - v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2)) - q_proj = linear_qkv[q_slice].reshape(-1, hidden_size).cpu() - k_proj = linear_qkv[k_slice].reshape(-1, hidden_size).cpu() - v_proj = linear_qkv[v_slice].reshape(-1, hidden_size).cpu() - return q_proj, k_proj, v_proj +@io.state_transform( + source_key=( + "vision_tower.vision_model.encoder.layers.*.self_attn.q_proj.weight", + "vision_tower.vision_model.encoder.layers.*.self_attn.k_proj.weight", + "vision_tower.vision_model.encoder.layers.*.self_attn.v_proj.weight", + ), + target_key="vision_model.decoder.layers.*.self_attention.linear_qkv.weight", +) +def _import_vision_qkv(ctx: io.TransformCTX, q, k, v): + megatron_config = ctx.target.config.vision_transformer_config + return import_qkv( + q, + k, + v, + head_num=megatron_config.num_attention_heads, + num_query_groups=megatron_config.num_query_groups, + heads_per_group=megatron_config.num_attention_heads // megatron_config.num_query_groups, + hidden_size=megatron_config.hidden_size, + head_size=megatron_config.kv_channels, + ) @io.state_transform( source_key=( - "language_model.model.layers.*.mlp.gate_proj.weight", - "language_model.model.layers.*.mlp.up_proj.weight", + "vision_tower.vision_model.encoder.layers.*.self_attn.q_proj.bias", + "vision_tower.vision_model.encoder.layers.*.self_attn.k_proj.bias", + "vision_tower.vision_model.encoder.layers.*.self_attn.v_proj.bias", ), - target_key="language_model.decoder.layers.*.mlp.linear_fc1.weight", + target_key="vision_model.decoder.layers.*.self_attention.linear_qkv.bias", ) -def _import_linear_fc1(down, gate): - return torch.cat((down, gate), axis=0) +def _import_vision_qkv_bias(ctx: io.TransformCTX, q_bias, k_bias, v_bias): + megatron_config = ctx.target.config.vision_transformer_config + return import_qkv( + q_bias.unsqueeze(-1), + k_bias.unsqueeze(-1), + v_bias.unsqueeze(-1), + head_num=megatron_config.num_attention_heads, + num_query_groups=megatron_config.num_query_groups, + heads_per_group=megatron_config.num_attention_heads // megatron_config.num_query_groups, + hidden_size=1, + head_size=megatron_config.kv_channels, + ).squeeze(-1) + + +@io.state_transform( + source_key=("vision_tower.vision_model.embeddings.class_embedding",), + target_key="vision_model.class_token", +) +def _import_cls_token(ctx: io.TransformCTX, cls_token): + return cls_token.reshape(1, 1, -1) @io.state_transform( - source_key="language_model.decoder.layers.*.mlp.linear_fc1.weight", - target_key=( + source_key=( "language_model.model.layers.*.mlp.gate_proj.weight", "language_model.model.layers.*.mlp.up_proj.weight", ), + target_key="language_model.decoder.layers.*.mlp.linear_fc1.weight", ) -def _export_linear_fc1(linear_fc1): - gate_proj, up_proj = torch.chunk(linear_fc1, 2, dim=0) - - return gate_proj, up_proj +def _import_linear_fc1(down, gate): + return torch.cat((down, gate), axis=0) diff --git a/nemo/collections/vlm/neva/model/vit_config.py b/nemo/collections/vlm/neva/model/vit_config.py new file mode 100644 index 000000000000..5d60a84313ca --- /dev/null +++ b/nemo/collections/vlm/neva/model/vit_config.py @@ -0,0 +1,79 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +from nemo.collections.llm.fn.activation import openai_gelu, quick_gelu + +from nemo.collections.vlm.neva.model.base import CLIPViTConfig + + +@dataclass +class CLIPViTL_14_336_Config(CLIPViTConfig): + """Clip vit large patch14 config""" + + vision_model_type = "clip" + patch_dim = 14 + img_h = 336 + img_w = 336 + num_layers = 24 + num_attention_heads = 16 + add_bias_linear = True + add_qkv_bias = True + hidden_size = 1024 + hidden_dropout = 0.0 + attention_dropout = 0.0 + ffn_hidden_size = 4096 + gated_linear_unit = False + activation_func = quick_gelu + kv_channels = 64 + num_query_groups = 16 + layernorm_zero_centered_gamma = False + apply_query_key_layer_scaling = False + bias_activation_fusion = False + bias_dropout_fusion = False + attention_softmax_in_fp32 = True + normalization = 'LayerNorm' + apply_rope_fusion = False + + +@dataclass +class SigLIPViT400M_14_384_Config(CLIPViTConfig): + """Siglip so400m patch14 384 config""" + + vision_model_type = "siglip" + patch_dim = 14 + img_h = 384 + img_w = 384 + num_layers = 27 + num_attention_heads = 16 + add_bias_linear = True + add_qkv_bias = True + hidden_size = 1152 + hidden_dropout = 0.0 + attention_dropout = 0.0 + ffn_hidden_size = 4304 + gated_linear_unit = False + activation_func = openai_gelu + kv_channels = 72 + num_query_groups = 16 + layernorm_zero_centered_gamma = False + apply_query_key_layer_scaling = False + bias_activation_fusion = False + bias_dropout_fusion = False + attention_softmax_in_fp32 = True + normalization = 'LayerNorm' + apply_rope_fusion = False + qk_layernorm = False + layernorm_epsilon = 1e-6 diff --git a/nemo/collections/vlm/peft/lora.py b/nemo/collections/vlm/peft/lora.py index 1e394daa8ead..7a80b7e06883 100644 --- a/nemo/collections/vlm/peft/lora.py +++ b/nemo/collections/vlm/peft/lora.py @@ -48,7 +48,7 @@ class LoRA(LLMLoRA): """ freeze_language_model: bool = True - freeze_vision_model: bool = False + freeze_vision_model: bool = True def freeze_model(self, model: nn.Module) -> None: modules = [] diff --git a/nemo/collections/vlm/recipes/__init__.py b/nemo/collections/vlm/recipes/__init__.py index 2b71ecc50f8f..e3225dec8c4f 100644 --- a/nemo/collections/vlm/recipes/__init__.py +++ b/nemo/collections/vlm/recipes/__init__.py @@ -13,9 +13,12 @@ # limitations under the License. -from nemo.collections.vlm.recipes import mllama_11b, mllama_90b +from nemo.collections.vlm.recipes import llava15_7b, llava15_13b, llava_next_7b, mllama_11b, mllama_90b __all__ = [ + "llava15_7b", + "llava15_13b", "mllama_11b", "mllama_90b", + "llava_next_7b", ] diff --git a/nemo/collections/vlm/recipes/llava15_13b.py b/nemo/collections/vlm/recipes/llava15_13b.py new file mode 100644 index 000000000000..97b77b82d3de --- /dev/null +++ b/nemo/collections/vlm/recipes/llava15_13b.py @@ -0,0 +1,150 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Optional + +import nemo_run as run +import pytorch_lightning as pl +import torch + +from nemo import lightning as nl +from nemo.collections import llm, vlm +from nemo.collections.llm.recipes.finetune_default import nemo_resume +from nemo.collections.llm.recipes.log.default import tensorboard_logger +from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing +from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed +from nemo.collections.vlm.neva.data.mock import MockDataModule + +NAME = "llava15_13b" + + +@run.cli.factory(name=NAME) +def model() -> run.Config[pl.LightningModule]: + """ + Factory function to create a Llava 1.5 13B model configuration. + + Returns: + run.Config[pl.LightningModule]: Configuration for the Llava 1.5 13B model model. + + Examples: + CLI usage: + $ nemo llm pretrain model=llava15_13b ... + + Python API usage: + >>> model_config = model() + >>> print(model_config) + """ + return run.Config(vlm.LlavaModel, config=run.Config(vlm.Llava15Config13B)) + + +@run.cli.factory(target=llm.finetune, name=NAME) +def finetune_recipe( + dir: Optional[str] = None, + name: str = "default", + num_nodes: int = 1, + num_gpus_per_node: int = 8, + peft_scheme: Optional[str] = 'lora', +) -> run.Partial: + """ + Create a fine-tuning recipe for Llava1.5 13B model. + + This function sets up a complete configuration for fine-tuning, including + model, trainer, data, logging, optimization, and resumption settings. + The recipe uses LoRA (Low-Rank Adaptation) for efficient fine-tuning, unless peft_scheme is set to None. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the fine-tuning run. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + + Returns: + run.Partial: Partial configuration for fine-tuning. + + Examples: + CLI usage: + $ nemo llm finetune --factory llava15_13b + + Python API usage: + >>> recipe = finetune_recipe(name="llava15_13b_finetune", num_nodes=1) + >>> print(recipe) + + Note: + This recipe uses the SQuAD dataset for fine-tuning. For more information + on fine-tuning LLMs with NeMo, see the fine-tuning guide in the + `examples/llm/finetune/` directory. + """ + + strategy = run.Config( + nl.MegatronStrategy, + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + encoder_pipeline_model_parallel_size=0, + pipeline_dtype=torch.bfloat16, + ) + + trainer = run.Config( + nl.Trainer, + accelerator="gpu", + accumulate_grad_batches=1, + devices=num_gpus_per_node, + limit_val_batches=10, + log_every_n_steps=1, + max_steps=5190, + num_nodes=num_nodes, + plugins=bf16_mixed(), + strategy=strategy, + val_check_interval=1000, + ) + + recipe = run.Partial( + llm.finetune, + model=model(), + trainer=trainer, + data=run.Config( + MockDataModule, + seq_length=4096, + global_batch_size=128, + micro_batch_size=1, + tokenizer=None, + image_processor=None, + num_workers=4, + ), + log=llm.default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), + optim=distributed_fused_adam_with_cosine_annealing(max_lr=2.0e-05, min_lr=2.0e-07, warmup_steps=150), + resume=nemo_resume("llava-hf/llava-1.5-13b-hf"), + ) + + if peft_scheme is None or peft_scheme.lower() == 'none': + recipe.trainer.strategy.tensor_model_parallel_size = 2 + recipe.optim.config.lr = 2e-05 + elif peft_scheme.lower() == 'lora': + recipe.peft = run.Config( + vlm.LoRA, + freeze_vision_model=False, + target_modules=[ + "*.language_model.*.linear_qkv", + "*.language_model.*.linear_q", + "*.language_model.*.linear_kv", + "*.language_model.*.linear_proj", + "*.language_model.*.linear_fc1", + "*.language_model.*.linear_fc2", + ], + ) + recipe.optim.config.lr = 1e-4 + else: + raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") + + return recipe diff --git a/nemo/collections/vlm/recipes/llava15_7b.py b/nemo/collections/vlm/recipes/llava15_7b.py new file mode 100644 index 000000000000..04e6bd36f4d4 --- /dev/null +++ b/nemo/collections/vlm/recipes/llava15_7b.py @@ -0,0 +1,152 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Optional + +import nemo_run as run +import pytorch_lightning as pl +import torch + +from nemo import lightning as nl +from nemo.collections import llm, vlm +from nemo.collections.llm.recipes.finetune_default import nemo_resume +from nemo.collections.llm.recipes.log.default import tensorboard_logger +from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing +from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed +from nemo.collections.vlm.neva.data.mock import MockDataModule +from nemo.utils.exp_manager import TimingCallback + +NAME = "llava15_7b" + + +@run.cli.factory(name=NAME) +def model() -> run.Config[pl.LightningModule]: + """ + Factory function to create a Llava 1.5 7B model configuration. + + Returns: + run.Config[pl.LightningModule]: Configuration for the Llava 1.5 7B model model. + + Examples: + CLI usage: + $ nemo llm pretrain model=llava15_7b ... + + Python API usage: + >>> model_config = model() + >>> print(model_config) + """ + return run.Config(vlm.LlavaModel, config=run.Config(vlm.Llava15Config7B)) + + +@run.cli.factory(target=llm.finetune, name=NAME) +def finetune_recipe( + dir: Optional[str] = None, + name: str = "default", + num_nodes: int = 1, + num_gpus_per_node: int = 8, + peft_scheme: Optional[str] = 'none', +) -> run.Partial: + """ + Create a fine-tuning recipe for Llava1.5 7B model. + + This function sets up a complete configuration for fine-tuning, including + model, trainer, data, logging, optimization, and resumption settings. + The recipe uses LoRA (Low-Rank Adaptation) for efficient fine-tuning, unless peft_scheme is set to None. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the fine-tuning run. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + + Returns: + run.Partial: Partial configuration for fine-tuning. + + Examples: + CLI usage: + $ nemo llm finetune --factory llava15_7b + + Python API usage: + >>> recipe = finetune_recipe(name="llava15_7b_finetune", num_nodes=1) + >>> print(recipe) + + Note: + This recipe uses the SQuAD dataset for fine-tuning. For more information + on fine-tuning LLMs with NeMo, see the fine-tuning guide in the + `examples/llm/finetune/` directory. + """ + + strategy = run.Config( + nl.MegatronStrategy, + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + encoder_pipeline_model_parallel_size=0, + pipeline_dtype=torch.bfloat16, + ) + + trainer = run.Config( + nl.Trainer, + accelerator="gpu", + accumulate_grad_batches=1, + devices=num_gpus_per_node, + limit_val_batches=10, + log_every_n_steps=1, + max_steps=5190, + num_nodes=num_nodes, + plugins=bf16_mixed(), + strategy=strategy, + val_check_interval=1000, + callbacks=[run.Config(TimingCallback)], + ) + + recipe = run.Partial( + llm.finetune, + model=model(), + trainer=trainer, + data=run.Config( + MockDataModule, + seq_length=4096, + global_batch_size=128, + micro_batch_size=2, + tokenizer=None, + image_processor=None, + num_workers=4, + ), + log=llm.default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), + optim=distributed_fused_adam_with_cosine_annealing(max_lr=2.0e-05, min_lr=2.0e-07, warmup_steps=150), + resume=nemo_resume("llava-hf/llava-1.5-7b-hf"), + ) + + if peft_scheme is None or peft_scheme.lower() == 'none': + recipe.trainer.strategy.tensor_model_parallel_size = 2 + recipe.optim.config.lr = 2e-05 + elif peft_scheme.lower() == 'lora': + recipe.peft = run.Config( + vlm.LoRA, + freeze_vision_model=False, + target_modules=[ + "*.language_model.*.linear_qkv", + "*.language_model.*.linear_q", + "*.language_model.*.linear_kv", + "*.language_model.*.linear_proj", + "*.language_model.*.linear_fc1", + "*.language_model.*.linear_fc2", + ], + ) + recipe.optim.config.lr = 1e-4 + else: + raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") + + return recipe diff --git a/nemo/collections/vlm/recipes/llava_next_7b.py b/nemo/collections/vlm/recipes/llava_next_7b.py new file mode 100644 index 000000000000..c483ff788f26 --- /dev/null +++ b/nemo/collections/vlm/recipes/llava_next_7b.py @@ -0,0 +1,242 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Optional + +import nemo_run as run +import pytorch_lightning as pl +import torch + +from nemo import lightning as nl +from nemo.collections import llm, vlm +from nemo.collections.llm.recipes.finetune_default import nemo_resume +from nemo.collections.llm.recipes.log.default import tensorboard_logger +from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing +from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed +from nemo.collections.vlm import LlavaNextMockDataModule +from nemo.utils.exp_manager import TimingCallback + +NAME = "llava_next_7b" + + +@run.cli.factory(name=NAME) +def model(config=run.Config(vlm.LlavaNextConfig7B)) -> run.Config[pl.LightningModule]: + """ + Factory function to create a LlavaNext 7B model configuration. + + Returns: + run.Config[pl.LightningModule]: Configuration for the Llava Next 7B model. + + Examples: + CLI usage: + $ nemo llm pretrain model=llava_next_7b ... + + Python API usage: + >>> model_config = model() + >>> print(model_config) + """ + return run.Config(vlm.LlavaNextModel, config=config) + + +@run.cli.factory(target=llm.finetune, name=NAME) +def finetune_recipe( + dir: Optional[str] = None, + name: str = "default", + num_nodes: int = 1, + num_gpus_per_node: int = 8, + peft_scheme: Optional[str] = 'none', +) -> run.Partial: + """ + Create a fine-tuning recipe for LlavaNext 7B model. + + This function sets up a complete configuration for fine-tuning, including + model, trainer, data, logging, optimization, and resumption settings. + The recipe uses LoRA (Low-Rank Adaptation) for efficient fine-tuning, unless peft_scheme is set to None. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the fine-tuning run. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + + Returns: + run.Partial: Partial configuration for fine-tuning. + + Examples: + CLI usage: + $ nemo llm finetune --factory llava_next_7b + + Python API usage: + >>> recipe = finetune_recipe(name="llava_next_7b_finetune", num_nodes=1) + >>> print(recipe) + """ + + strategy = run.Config( + nl.MegatronStrategy, + tensor_model_parallel_size=4, + pipeline_model_parallel_size=1, + encoder_pipeline_model_parallel_size=0, + pipeline_dtype=torch.bfloat16, + ) + + trainer = run.Config( + nl.Trainer, + accelerator="gpu", + accumulate_grad_batches=1, + devices=num_gpus_per_node, + limit_val_batches=10, + log_every_n_steps=1, + max_steps=5190, + num_nodes=num_nodes, + plugins=bf16_mixed(), + strategy=strategy, + val_check_interval=1000, + callbacks=[run.Config(TimingCallback)], + ) + from transformers import AutoProcessor + + from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer + + recipe = run.Partial( + llm.finetune, + model=model( + config=run.Config( + vlm.LlavaNextConfig7B, + freeze_language_model=False, + freeze_vision_model=True, + freeze_vision_projection=False, + ) + ), + trainer=trainer, + data=run.Config( + LlavaNextMockDataModule, + seq_length=4096, + global_batch_size=8, + micro_batch_size=2, + tokenizer=None, + image_processor=None, + num_workers=4, + ), + log=llm.default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), + optim=distributed_fused_adam_with_cosine_annealing(max_lr=2.0e-05, min_lr=2.0e-07, warmup_steps=150), + resume=nemo_resume("llava-hf/llava-v1.6-vicuna-7b-hf"), + ) + + if peft_scheme is None or peft_scheme.lower() == 'none': + recipe.trainer.strategy.tensor_model_parallel_size = 2 + recipe.optim.config.lr = 2e-05 + elif peft_scheme.lower() == 'lora': + recipe.peft = run.Config( + vlm.LoRA, + freeze_vision_model=False, + target_modules=[ + "*.language_model.*.linear_qkv", + "*.language_model.*.linear_q", + "*.language_model.*.linear_kv", + "*.language_model.*.linear_proj", + "*.language_model.*.linear_fc1", + "*.language_model.*.linear_fc2", + ], + ) + recipe.optim.config.lr = 1e-4 + else: + raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") + + return recipe + + +@run.cli.factory(target=llm.pretrain, name=NAME) +def pretrain_recipe( + dir: Optional[str] = None, + name: str = "default", + num_nodes: int = 1, + num_gpus_per_node: int = 8, + peft_scheme: Optional[str] = 'none', +) -> run.Partial: + """ + Create a Pre-training recipe for Llava1.6 7B model. + + This function sets up a complete configuration for pre-training, including + model, trainer, data, logging, optimization, and resumption settings. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the fine-tuning run. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + + Returns: + run.Partial: Partial configuration for fine-tuning. + + Examples: + CLI usage: + $ nemo llm pretrain --factory llava_next_7b + + Python API usage: + >>> recipe = finetune_recipe(name="llava_next_7b_pretrain", num_nodes=1) + >>> print(recipe) + """ + + strategy = run.Config( + nl.MegatronStrategy, + tensor_model_parallel_size=4, + pipeline_model_parallel_size=1, + encoder_pipeline_model_parallel_size=0, + pipeline_dtype=torch.bfloat16, + ) + + trainer = run.Config( + nl.Trainer, + accelerator="gpu", + accumulate_grad_batches=1, + devices=num_gpus_per_node, + limit_val_batches=10, + log_every_n_steps=1, + max_steps=5190, + num_nodes=num_nodes, + plugins=bf16_mixed(), + strategy=strategy, + val_check_interval=1000, + callbacks=[run.Config(TimingCallback)], + ) + from transformers import AutoProcessor + + from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer + + recipe = run.Partial( + llm.pretrain, + model=model( + config=run.Config( + vlm.LlavaNextConfig7B, + freeze_language_model=True, + freeze_vision_model=True, + freeze_vision_projection=False, + ) + ), + trainer=trainer, + data=run.Config( + LlavaNextMockDataModule, + seq_length=4096, + global_batch_size=8, + micro_batch_size=2, + tokenizer=None, + image_processor=None, + num_workers=4, + ), + log=llm.default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), + optim=distributed_fused_adam_with_cosine_annealing(max_lr=0.001, min_lr=2.0e-05, warmup_steps=150), + ) + + return recipe diff --git a/nemo/collections/vlm/recipes/mllama_11b.py b/nemo/collections/vlm/recipes/mllama_11b.py index 697be9990faf..4b08606900e3 100644 --- a/nemo/collections/vlm/recipes/mllama_11b.py +++ b/nemo/collections/vlm/recipes/mllama_11b.py @@ -15,8 +15,8 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch from nemo import lightning as nl @@ -26,6 +26,7 @@ from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed from nemo.collections.vlm.mllama.data.mock import MockDataModule +from nemo.utils.exp_manager import TimingCallback NAME = "mllama_11b" @@ -46,7 +47,7 @@ def model() -> run.Config[pl.LightningModule]: >>> model_config = model() >>> print(model_config) """ - return run.Config(vlm.MLlamaModel, config=run.Config(vlm.MLlamaConfig11B)) + return run.Config(vlm.MLlamaModel, config=run.Config(vlm.MLlamaConfig11BInstruct)) @run.cli.factory(target=llm.finetune, name=NAME) @@ -107,6 +108,7 @@ def finetune_recipe( plugins=bf16_mixed(), strategy=strategy, val_check_interval=100, + callbacks=[run.Config(TimingCallback)], ) recipe = run.Partial( @@ -115,34 +117,37 @@ def finetune_recipe( trainer=trainer, data=run.Config( MockDataModule, - seq_length=4100, # encoder (vision) seq length - decoder_seq_length=512, # decoder (llm) seq length - global_batch_size=16, - micro_batch_size=2, + seq_length=6404, # encoder (vision) seq length + decoder_seq_length=2048, # decoder (llm) seq length + global_batch_size=2, + micro_batch_size=1, vocab_size=128256, - crop_size=(448, 448), + crop_size=(560, 560), num_workers=0, ), log=llm.default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), optim=distributed_fused_adam_with_cosine_annealing(max_lr=1e-4, min_lr=2.0e-07, warmup_steps=150), - resume=nemo_resume("meta-llama/Llama-3.2-11B-Vision"), + resume=nemo_resume("meta-llama/Llama-3.2-11B-Vision-Instruct"), ) if peft_scheme is None or peft_scheme.lower() == 'none': recipe.trainer.strategy.tensor_model_parallel_size = 2 recipe.optim.config.lr = 2e-05 elif peft_scheme.lower() == 'lora': + # pylint: disable=line-too-long + """Adapted from https://github.com/meta-llama/llama-recipes/blob/main/src/llama_recipes/configs/peft.py""" recipe.peft = run.Config( vlm.LoRA, - freeze_vision_model=False, + freeze_vision_model=True, target_modules=[ - "*.language_model.*.linear_qkv", - "*.language_model.*.linear_q", - "*.language_model.*.linear_kv", - "*.language_model.*.linear_proj", - "*.language_model.*.linear_fc1", - "*.language_model.*.linear_fc2", + "linear_qkv", + "linear_q", + "linear_kv", ], + dim=8, + alpha=32, + dropout=0.05, + dropout_position="pre", ) recipe.optim.config.lr = 1e-4 else: diff --git a/nemo/collections/vlm/recipes/mllama_90b.py b/nemo/collections/vlm/recipes/mllama_90b.py index 8822aa9b189f..12e0329fc6dd 100644 --- a/nemo/collections/vlm/recipes/mllama_90b.py +++ b/nemo/collections/vlm/recipes/mllama_90b.py @@ -15,8 +15,8 @@ from typing import Optional +import lightning.pytorch as pl import nemo_run as run -import pytorch_lightning as pl import torch from nemo import lightning as nl @@ -26,6 +26,7 @@ from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed from nemo.collections.vlm.mllama.data.mock import MockDataModule +from nemo.utils.exp_manager import TimingCallback NAME = "mllama_90b" @@ -46,7 +47,7 @@ def model() -> run.Config[pl.LightningModule]: >>> model_config = model() >>> print(model_config) """ - return run.Config(vlm.MLlamaModel, config=run.Config(vlm.MLlamaConfig90B)) + return run.Config(vlm.MLlamaModel, config=run.Config(vlm.MLlamaConfig90BInstruct)) @run.cli.factory(target=llm.finetune, name=NAME) @@ -107,6 +108,7 @@ def finetune_recipe( plugins=bf16_mixed(), strategy=strategy, val_check_interval=100, + callbacks=[run.Config(TimingCallback)], ) recipe = run.Partial( @@ -116,7 +118,7 @@ def finetune_recipe( data=run.Config( MockDataModule, seq_length=6404, # encoder (vision) seq length - decoder_seq_length=512, # decoder (llm) seq length + decoder_seq_length=2048, # decoder (llm) seq length global_batch_size=16, micro_batch_size=2, vocab_size=128256, @@ -125,23 +127,26 @@ def finetune_recipe( ), log=llm.default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), optim=distributed_fused_adam_with_cosine_annealing(max_lr=1e-4, min_lr=2.0e-07, warmup_steps=150), - resume=nemo_resume("meta-llama/Llama-3.2-90B-Vision"), + resume=nemo_resume("meta-llama/Llama-3.2-90B-Vision-Instruct"), ) if peft_scheme is None or peft_scheme.lower() == 'none': raise ValueError("Full finetuning recipe for Llama-3.2-90B model will be supported soon.") elif peft_scheme.lower() == 'lora': + # pylint: disable=line-too-long + """Adapted from https://github.com/meta-llama/llama-recipes/blob/main/src/llama_recipes/configs/peft.py""" recipe.peft = run.Config( vlm.LoRA, - freeze_vision_model=False, + freeze_vision_model=True, target_modules=[ - "*.language_model.*.linear_qkv", - "*.language_model.*.linear_q", - "*.language_model.*.linear_kv", - "*.language_model.*.linear_proj", - "*.language_model.*.linear_fc1", - "*.language_model.*.linear_fc2", + "linear_qkv", + "linear_q", + "linear_kv", ], + dim=8, + alpha=32, + dropout=0.05, + dropout_position="pre", ) recipe.optim.config.lr = 1e-4 else: diff --git a/nemo/core/classes/__init__.py b/nemo/core/classes/__init__.py index 3a6db2602648..e773972c6d7b 100644 --- a/nemo/core/classes/__init__.py +++ b/nemo/core/classes/__init__.py @@ -14,8 +14,8 @@ import hydra +import lightning.pytorch import omegaconf -import pytorch_lightning from nemo.core.classes.common import ( FileIO, diff --git a/nemo/core/classes/exportable.py b/nemo/core/classes/exportable.py index aab09d42d907..ba284e7c28cd 100644 --- a/nemo/core/classes/exportable.py +++ b/nemo/core/classes/exportable.py @@ -15,7 +15,7 @@ from typing import Dict, List, Optional, Union import torch -from pytorch_lightning.core.module import _jit_is_scripting +from lightning.pytorch.core.module import _jit_is_scripting from nemo.core.classes import typecheck from nemo.core.neural_types import NeuralType diff --git a/nemo/core/classes/modelPT.py b/nemo/core/classes/modelPT.py index a15f769e9d88..88ff47caf8c2 100644 --- a/nemo/core/classes/modelPT.py +++ b/nemo/core/classes/modelPT.py @@ -35,9 +35,9 @@ HAVE_MEGATRON_CORE = False +from lightning.pytorch import LightningModule, Trainer +from lightning.pytorch.utilities import model_summary, rank_zero_only from omegaconf import DictConfig, OmegaConf, open_dict -from pytorch_lightning import LightningModule, Trainer -from pytorch_lightning.utilities import model_summary, rank_zero_only from nemo import package_info from nemo.core import optim @@ -79,7 +79,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): """ if trainer is not None and not isinstance(trainer, Trainer): raise ValueError( - f"trainer constructor argument must be either None or pytorch_lightning.Trainer. But got {type(trainer)} instead." + f"trainer constructor argument must be either None or lightning.pytorch.Trainer. But got {type(trainer)} instead." ) super().__init__() @@ -211,6 +211,12 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self._memory_profile_started = False self._memory_profile_complete = False + # Setup chakra profiling if it has been enabled in the model config + self._setup_chakra_profiling() + + # A flag for the profile generation + self._chakra_profile_in_progress = False + def __init_subclass__(cls) -> None: cls._save_restore_connector = SaveRestoreConnector() @@ -1744,6 +1750,78 @@ def update_save_restore_connector(cls, save_restore_connector): else: setattr(cls, '_save_restore_connector', save_restore_connector) + def _setup_chakra_profiling(self): + """Enables chakra profiling + To use, add the following options to the model config: + ## Chakra profiling options + chakra_profile: + enabled: False + start_step: 2 # Global batch to start profiling + end_step: 2 # Global batch to end profiling + warmup_steps: 0 # Global batch to start profiling + active_steps: 1 # Global batch to start profiling + trace_dir: None # Path to store the profile output file + """ + if self.cfg.get('chakra_profile', None) is not None: + if self.cfg.chakra_profile.get('enabled', False): + + from torch.profiler import ExecutionTraceObserver + from nemo.utils.env_var_parsing import get_envint + + self._chakra_profile_enabled = True + self._chakra_profile_start_step = self.cfg.chakra_profile.get('start_step', 0) + self._chakra_profile_end_step = self.cfg.chakra_profile.get('end_step', 0) + trace_dir = self.cfg.chakra_profile.get('trace_dir', None) + + if trace_dir is None or not os.path.isdir(trace_dir): + raise ValueError(f'chakra profile output path ({trace_dir}) is not set or does not exist.') + + trace_dir = Path(trace_dir) + warmup_steps = self.cfg.chakra_profile.get('warmup_steps', 0) + active_steps = self.cfg.chakra_profile.get('active_steps', 1) + + job_id = get_envint("SLURM_JOB_ID", 0) + + self._chakra_trace_dir = trace_dir / f'{job_id}_chakra' + self._kineto_trace_dir = trace_dir / f'{job_id}_kineto' + + self._chakra_trace_dir.mkdir(parents=True, exist_ok=True) + self._kineto_trace_dir.mkdir(parents=True, exist_ok=True) + + if isinstance(self._chakra_profile_start_step, int): + logging.info(f'chakra profiling setup with start_step: {self._chakra_profile_start_step}') + else: + raise ValueError( + f'chakra start_step must be of type int. Found: {type(self._chakra_profile_start_step)}' + ) + + if isinstance(self._chakra_profile_end_step, int): + logging.info(f'chakra profiling setup with end_step: {self._chakra_profile_end_step}') + else: + raise ValueError( + f'chakra end_step must be of type int. Found: {type(self._chakra_profile_end_step)}' + ) + + if self._chakra_profile_end_step >= self._chakra_profile_start_step: + pass + else: + raise ValueError(f'chakra end_step must be greater than or equal to chakra start_step') + + if self.cfg.nsys_profile.get('enabled', False): + raise Exception( + f"Profiler conflict: Chakra profiling and Nsys profiling cannot be enabled at the same time." + ) + + self._et = ExecutionTraceObserver() + self._prof = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + schedule=torch.profiler.schedule(wait=0, warmup=warmup_steps, active=active_steps), + execution_trace_observer=self._et, + ) + def _setup_profiling(self): """Enables nsys profiling To use, add the following optoins to the model config: @@ -1848,11 +1926,22 @@ def on_train_start(self): def on_train_batch_start(self, batch: Any, batch_idx: int, unused: int = 0) -> Optional[int]: """PyTorch Lightning hook: https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-train-batch-start - We use it here to enable nsys profiling and dynamic freezing. + We use it here to enable profiling and dynamic freezing. """ - - # nsys profiling if self.device.type == 'cuda': + if hasattr(self, '_chakra_profile_enabled'): + if self._chakra_profile_enabled and not self._chakra_profile_in_progress: + if ( + self.trainer.global_step >= self._chakra_profile_start_step + and self.trainer.global_step <= self._chakra_profile_end_step + ): + logging.info( + f"====== Start chakra profiling from global_step {self.trainer.global_step} ======" + ) + self._et.register_callback(str(self._chakra_trace_dir / f'rank-{get_rank()}.json')) + self._prof.start() + self._chakra_profile_in_progress = True + if hasattr(self, '_nsys_profile_enabled'): if self._nsys_profile_enabled and not self._nsys_profile_started: if batch_idx >= self._nsys_profile_start_step and get_rank() in self._nsys_profile_ranks: @@ -1898,6 +1987,18 @@ def on_train_batch_end(self, outputs, batch: Any, batch_idx: int, unused: int = """ if self.device.type == 'cuda': + if hasattr(self, '_chakra_profile_enabled'): + # self.trainer.global_step is increaeasd before on_train_batch_end + if self._chakra_profile_enabled and self._chakra_profile_in_progress: + if self.trainer.global_step - 1 >= self._chakra_profile_end_step: + logging.info(f"====== End chakra profiling at global_step {self.trainer.global_step} ======") + self._prof.stop() + self._prof.export_chrome_trace(str(self._kineto_trace_dir / f'rank-{get_rank()}.json')) + self._et.unregister_callback() + self._chakra_profile_in_progress = False + elif self.trainer.global_step - 1 >= self._chakra_profile_start_step: + self._prof.step() + if hasattr(self, '_nsys_profile_enabled'): if self._nsys_profile_enabled and not self._nsys_profile_complete: if batch_idx >= self._nsys_profile_end_step and get_rank() in self._nsys_profile_ranks: diff --git a/nemo/core/connectors/save_restore_connector.py b/nemo/core/connectors/save_restore_connector.py index cd9971a9c383..2c4c826d1daf 100644 --- a/nemo/core/connectors/save_restore_connector.py +++ b/nemo/core/connectors/save_restore_connector.py @@ -23,9 +23,9 @@ from typing import Callable, Generator, Optional, Set, Union import torch +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import DictConfig, OmegaConf from omegaconf.omegaconf import open_dict -from pytorch_lightning.trainer.trainer import Trainer from nemo.core import classes as nemo_classes # to avoid circular import do not import ModelPT directly from nemo.utils import logging, model_utils diff --git a/nemo/core/utils/k2_guard.py b/nemo/core/utils/k2_guard.py index a9f64ce39c6b..b0e86d319ec0 100644 --- a/nemo/core/utils/k2_guard.py +++ b/nemo/core/utils/k2_guard.py @@ -21,8 +21,9 @@ import textwrap +from lightning.pytorch.utilities.imports import package_available from packaging.version import Version -from pytorch_lightning.utilities.imports import package_available + from nemo.core.utils.k2_utils import K2_INSTALLATION_MESSAGE __K2_MINIMUM_MAJOR_VERSION = 1 diff --git a/nemo/deploy/deploy_base.py b/nemo/deploy/deploy_base.py index 63746199bac6..41e0e7ddbdc9 100644 --- a/nemo/deploy/deploy_base.py +++ b/nemo/deploy/deploy_base.py @@ -18,7 +18,7 @@ use_pytorch_lightning = True try: - from pytorch_lightning import Trainer + from lightning.pytorch import Trainer except Exception: use_pytorch_lightning = False diff --git a/nemo/deploy/nlp/__init__.py b/nemo/deploy/nlp/__init__.py index 5ebbe6816664..633544e300ed 100644 --- a/nemo/deploy/nlp/__init__.py +++ b/nemo/deploy/nlp/__init__.py @@ -12,15 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. - -use_query_llm = True -try: - from nemo.deploy.nlp.query_llm import NemoQueryLLM, NemoQueryLLMPyTorch -except Exception: - use_query_llm = False - -use_megatron_llm = True -try: - from nemo.deploy.nlp.megatronllm_deployable import MegatronLLMDeployable -except Exception: - use_megatron_llm = False +from nemo.deploy.nlp.megatronllm_deployable import MegatronLLMDeployable +from nemo.deploy.nlp.query_llm import NemoQueryLLM, NemoQueryLLMPyTorch diff --git a/nemo/deploy/nlp/megatronllm_deployable.py b/nemo/deploy/nlp/megatronllm_deployable.py index 64cf6114ceba..0ce5991cdc95 100644 --- a/nemo/deploy/nlp/megatronllm_deployable.py +++ b/nemo/deploy/nlp/megatronllm_deployable.py @@ -20,7 +20,7 @@ import numpy as np import torch import wrapt -from pytorch_lightning.trainer.trainer import Trainer +from lightning.pytorch.trainer.trainer import Trainer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel from nemo.collections.nlp.modules.common.text_generation_utils import ( diff --git a/nemo/deploy/nlp/query_llm.py b/nemo/deploy/nlp/query_llm.py index 7e873db6b5b1..e1d21bb54b76 100644 --- a/nemo/deploy/nlp/query_llm.py +++ b/nemo/deploy/nlp/query_llm.py @@ -174,6 +174,7 @@ def query_llm( end_strings=None, init_timeout=60.0, openai_format_response: bool = False, + output_generation_logits: bool = False, ): """ Query the Triton server synchronously and return a list of responses. @@ -190,6 +191,8 @@ def query_llm( no_repeat_ngram_size (int): no repeat ngram size. task_id (str): downstream task id if virtual tokens are used. init_timeout (flat): timeout for the connection. + openai_format_response: return response similar to OpenAI API format + output_generation_logits: return generation logits from model on PyTriton """ prompts = str_list2numpy(prompts) @@ -248,6 +251,9 @@ def query_llm( if end_strings is not None: inputs["end_strings"] = str_list2numpy(end_strings) + if output_generation_logits is not None: + inputs["output_generation_logits"] = np.full(prompts.shape, output_generation_logits, dtype=np.bool_) + with ModelClient(self.url, self.model_name, init_timeout_s=init_timeout) as client: result_dict = client.infer_batch(**inputs) output_type = client.model_config.outputs[0].dtype @@ -269,6 +275,9 @@ def query_llm( "model": self.model_name, "choices": [{"text": str(sentences)}], } + # Convert gneration logits to a list to make it json serializable and add it to openai_response dict + if output_generation_logits: + openai_response["choices"][0]["generation_logits"] = result_dict["generation_logits"].tolist() return openai_response else: return sentences diff --git a/nemo/deploy/service/rest_model_api.py b/nemo/deploy/service/rest_model_api.py index fbc774883faa..64afea167295 100644 --- a/nemo/deploy/service/rest_model_api.py +++ b/nemo/deploy/service/rest_model_api.py @@ -8,8 +8,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - -import json import os from pathlib import Path import requests @@ -19,6 +17,7 @@ from pydantic_settings import BaseSettings from nemo.deploy.nlp import NemoQueryLLM +from nemo.utils import logging class TritonSettings(BaseSettings): @@ -29,14 +28,13 @@ class TritonSettings(BaseSettings): def __init__(self): super(TritonSettings, self).__init__() try: - with open(os.path.join(Path.cwd(), 'nemo/deploy/service/config.json')) as config: - config_json = json.load(config) - self._triton_service_port = config_json["triton_service_port"] - self._triton_service_ip = config_json["triton_service_ip"] - self._triton_request_timeout = config_json["triton_request_timeout"] - self._openai_format_response = config_json["openai_format_response"] + self._triton_service_port = int(os.environ.get('TRITON_PORT', 8080)) + self._triton_service_ip = os.environ.get('TRITON_HTTP_ADDRESS', '0.0.0.0') + self._triton_request_timeout = int(os.environ.get('TRITON_REQUEST_TIMEOUT', 60)) + self._openai_format_response = os.environ.get('OPENAI_FORMAT_RESPONSE', 'False').lower() == 'true' + self._output_generation_logits = os.environ.get('OUTPUT_GENERATION_LOGITS', 'False').lower() == 'true' except Exception as error: - print("An exception occurred:", error) + logging.error("An exception occurred trying to retrieve set args in TritonSettings class. Error:", error) return @property @@ -54,11 +52,17 @@ def triton_request_timeout(self): @property def openai_format_response(self): """ - Retuns the response from Triton server in OpenAI compatible formar if set to True, - default set in config.json is false. + Retuns the response from Triton server in OpenAI compatible format if set to True. """ return self._openai_format_response + @property + def output_generation_logits(self): + """ + Retuns the generation logits along with text in Triton server output if set to True. + """ + return self._output_generation_logits + app = FastAPI() triton_settings = TritonSettings() @@ -70,19 +74,27 @@ class CompletionRequest(BaseModel): max_tokens: int = 512 temperature: float = 1.0 top_p: float = 0.0 - n: int = 1 + top_k: int = 1 stream: bool = False stop: str | None = None frequency_penalty: float = 1.0 -@app.get("/triton_health") +@app.get("/v1/health") +def health_check(): + return {"status": "ok"} + + +@app.get("/v1/triton_health") async def check_triton_health(): """ This method exposes endpoint "/triton_health" which can be used to verify if Triton server is accessible while running the REST or FastAPI application. - Verify by running: curl http://service_http_address:service_port/triton_health and the returned status should inform if the server is accessible. + Verify by running: curl http://service_http_address:service_port/v1/triton_health and the returned status should inform if the server is accessible. """ - triton_url = f"triton_settings.triton_service_ip:str(triton_settings.triton_service_port)/v2/health/ready" + triton_url = ( + f"http://{triton_settings.triton_service_ip}:{str(triton_settings.triton_service_port)}/v2/health/ready" + ) + logging.info(f"Attempting to connect to Triton server at: {triton_url}") try: response = requests.get(triton_url, timeout=5) if response.status_code == 200: @@ -101,11 +113,13 @@ def completions_v1(request: CompletionRequest): output = nq.query_llm( prompts=[request.prompt], max_output_len=request.max_tokens, - top_k=request.n, + # when these below params are passed as None + top_k=request.top_k, top_p=request.top_p, temperature=request.temperature, init_timeout=triton_settings.triton_request_timeout, openai_format_response=triton_settings.openai_format_response, + output_generation_logits=triton_settings.output_generation_logits, ) if triton_settings.openai_format_response: return output @@ -114,5 +128,5 @@ def completions_v1(request: CompletionRequest): "output": output[0][0], } except Exception as error: - print("An exception occurred:", error) + logging.error("An exception occurred with the post request to /v1/completions/ endpoint:", error) return {"error": "An exception occurred"} diff --git a/nemo/export/__init__.py b/nemo/export/__init__.py index d9155f923f18..6b1f8c90aa8f 100644 --- a/nemo/export/__init__.py +++ b/nemo/export/__init__.py @@ -11,3 +11,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from nemo.export.tensorrt_lazy_compiler import trt_compile diff --git a/nemo/export/tensorrt_lazy_compiler.py b/nemo/export/tensorrt_lazy_compiler.py new file mode 100644 index 000000000000..ab40278efa94 --- /dev/null +++ b/nemo/export/tensorrt_lazy_compiler.py @@ -0,0 +1,714 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import inspect +import os +import tempfile +import threading +from collections import OrderedDict +from logging import getLogger +from pathlib import Path +from types import MethodType +from typing import Any, Dict, List, Sequence, Tuple, Union + +import torch + +from nemo.utils.export_utils import add_casts_around_norms, replace_for_export +from nemo.utils.import_utils import safe_import + +polygraphy, polygraphy_imported = safe_import("polygraphy") +if polygraphy_imported: + from polygraphy.backend.common import bytes_from_path + from polygraphy.backend.trt import ( + CreateConfig, + Profile, + engine_bytes_from_network, + engine_from_bytes, + network_from_onnx_path, + ) + +trt, trt_imported = safe_import("tensorrt") +torch_tensorrt, _ = safe_import("torch_tensorrt") +cudart, _ = safe_import("cuda.cudart") + +lock_sm = threading.Lock() + + +def trt_to_torch_dtype_dict(): + """ + Map of TRT dtype -> Torch dtype + """ + return { + trt.int32: torch.int32, + trt.float32: torch.float32, + trt.float16: torch.float16, + trt.bfloat16: torch.float16, + trt.int64: torch.int64, + trt.int8: torch.int8, + trt.bool: torch.bool, + } + + +def get_profile_shapes(input_shape: Sequence[int], dynamic_batchsize: Sequence[int] | None): + """ + Given a sample input shape, calculate min/opt/max shapes according to dynamic_batchsize. + """ + + def scale_batch_size(input_shape: Sequence[int], scale_num: int): + scale_shape = [*input_shape] + scale_shape[0] = scale_num + return scale_shape + + # Use the dynamic batchsize range to generate the min, opt and max model input shape + if dynamic_batchsize: + min_input_shape = scale_batch_size(input_shape, dynamic_batchsize[0]) + opt_input_shape = scale_batch_size(input_shape, dynamic_batchsize[1]) + max_input_shape = scale_batch_size(input_shape, dynamic_batchsize[2]) + else: + min_input_shape = opt_input_shape = max_input_shape = input_shape + return min_input_shape, opt_input_shape, max_input_shape + + +def get_dynamic_axes(profiles): + """ + This method calculates dynamic_axes to use in onnx.export(). + Args: + profiles: [[min,opt,max],...] list of profile dimensions + """ + dynamic_axes: dict[str, list[int]] = {} + if not profiles: + return dynamic_axes + for profile in profiles: + for key in profile: + axes = [] + vals = profile[key] + for i in range(len(vals[0])): + if vals[0][i] != vals[2][i]: + axes.append(i) + if len(axes) > 0: + dynamic_axes[key] = axes + return dynamic_axes + + +def cuassert(cuda_ret): + """ + Error reporting method for CUDA calls. + Args: + cuda_ret: CUDA return code. + """ + err = cuda_ret[0] + if err != 0: + raise RuntimeError(f"CUDA ERROR: {err}") + if len(cuda_ret) > 1: + return cuda_ret[1] + return None + + +class ShapeError(Exception): + """ + Exception class to report errors from setting TRT plan input shapes + """ + + pass + + +class TRTEngine: + """ + An auxiliary class to implement running of TRT optimized engines + + """ + + def __init__(self, plan_path, logger=None): + """ + Loads serialized engine, creates execution context and activates it + Args: + plan_path: path to serialized TRT engine. + logger: optional logger object + """ + self.plan_path = plan_path + self.logger = logger or getLogger("trt_compile") + self.logger.info(f"Loading TensorRT engine: {self.plan_path}") + self.engine = engine_from_bytes(bytes_from_path(self.plan_path)) + self.tensors = OrderedDict() + self.cuda_graph_instance = None # cuda graph + self.context = self.engine.create_execution_context() + self.input_names = [] + self.output_names = [] + self.dtypes = [] + self.cur_profile = 0 + self.input_table = {} + dtype_dict = trt_to_torch_dtype_dict() + for idx in range(self.engine.num_io_tensors): + binding = self.engine[idx] + if self.engine.get_tensor_mode(binding) == trt.TensorIOMode.INPUT: + self.input_names.append(binding) + elif self.engine.get_tensor_mode(binding) == trt.TensorIOMode.OUTPUT: + self.output_names.append(binding) + dtype = dtype_dict[self.engine.get_tensor_dtype(binding)] + self.dtypes.append(dtype) + self.logger.info( + f"Loaded TensorRT engine: {self.plan_path}.\nInputs: {self.input_names}\nOutputs: {self.output_names}" + ) + + def allocate_buffers(self, device): + """ + Allocates outputs to run TRT engine + Args: + device: GPU device to allocate memory on + """ + ctx = self.context + + for i, binding in enumerate(self.output_names): + shape = list(ctx.get_tensor_shape(binding)) + if binding not in self.tensors or list(self.tensors[binding].shape) != shape: + t = torch.empty(shape, dtype=self.dtypes[i], device=device).contiguous() + self.tensors[binding] = t + ctx.set_tensor_address(binding, t.data_ptr()) + + def set_inputs(self, feed_dict, stream): + """ + Sets input bindings for TRT engine according to feed_dict + Args: + feed_dict: a dictionary [str->Tensor] + stream: CUDA stream to use + """ + e = self.engine + ctx = self.context + + last_profile = self.cur_profile + + def try_set_inputs(): + for binding in self.input_names: + t = feed_dict.get(self.input_table[binding], None) + if t is not None: + t = t.contiguous() + shape = t.shape + ctx.set_input_shape(binding, shape) + ctx.set_tensor_address(binding, t.data_ptr()) + + while True: + try: + try_set_inputs() + break + except ShapeError: + next_profile = (self.cur_profile + 1) % e.num_optimization_profiles + if next_profile == last_profile: + raise + self.cur_profile = next_profile + ctx.set_optimization_profile_async(self.cur_profile, stream) + except Exception: + raise + left = ctx.infer_shapes() + assert len(left) == 0 + + def infer(self, stream, use_cuda_graph=False): + """ + Runs TRT engine. + Args: + stream: CUDA stream to run on + use_cuda_graph: use CUDA graph. Note: requires all inputs to be the same GPU memory between calls. + """ + if use_cuda_graph: + if self.cuda_graph_instance is not None: + cuassert(cudart.cudaGraphLaunch(self.cuda_graph_instance, stream)) + cuassert(cudart.cudaStreamSynchronize(stream)) + else: + # do inference before CUDA graph capture + noerror = self.context.execute_async_v3(stream) + if not noerror: + raise ValueError("ERROR: inference failed.") + # capture cuda graph + cuassert( + cudart.cudaStreamBeginCapture( + stream, cudart.cudaStreamCaptureMode.cudaStreamCaptureModeThreadLocal + ) + ) + self.context.execute_async_v3(stream) + graph = cuassert(cudart.cudaStreamEndCapture(stream)) + self.cuda_graph_instance = cuassert(cudart.cudaGraphInstantiate(graph, 0)) + self.logger.info("CUDA Graph captured!") + else: + noerror = self.context.execute_async_v3(stream) + cuassert(cudart.cudaStreamSynchronize(stream)) + if not noerror: + raise ValueError("ERROR: inference failed.") + + return self.tensors + + +def make_tensor(d): + """ + Creates a new tensor from d, returns d if d is already a tensor + """ + return d if isinstance(d, torch.Tensor) else torch.tensor(d).cuda() + + +def unroll_input(input_names, input_example): + """ + Simulates list/tuple unrolling during ONNX export + """ + unrolled_input = {} + for name in input_names: + val = input_example[name] + if val is not None: + if isinstance(val, list) or isinstance(val, tuple): + for i in range(len(val)): + unrolled_input[f"{name}_{i}"] = make_tensor(val[i]) + else: + unrolled_input[name] = make_tensor(val) + return unrolled_input + + +def parse_groups( + ret: List[torch.Tensor], output_lists: List[List[int]] +) -> Tuple[Union[torch.Tensor, List[torch.Tensor]], ...]: + """ + Implements parsing of 'output_lists' arg of trt_compile(). + + Args: + ret: plain list of Tensors + + output_lists: list of output group sizes: to form some Lists/Tuples out of 'ret' List, this will be a list + of group dimensions, like [[], [5], [-1]] for returning Tensor, list of 5 items and dynamic list. + Format: [[group_n] | [], ...] + [] or group_n == 0 : next output from ret is a scalar + group_n > 0 : next output from ret is a list of group_n length + group_n == -1: next output is a dynamic list. This entry can be at any + position in output_lists, but can appear only once. + Returns: + Tuple of Union[torch.Tensor, List[torch.Tensor]], according to the grouping in output_lists + + """ + groups: Tuple[Union[torch.Tensor, List[torch.Tensor]], ...] = tuple() + cur = 0 + for l in range(len(output_lists)): + gl = output_lists[l] + assert len(gl) == 0 or len(gl) == 1 + if len(gl) == 0 or gl[0] == 0: + groups = (*groups, ret[cur]) + cur = cur + 1 + elif gl[0] > 0: + groups = (*groups, ret[cur : cur + gl[0]]) + cur = cur + gl[0] + elif gl[0] == -1: + rev_groups: Tuple[Union[torch.Tensor, List[torch.Tensor]], ...] = tuple() + rcur = len(ret) + for rl in range(len(output_lists) - 1, l, -1): + rgl = output_lists[rl] + assert len(rgl) == 0 or len(rgl) == 1 + if len(rgl) == 0 or rgl[0] == 0: + rcur = rcur - 1 + rev_groups = (*rev_groups, ret[rcur]) + elif rgl[0] > 0: + rcur = rcur - rgl[0] + rev_groups = (*rev_groups, ret[rcur : rcur + rgl[0]]) + else: + raise ValueError("Two -1 lists in output") + groups = (*groups, ret[cur:rcur], *rev_groups[::-1]) + break + return groups + + +class TrtCompiler: + """ + This class implements: + - TRT lazy persistent export + - Running TRT with optional fallback to Torch + (for TRT engines with limited profiles) + """ + + def __init__( + self, + model, + plan_path, + precision="fp16", + method="onnx", + input_names=None, + output_names=None, + output_lists=None, + export_args=None, + build_args=None, + input_profiles=None, + dynamic_batchsize=None, + use_cuda_graph=False, + timestamp=None, + fallback=False, + forward_override=None, + logger=None, + ): + """ + Initialization method: + Tries to load persistent serialized TRT engine + Saves its arguments for lazy TRT build on first forward() call + Args: + model: Model to "wrap". + plan_path : Path where to save persistent serialized TRT engine. + precision: TRT builder precision o engine model. Should be 'fp32'|'tf32'|'fp16'|'bf16'. + method: One of 'onnx'|'torch_trt'. + Default is 'onnx' (torch.onnx.export()->TRT). This is the most stable and efficient option. + 'torch_trt' may not work for some nets. Also AMP must be turned off for it to work. + input_names: Optional list of input names. If None, will be read from the function signature. + output_names: Optional list of output names. Note: If not None, patched forward() will return a dictionary. + output_lists: Optional list of output group sizes: when forward() returns Lists/Tuples, this will be a list + of their dimensions, like [[], [5], [-1]] for Tensor, list of 5 items and dynamic list. + export_args: Optional args to pass to export method. See onnx.export() and Torch-TensorRT docs for details. + build_args: Optional args to pass to TRT builder. See polygraphy.Config for details. + input_profiles: Optional list of profiles for TRT builder and ONNX export. + Each profile is a map of the form : {"input id" : [min_shape, opt_shape, max_shape], ...}. + dynamic_batchsize: A sequence with three elements to define the input batch size range for the model to be + converted. Should be a sequence like [MIN_BATCH, OPT_BATCH, MAX_BATCH]. + [note]: If neither input_profiles nor dynamic_batchsize specified, static shapes will be used. + use_cuda_graph: Use CUDA Graph for inference. Note: inputs have to be the same GPU memory between calls! + timestamp: Optional timestamp to rebuild TRT engine (e.g. if config file changes). + fallback: Allow to fall back to Pytorch when TRT inference fails (e.g, shapes exceed max profile). + """ + + method_vals = ["onnx", "torch_trt"] + if method not in method_vals: + raise ValueError(f"trt_compile(): 'method' should be one of {method_vals}, got: {method}.") + precision_vals = ["fp32", "tf32", "fp16", "bf16"] + if precision not in precision_vals: + raise ValueError(f"trt_compile(): 'precision' should be one of {precision_vals}, got: {precision}.") + + self.plan_path = plan_path + self.precision = precision + self.method = method + self.return_dict = output_names is not None + self.output_names = output_names or [] + self.output_lists = output_lists or [] + self.profiles = input_profiles or [] + self.dynamic_batchsize = dynamic_batchsize + self.export_args = export_args or {} + self.build_args = build_args or {} + self.engine: TRTEngine | None = None + self.use_cuda_graph = use_cuda_graph + self.fallback = fallback + self.disabled = False + + self.logger = logger or getLogger("trt_compile") + self.argspec = inspect.getfullargspec(model.forward) + # Normally we read input_names from forward() but can be overridden + if input_names is None: + input_names = self.argspec.args[1:] + self.defaults = {} + if self.argspec.defaults is not None: + for i in range(len(self.argspec.defaults)): + d = self.argspec.defaults[-i - 1] + if d is not None: + d = make_tensor(d) + self.defaults[self.argspec.args[-i - 1]] = d + + self.input_names = input_names + self.old_forward = model.forward + + # Force engine rebuild if older than the timestamp + if timestamp is not None and os.path.exists(self.plan_path) and os.path.getmtime(self.plan_path) < timestamp: + os.remove(self.plan_path) + + def _inputs_to_dict(self, input_example): + trt_inputs = {} + for i, inp in enumerate(input_example): + input_name = self.input_names[i] + trt_inputs[input_name] = inp + return trt_inputs + + def _load_engine(self): + """ + Loads TRT plan from disk and activates its execution context. + """ + try: + self.engine = TRTEngine(self.plan_path, self.logger) + # Make sure we have names correct + input_table = {} + for name in self.engine.input_names: + if name.startswith("__") and name not in self.input_names: + orig_name = name[2:] + else: + orig_name = name + input_table[name] = orig_name + self.engine.input_table = input_table + self.logger.info(f"Engine loaded, inputs:{self.engine.input_table}") + except Exception as e: + self.logger.info(f"Exception while loading the engine:\n{e}") + + def forward(self, model, argv, kwargs): + """ + Main forward method: + Builds TRT engine if not available yet. + Tries to run TRT engine + If exception thrown and self.callback==True: falls back to original Pytorch + + Args: Passing through whatever args wrapped module's forward() has + Returns: Passing through wrapped module's forward() return value(s) + + """ + args = self.defaults + args.update(kwargs) + if len(argv) > 0: + args.update(self._inputs_to_dict(argv)) + + if self.engine is None and not self.disabled: + # Restore original forward for export + new_forward = model.forward + model.forward = self.old_forward + try: + self._load_engine() + if self.engine is None: + build_args = args.copy() + with torch.no_grad(): + self._build_and_save(model, build_args) + # This will reassign input_names from the engine + self._load_engine() + assert self.engine is not None + except Exception as e: + if self.fallback: + self.logger.info(f"Failed to build engine: {e}") + self.disabled = True + else: + raise e + if not self.disabled and not self.fallback: + # Delete all parameters + for param in model.parameters(): + del param + # Call empty_cache to release GPU memory + torch.cuda.empty_cache() + # restore TRT hook + model.forward = new_forward + # Run the engine + try: + if self.engine is not None: + # forward_trt is not thread safe as we do not use per-thread execution contexts + with lock_sm: + device = torch.cuda.current_device() + stream = torch.cuda.Stream(device=device) + self.engine.set_inputs(unroll_input(self.input_names, args), stream.cuda_stream) + self.engine.allocate_buffers(device=device) + # Need this to synchronize with Torch stream + stream.wait_stream(torch.cuda.current_stream()) + ret = self.engine.infer(stream.cuda_stream, use_cuda_graph=self.use_cuda_graph) + # if output_names is not None, return dictionary + if not self.return_dict: + ret = list(ret.values()) + if self.output_lists: + ret = parse_groups(ret, self.output_lists) + elif len(ret) == 1: + ret = ret[0] + return ret + except Exception as e: + if self.fallback: + self.logger.info(f"Exception: {e}\nFalling back to Pytorch ...") + else: + raise e + return self.old_forward(*argv, **kwargs) + + def _onnx_to_trt(self, onnx_path): + """ + Builds TRT engine from ONNX file at onnx_path and saves to self.plan_path + """ + + profiles = [] + for profile in self.profiles: + p = Profile() + for id, val in profile.items(): + p.add(id, min=val[0], opt=val[1], max=val[2]) + profiles.append(p) + + build_args = self.build_args.copy() + build_args["tf32"] = self.precision != "fp32" + if self.precision == "fp16": + build_args["fp16"] = True + elif self.precision == "bf16": + build_args["bf16"] = True + + self.logger.info(f"Building TensorRT engine for {onnx_path}: {self.plan_path}") + network = network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]) + return engine_bytes_from_network(network, config=CreateConfig(profiles=profiles, **build_args)) + + def _build_and_save(self, model, input_example): + """ + If TRT engine is not ready, exports model to ONNX, + builds TRT engine and saves serialized TRT engine to the disk. + Args: + input_example: passed to onnx.export() + """ + + if self.engine is not None: + return + + export_args = self.export_args + engine_bytes = None + + add_casts_around_norms(model) + replace_for_export(model) + + if self.method == "torch_trt": + enabled_precisions = [torch.float32] + if self.precision == "fp16": + enabled_precisions.append(torch.float16) + elif self.precision == "bf16": + enabled_precisions.append(torch.bfloat16) + inputs = list(input_example.values()) + + def get_torch_trt_input(input_shape, dynamic_batchsize): + min_input_shape, opt_input_shape, max_input_shape = get_profile_shapes(input_shape, dynamic_batchsize) + return torch_tensorrt.Input( + min_shape=min_input_shape, opt_shape=opt_input_shape, max_shape=max_input_shape + ) + + tt_inputs = [get_torch_trt_input(i.shape, self.dynamic_batchsize) for i in inputs] + engine_bytes = torch_tensorrt.convert_method_to_trt_engine( + model, + "forward", + arg_inputs=tt_inputs, + enabled_precisions=enabled_precisions, + **export_args, + ) + else: + dbs = self.dynamic_batchsize + if dbs: + if len(self.profiles) > 0: + raise ValueError("ERROR: Both dynamic_batchsize and input_profiles set for TrtCompiler!") + if len(dbs) != 3: + raise ValueError("dynamic_batchsize has to have len ==3 ") + profile = {} + for id, val in input_example.items(): + + def add_profile(id, val): + sh = val.shape + if len(sh) > 0: + sh = sh[1:] + profile[id] = [[dbs[0], *sh], [dbs[1], *sh], [dbs[2], *sh]] + + if isinstance(val, list) or isinstance(val, tuple): + for i in range(len(val)): + add_profile(f"{id}_{i}", val[i]) + elif isinstance(val, torch.Tensor): + add_profile(id, val) + self.profiles = [profile] + + self.dynamic_axes = get_dynamic_axes(self.profiles) + + if len(self.dynamic_axes) > 0: + export_args.update({"dynamic_axes": self.dynamic_axes}) + + # Use temporary directory for easy cleanup in case of external weights + with tempfile.TemporaryDirectory() as tmpdir: + if export_args.get("dynamo", False): + input_names = None + else: + input_names = list(unroll_input(self.input_names, input_example).keys()) + onnx_path = str(Path(tmpdir) / "model.onnx") + self.logger.info( + f"Exporting to {onnx_path}:\n" + + f"output_names={self.output_names}\ninput_names={self.input_names}\nexport args: {export_args}" + ) + torch.onnx.export( + model, + (input_example,), + onnx_path, + input_names=input_names, + output_names=self.output_names, + **export_args, + ) + if polygraphy_imported: + from polygraphy.backend.onnx.loader import fold_constants, onnx_from_path, save_onnx + + onnx_model = fold_constants(onnx_from_path(onnx_path), size_threshold=16 * 1000 * 1000) + save_onnx(onnx_model, onnx_path) + self.logger.info("Export to ONNX successful.") + engine_bytes = self._onnx_to_trt(onnx_path) + if engine_bytes: + open(self.plan_path, "wb").write(engine_bytes) + + +def trt_forward(self, *argv, **kwargs): + """ + Patch function to replace original model's forward() with. + Redirects to TrtCompiler.forward() + """ + return self._trt_compiler.forward(self, argv, kwargs) + + +def trt_compile( + model: torch.nn.Module, + base_path: str, + args: Dict[str, Any] | None = None, + submodule: Union[str, List[str]] | None = None, + logger: Any | None = None, +) -> torch.nn.Module: + """ + Instruments model or submodule(s) with TrtCompiler and replaces its forward() with TRT hook. + Note: TRT 10.3 is recommended for best performance. Some nets may even fail to work with TRT 8.x + Args: + model: module to patch with TrtCompiler object. + base_path: TRT plan(s) saved to f"{base_path}[.{submodule}].plan" path. + dirname(base_path) must exist, base_path does not have to. + If base_path does point to existing file (e.g. associated checkpoint), + that file becomes a dependency - its mtime is added to args["timestamp"]. + args: Optional dict : unpacked and passed to TrtCompiler() - see TrtCompiler above for details. + submodule: Optional hierarchical id(s) of submodule to patch, e.g. ['image_decoder.decoder'] + If None, TrtCompiler patch is applied to the whole model. + Otherwise, submodule (or list of) is being patched. + logger: Optional logger for diagnostics. + Returns: + Always returns same model passed in as argument. This is for ease of use in configs. + """ + + default_args: Dict[str, Any] = { + "method": "onnx", + "precision": "fp16", + "build_args": {"builder_optimization_level": 5, "precision_constraints": "obey"}, + } + + default_args.update(args or {}) + args = default_args + + if trt_imported and polygraphy_imported and torch.cuda.is_available(): + # if "path" filename point to existing file (e.g. checkpoint) + # it's also treated as dependency + if os.path.exists(base_path): + timestamp = int(os.path.getmtime(base_path)) + if "timestamp" in args: + timestamp = max(int(args["timestamp"]), timestamp) + args["timestamp"] = timestamp + + def wrap(model, path): + if not hasattr(model, "_trt_compiler"): + model.orig_forward = model.forward + wrapper = TrtCompiler(model, path + ".plan", logger=logger, **args) + model._trt_compiler = wrapper + model.forward = MethodType(trt_forward, model) + + def find_sub(parent, submodule): + idx = submodule.find(".") + # if there is "." in name, call recursively + if idx != -1: + parent_name = submodule[:idx] + parent = getattr(parent, parent_name) + submodule = submodule[idx + 1 :] + return find_sub(parent, submodule) + return parent, submodule + + if submodule is not None: + if isinstance(submodule, str): + submodule = [submodule] + for s in submodule: + parent, sub = find_sub(model, s) + wrap(getattr(parent, sub), base_path + "." + s) + else: + wrap(model, base_path) + else: + logger = logger or getLogger("trt_compile") + logger.warning("TensorRT and/or polygraphy packages are not available! trt_compile() has no effect.") + + return model diff --git a/nemo/export/tensorrt_llm.py b/nemo/export/tensorrt_llm.py index 08b0b822cad4..c4d1a645b244 100644 --- a/nemo/export/tensorrt_llm.py +++ b/nemo/export/tensorrt_llm.py @@ -30,10 +30,11 @@ import wrapt from tensorrt_llm._utils import numpy_to_torch +from nemo.collections.nlp.parts.utils_funcs import torch_dtype_from_precision from nemo.deploy import ITritonDeployable from nemo.export.tarutils import TarPath, unpack_tarball from nemo.export.trt_llm.converter.model_converter import model_to_trtllm_ckpt -from nemo.export.trt_llm.converter.model_to_trt_llm_ckpt import dist_model_to_trt_llm_ckpt +from nemo.export.trt_llm.converter.model_to_trt_llm_ckpt import dist_model_to_trt_llm_ckpt, get_layer_prefix from nemo.export.trt_llm.converter.utils import init_model_parallel_from_nemo from nemo.export.trt_llm.nemo_ckpt_loader.nemo_file import ( build_tokenizer, @@ -65,6 +66,8 @@ @wrapt.decorator def noop_decorator(func): + """No op decorator""" + def wrapper(*args, **kwargs): return func(*args, **kwargs) @@ -80,6 +83,7 @@ def wrapper(*args, **kwargs): use_pytriton = False +# pylint: disable=line-too-long class TensorRTLLM(ITritonDeployable): """ Exports nemo checkpoints to TensorRT-LLM and run fast inference. @@ -180,6 +184,8 @@ def export( reduce_fusion: bool = True, fp8_quantized: Optional[bool] = None, fp8_kvcache: Optional[bool] = None, + gather_context_logits: Optional[bool] = False, + gather_generation_logits: Optional[bool] = False, ): """ Exports nemo checkpoints to TensorRT-LLM. @@ -218,6 +224,8 @@ def export( reduce_fusion (bool): enables fusing extra kernels after custom TRT-LLM allReduce fp8_quantized (Optional[bool]): enables exporting to FP8 TRT-LLM checkpoints. If not set, autodetects the type. fp8_kvcache (Optional[bool]): enables FP8 KV-cache quantization. If not set, autodetects the type. + gather_context_logits (Optional[bool]): if True, enables gather_context_logits while building trtllm engine. Default: False + gather_generation_logits (Optional[bool]): if True, enables gather_generation_logits while building trtllm engine. Default: False """ if n_gpus is not None: warnings.warn( @@ -343,43 +351,14 @@ def export( DEFAULT_CONVERSION_DICT, ) from megatron.core.export.trtllm.trtllm_helper import TRTLLMHelper - from megatron.core.transformer.transformer_config import TransformerConfig from tensorrt_llm.layers import MoeConfig - def get_transformer_config(nemo_model_config): - normalization = nemo_model_config.get('normalization', 'layernorm') - transformer_config_normalization = 'LayerNorm' - layernorm_zero_centered_gamma = False - if normalization == 'layernorm1p': - layernorm_zero_centered_gamma = True - elif normalization == 'rmsnorm': - transformer_config_normalization = 'RMSNorm' - - conf = TransformerConfig( - num_layers=nemo_model_config.get('num_layers'), - moe_router_topk=nemo_model_config.get('moe_router_topk', 0), - num_attention_heads=nemo_model_config.get('num_attention_heads'), - num_query_groups=nemo_model_config.get( - 'num_query_groups', nemo_model_config['num_attention_heads'] - ), - kv_channels=nemo_model_config.get("kv_channels", None), - hidden_size=nemo_model_config.get('hidden_size'), - ffn_hidden_size=nemo_model_config.get('ffn_hidden_size'), - layernorm_epsilon=nemo_model_config.get('layernorm_epsilon'), - add_bias_linear=nemo_model_config.get('bias'), - num_moe_experts=nemo_model_config.get('num_moe_experts', None), - normalization=transformer_config_normalization, - layernorm_zero_centered_gamma=layernorm_zero_centered_gamma, - ) - - return conf - # We build the transformer config using the nemo model config. - transformer_config = get_transformer_config(model_configs) + transformer_config = self.get_transformer_config(model_configs) input_model_type = getattr(ModelType, model_type) # MCore export supports some default conversion dictionaries - mcore_model_conversion_dict = DEFAULT_CONVERSION_DICT[input_model_type] + mcore_model_conversion_dict = DEFAULT_CONVERSION_DICT # All Mcore conversion dicts start with "decoder.layers.4.blah.blah" , while nemo models start with "model.decoder.layers.4.blahblah". so we append model. to the keys nemo_model_conversion_dict = { f'model.{key}': value for key, value in mcore_model_conversion_dict.items() @@ -495,14 +474,19 @@ def get_transformer_config(nemo_model_config): multiple_profiles=multiple_profiles, gpt_attention_plugin=gpt_attention_plugin, gemm_plugin=gemm_plugin, + gather_context_logits=gather_context_logits, + gather_generation_logits=gather_generation_logits, ) tokenizer_path = os.path.join(nemo_export_dir, "tokenizer.model") tokenizer_path_nemo2 = os.path.join(nemo_export_dir, "nemo_context") + vocab_path = os.path.join(nemo_export_dir, "vocab.json") if os.path.exists(tokenizer_path): shutil.copy(tokenizer_path, self.model_dir) elif os.path.exists(tokenizer_path_nemo2): shutil.copytree(tokenizer_path_nemo2, Path(self.model_dir) / "nemo_context") + elif os.path.exists(vocab_path): + shutil.copy(vocab_path, os.path.join(self.model_dir, "vocab.json")) else: self.tokenizer.save_pretrained(os.path.join(self.model_dir, 'huggingface_tokenizer')) @@ -518,6 +502,34 @@ def get_transformer_config(nemo_model_config): if load_model: self._load() + def get_transformer_config(self, nemo_model_config): + """Given nemo model config get transformer config""" + from megatron.core.transformer.transformer_config import TransformerConfig + + normalization = nemo_model_config.get('normalization', 'layernorm') + transformer_config_normalization = 'LayerNorm' + layernorm_zero_centered_gamma = False + if normalization == 'layernorm1p': + layernorm_zero_centered_gamma = True + elif normalization == 'rmsnorm': + transformer_config_normalization = 'RMSNorm' + + conf = TransformerConfig( + num_layers=nemo_model_config.get('num_layers'), + moe_router_topk=nemo_model_config.get('moe_router_topk', 0), + num_attention_heads=nemo_model_config.get('num_attention_heads'), + num_query_groups=nemo_model_config.get('num_query_groups', nemo_model_config['num_attention_heads']), + kv_channels=nemo_model_config.get("kv_channels", None), + hidden_size=nemo_model_config.get('hidden_size'), + ffn_hidden_size=nemo_model_config.get('ffn_hidden_size'), + layernorm_epsilon=nemo_model_config.get('layernorm_epsilon'), + add_bias_linear=nemo_model_config.get('bias'), + num_moe_experts=nemo_model_config.get('num_moe_experts', None), + normalization=transformer_config_normalization, + layernorm_zero_centered_gamma=layernorm_zero_centered_gamma, + ) + return conf + def convert_to_safe_tensors( self, nemo_checkpoint_path: str, @@ -530,6 +542,7 @@ def convert_to_safe_tensors( use_embedding_sharing: bool = False, dtype: str = "bfloat16", ): + """Convert to safe tensor""" gpus_per_node = tensor_parallelism_size if gpus_per_node is None else gpus_per_node if Path(self.model_dir).exists(): @@ -595,6 +608,167 @@ def convert_to_safe_tensors( if tensorrt_llm.mpi_world_size() > 1: tensorrt_llm.mpi_barrier() + def gather_and_reshard_model(self, model_config, model, storage_dtype): + """ + Accumulate all vp model chunks together, and reshard model (i.e) gather all pp ranks + if required and return the final model state dict + """ + + def _get_layer_index(split_key): + for index, key in enumerate(split_key): + if key == "layers": + return index + 1 + raise ValueError(f"Unknown layer name format: {split_key}") + + def rename_layer_num(param_name, layer_num): + split_key = param_name.split(".") + layer_index = int(_get_layer_index(split_key)) + split_key[layer_index] = str(layer_num) + return ".".join(split_key) + + def get_layer_num(param_name): + split_key = param_name.split(".") + layer_index = int(_get_layer_index(split_key)) + return int(split_key[layer_index]) + + from megatron.core import parallel_state + + tp_size = parallel_state.get_tensor_model_parallel_world_size() + pp_rank = parallel_state.get_pipeline_model_parallel_rank() + pp_first_rank = parallel_state.get_pipeline_model_parallel_first_rank() + pp_last_rank = parallel_state.get_pipeline_model_parallel_last_rank() + pp_size = parallel_state.get_pipeline_model_parallel_world_size() + pp_group = parallel_state.get_pipeline_model_parallel_group() + vp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size() + if not vp_size: + vp_size = 1 + + inference_tp_size = self.tp_size + inference_pp_size = self.pp_size + reshard_model = False + if inference_tp_size != tp_size or inference_pp_size != pp_size: + LOGGER.info("Training/Generation model parallelism resharding enabled") + if inference_pp_size == 1 and pp_size > 1 and inference_tp_size == tp_size: + reshard_model = True + else: + raise NotImplementedError( + f"NeMo currently only supports PP>1 -> PP=1 resharding, other types of resharding will come in future releases." + ) + + num_layers = model_config["num_layers"] + layers_per_pp = num_layers // pp_size + layers_per_chunk = layers_per_pp // vp_size + + tl_params = {} + model_level_params = {} + if vp_size > 1: # consolidate params across model chunks + for idx, model_chunk in enumerate(model): + for key, val in model_chunk.state_dict().items(): + if torch.is_tensor(val): + if 'layers' in key: + key2 = rename_layer_num(key, get_layer_num(key) + idx * pp_size * layers_per_chunk) + tl_params[key2] = val + else: + model_level_params[key] = val + else: + for key, val in model.state_dict().items(): + if torch.is_tensor(val): + if 'decoder.layers' in key: + tl_params[key] = val + else: + model_level_params[key] = val + + if vp_size > 1 or reshard_model: + # gather layers across pp ranks + gathered_params = {} + for key, val in tl_params.items(): + weight_list = [torch.zeros_like(val) for _ in range(pp_size)] + torch.distributed.all_gather(weight_list, val, group=pp_group) + for idx in range(pp_size): + layer_num = get_layer_num(key) + idx * layers_per_chunk + key2 = rename_layer_num(key, layer_num) + if not reshard_model: # Save only layers of 1 single PP stage + layers_start = layers_per_pp * pp_rank + layers_end = layers_per_pp * (pp_rank + 1) - 1 + if layer_num >= layers_start and layer_num <= layers_end: + key2 = rename_layer_num(key, layer_num % layers_per_pp) + gathered_params[key2] = weight_list[idx] + else: + gathered_params[key2] = weight_list[idx] + tl_params = gathered_params + + model_state_dict = model_level_params + model_state_dict.update(tl_params) + + def get_tensor_if_available(key, pp_src_idx, group): + tensor = model_state_dict.get(key) + if tensor is not None: + tensor_shape = [tensor.shape] + else: + tensor_shape = [None] + + torch.distributed.broadcast_object_list(tensor_shape, pp_src_idx, group=group) + + if tensor_shape[0] is None: + return None + if torch.distributed.get_rank() != pp_src_idx: + tensor = torch.empty(tensor_shape[0], dtype=storage_dtype).cuda() + + torch.distributed.broadcast(tensor.contiguous(), pp_src_idx, group=pp_group) + return tensor + + if reshard_model: + key = 'decoder.final_layernorm.weight' + tensor = get_tensor_if_available(key, pp_last_rank, pp_group) + if tensor is not None: + model_state_dict[key] = tensor + + key = 'decoder.final_layernorm.bias' + tensor = get_tensor_if_available(key, pp_last_rank, pp_group) + if tensor is not None: + model_state_dict[key] = tensor + + key = 'embedding.word_embeddings.weight' + tensor = get_tensor_if_available(key, pp_first_rank, pp_group) + if tensor is not None: + model_state_dict[key] = tensor + + key = 'output_layer.weight' + tensor = get_tensor_if_available(key, pp_last_rank, pp_group) + if tensor is not None: + model_state_dict[key] = tensor + + return model_state_dict + + def get_input_dtype(self, storage_dtype): + """ + Return mcore export dtype given torch dtype + """ + from megatron.core.export.data_type import DataType + + if storage_dtype == torch.bfloat16: + return DataType.bfloat16 + elif storage_dtype == torch.float32: + return DataType.float32 + elif storage_dtype == torch.float16: + return DataType.float16 + + def get_nemo_to_trtllm_conversion_dict(self, model_state_dict): + """MCore export supports some default conversion dictionaries + All Mcore conversion dicts start with "decoder.layers.4.blah.blah" , while nemo models sometimes start with "model.decoder.layers.4.blahblah". so we append model prefix. to the keys + """ + from megatron.core.export.trtllm.model_to_trllm_mapping.default_conversion_dict import DEFAULT_CONVERSION_DICT + + model_prefix, _ = get_layer_prefix(layer_names=model_state_dict.keys(), is_mcore=True) + + nemo_model_conversion_dict = {} + for key, value in DEFAULT_CONVERSION_DICT.items(): + if 'layers' in key: + nemo_model_conversion_dict[f'{model_prefix}.{key}'] = value + else: + nemo_model_conversion_dict[key] = value + return nemo_model_conversion_dict + def build( self, model, @@ -607,6 +781,7 @@ def build( max_batch_size: int = 4, use_refit: bool = True, reshard_model: bool = False, + use_mcore_path: bool = True, ): """ Convert a model parallel nemo model to TensorRT-LLM. @@ -621,31 +796,103 @@ def build( if self.dp_size > 1: self.model_dir = os.path.join(self.model_dir, f"dp_rank{self.dp_rank}") - weights, model_config = model_to_trtllm_ckpt( - model=model, - nemo_model_config=model_config, - nemo_export_dir=self.model_dir, - decoder_type=model_type, - tensor_parallel_size=self.tp_size, - pipeline_parallel_size=self.pp_size, - gpus_per_node=gpus_per_node, - use_parallel_embedding=True, - use_distributed_convert=True, - model_parallel_rank=self.mp_rank, - vocab_size=self.tokenizer.vocab_size, - ) + if use_mcore_path: + from megatron.core.export.model_type import ModelType + from megatron.core.export.trtllm.trtllm_helper import TRTLLMHelper + from tensorrt_llm.layers import MoeConfig + + storage_dtype = torch_dtype_from_precision(model_config.precision) + model_state_dict = self.gather_and_reshard_model(model_config, model, storage_dtype) + # We build the transformer config using the nemo model config. + transformer_config = self.get_transformer_config(model_config) + input_model_type = getattr(ModelType, model_type) + + nemo_model_conversion_dict = self.get_nemo_to_trtllm_conversion_dict(model_state_dict) + + self.trtllm_helper = TRTLLMHelper( + transformer_config=transformer_config, + model_type=input_model_type, + trtllm_conversion_dict=nemo_model_conversion_dict, + position_embedding_type=model_config.get('position_embedding_type'), + max_position_embeddings=model_config.get('max_position_embeddings'), + rotary_percentage=model_config.get('rotary_percentage', 1.0), + rotary_base=model_config.get('rotary_base', 10000), + moe_tp_mode=model_config.get('moe_tp_mode', 2), + multi_query_mode=model_config.get("multi_query_mode", False), + activation=model_config.get('activation', "gelu"), + seq_len_interpolation_factor=model_config.get("seq_len_interpolation_factor"), + moe_renorm_mode=model_config.get( + 'moe_renorm_mode', MoeConfig.ExpertScaleNormalizationMode.RENORMALIZE + ), + share_embeddings_and_output_weights=model_config.get("share_embeddings_and_output_weights", False), + ) + + input_dtype = self.get_input_dtype(storage_dtype) + + trtllm_model_weights_list, trtllm_model_config_list = ( + self.trtllm_helper.get_trtllm_pretrained_config_and_model_weights( + model_state_dict=model_state_dict, + dtype=input_dtype, + state_dict_split_by_layer_numbers=True, + on_device_distributed_conversion=True, + vocab_size=self.tokenizer.vocab_size, + gpus_per_node=gpus_per_node, + ) + ) + trtllm_model_config = trtllm_model_config_list[0] + trtllm_model_weights = trtllm_model_weights_list[0] + + if reshard_model: + assert self.pp_size == 1, 'Reshard is true, but pp size is not one' + # MCORE Export will use parallel_state to determine pp . + # Since we reshard to pp = 1, we need to modify the config and mapping + world_size = self.tp_size * self.pp_size + trtllm_model_config.pp_size = self.pp_size + trtllm_model_config.world_size = world_size + trtllm_model_config.mapping = tensorrt_llm.Mapping( + world_size=world_size, + rank=self.mp_rank, + tp_size=self.tp_size, + pp_size=self.pp_size, + ) + + engine = self.trtllm_helper.build_and_save_engine( + max_input_len=max_input_len, + max_output_len=max_output_len, + max_seq_len=max_input_len + max_output_len, + max_batch_size=max_batch_size, + trtllm_model_config=trtllm_model_config, + trtllm_model_weights=trtllm_model_weights, + engine_dir=self.model_dir, + use_refit=use_refit, + ) + else: + weights, model_config = model_to_trtllm_ckpt( + model=model, + nemo_model_config=model_config, + nemo_export_dir=self.model_dir, + decoder_type=model_type, + tensor_parallel_size=self.tp_size, + pipeline_parallel_size=self.pp_size, + gpus_per_node=gpus_per_node, + use_parallel_embedding=True, + use_distributed_convert=True, + model_parallel_rank=self.mp_rank, + vocab_size=self.tokenizer.vocab_size, + ) + + engine = build_and_save_engine( + max_input_len=max_input_len, + max_output_len=max_output_len, + max_seq_len=max_input_len + max_output_len, + max_batch_size=max_batch_size, + model_config=model_config[0], + model_weights=weights[0], + model_dir=self.model_dir, + model_type=model_type, + use_refit=use_refit, + ) - engine = build_and_save_engine( - max_input_len=max_input_len, - max_output_len=max_output_len, - max_seq_len=max_input_len + max_output_len, - max_batch_size=max_batch_size, - model_config=model_config[0], - model_weights=weights[0], - model_dir=self.model_dir, - model_type=model_type, - use_refit=use_refit, - ) torch.distributed.barrier() cfg_path = Path(os.path.join(self.model_dir, f'config_{torch.distributed.get_rank()}.json')) @@ -654,18 +901,33 @@ def build( load_distributed(self.model_dir, self.mp_rank, gpus_per_node) - def refit(self, model, model_config): + def refit(self, model, model_config, use_mcore_path=True): """ Refits an TensorRT engine using an instantiated nemo model. This function should only be used after calling build() """ - weights_dict = dist_model_to_trt_llm_ckpt( - model=model, - nemo_model_config=model_config, - inference_tp_size=self.tp_size, - inference_pp_size=self.pp_size, - tokenizer_vocab_size=self.tokenizer.vocab_size, - ) + weights_dict = None + if use_mcore_path: + storage_dtype = torch_dtype_from_precision(model_config.precision) + + model_state_dict = self.gather_and_reshard_model(model_config, model, storage_dtype) + + nemo_model_conversion_dict = self.get_nemo_to_trtllm_conversion_dict(model_state_dict) + self.trtllm_helper.weights_converter.convert( + model_state_dict=model_state_dict, + tokenizer_vocab_size=self.tokenizer.vocab_size, + trtllm_conversion_dict=nemo_model_conversion_dict, + ) + weights_dict = self.trtllm_helper.weights_converter.trtllm_model_weights + + else: + weights_dict = dist_model_to_trt_llm_ckpt( + model=model, + nemo_model_config=model_config, + inference_tp_size=self.tp_size, + inference_pp_size=self.pp_size, + tokenizer_vocab_size=self.tokenizer.vocab_size, + ) load_distributed(self.model_dir, self.mp_rank, self.gpus_per_node) gc.collect() torch.cuda.empty_cache() @@ -688,6 +950,7 @@ def forward( prompt_embeddings_checkpoint_path: str = None, streaming: bool = False, output_log_probs: bool = False, + output_generation_logits: bool = False, **sampling_kwargs, ): """ @@ -706,6 +969,7 @@ def forward( task_ids (List(str)): list of the task ids for the prompt tables. prompt_embeddings_table (List(float)): prompt embeddings table. prompt_embeddings_checkpoint_path (str): path for the nemo checkpoint for the prompt embedding table. + output_generation_logits (bool): if True returns generation_logits in the outout of generate method. sampling_kwargs: Additional kwargs to set in the SamplingConfig. """ @@ -784,6 +1048,7 @@ def forward( no_repeat_ngram_size=no_repeat_ngram_size, output_log_probs=output_log_probs, multiprocessed_env=multiprocessed_env, + output_generation_logits=output_generation_logits, **sampling_kwargs, ) else: @@ -806,6 +1071,7 @@ def forward( ) def add_prompt_table(self, task_name: str, prompt_embeddings_checkpoint_path: str): + """Add prompt table""" if self.model is None: raise Exception( "A nemo checkpoint should be exported to TensorRT-LLM and " @@ -827,6 +1093,7 @@ def add_prompt_table(self, task_name: str, prompt_embeddings_checkpoint_path: st self._prep_ptuning_table() def remove_prompt_table(self, task_name: str): + """Remove prompt table""" if self.ptuning_tables is not None: for i in range(len(self.ptuning_tables)): if self.ptuning_tables[i]["task_name"] == task_name: @@ -838,11 +1105,13 @@ def remove_prompt_table(self, task_name: str): @property def get_supported_models_list(self): + """Supported model list""" # gpt and gptnext are the same. Keeping the gptnext due to backward compatibility. return ["gpt", "gptnext", "llama", "falcon", "starcoder", "mixtral", "gemma"] @property def get_hidden_size(self): + """Get hidden size""" if self.config is None: return None else: @@ -850,6 +1119,7 @@ def get_hidden_size(self): @property def get_triton_input(self): + """Get triton input""" inputs = ( Tensor(name="prompts", shape=(-1,), dtype=bytes), Tensor(name="max_output_len", shape=(-1,), dtype=np.int_, optional=True), @@ -862,16 +1132,22 @@ def get_triton_input(self): Tensor(name="no_repeat_ngram_size", shape=(-1,), dtype=np.single, optional=True), Tensor(name="task_id", shape=(-1,), dtype=bytes, optional=True), Tensor(name="lora_uids", shape=(-1,), dtype=bytes, optional=True), + Tensor(name="output_generation_logits", shape=(-1,), dtype=np.bool_, optional=False), ) return inputs @property def get_triton_output(self): - outputs = (Tensor(name="outputs", shape=(-1,), dtype=bytes),) + outputs = ( + Tensor(name="outputs", shape=(-1,), dtype=bytes), + Tensor(name="generation_logits", shape=(-1,), dtype=np.single), + ) return outputs @batch def triton_infer_fn(self, **inputs: np.ndarray): + """Triton infer function for streaming""" + output_dict = {} try: infer_input = {"input_texts": str_ndarray2list(inputs.pop("prompts"))} if "max_output_len" in inputs: @@ -898,17 +1174,24 @@ def triton_infer_fn(self, **inputs: np.ndarray): if "lora_uids" in inputs: lora_uids = np.char.decode(inputs.pop("lora_uids").astype("bytes"), encoding="utf-8") infer_input["lora_uids"] = lora_uids[0].tolist() + if "output_generation_logits" in inputs: + infer_input["output_generation_logits"] = inputs.pop("output_generation_logits")[0][0] - output_texts = self.forward(**infer_input) - output = cast_output(output_texts, np.bytes_) + if infer_input["output_generation_logits"]: + output_texts, generation_logits = self.forward(**infer_input) + output_dict["generation_logits"] = np.array(generation_logits.cpu().numpy()) + else: + output_texts = self.forward(**infer_input) + output_dict["outputs"] = cast_output(output_texts, np.bytes_) except Exception as error: err_msg = "An error occurred: {0}".format(str(error)) - output = cast_output([err_msg], np.bytes_) + output_dict["outputs"] = cast_output([err_msg], np.bytes_) - return {"outputs": output} + return output_dict @batch def triton_infer_fn_streaming(self, **inputs: np.ndarray): + """Triton infer function for streaming""" try: infer_input = {"input_texts": str_ndarray2list(inputs.pop("prompts"))} if "max_output_len" in inputs: @@ -1118,4 +1401,5 @@ def _load(self): ) from error def unload_engine(self): + """Unload engine""" unload_engine() diff --git a/nemo/export/tiktoken_tokenizer.py b/nemo/export/tiktoken_tokenizer.py new file mode 100644 index 000000000000..d599620256fa --- /dev/null +++ b/nemo/export/tiktoken_tokenizer.py @@ -0,0 +1,123 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import json +from pathlib import Path +from typing import Dict, Optional + +import numpy as np +import tiktoken +import torch + +PATTERN_TIKTOKEN = "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" +DEFAULT_TIKTOKEN_MAX_VOCAB = 2**17 # 131072 +SPECIAL_TOKENS = ["", "", ""] +SPECIAL_TOKEN_TEMPLATE = "" + + +def reload_mergeable_ranks( + path: str, + max_vocab: Optional[int] = None, +) -> Dict[bytes, int]: + """ + Reload the tokenizer JSON file and convert it to Tiktoken format. + """ + assert path.endswith(".json") + + # reload vocab + with open(path, "r", encoding='utf-8') as f: + vocab = json.load(f) + assert isinstance(vocab, list) + print(f"Vocab size: {len(vocab)}") + if max_vocab is not None: + vocab = vocab[:max_vocab] + print(f"Cutting vocab to first {len(vocab)} tokens.") + + # build ranks + ranks: Dict[bytes, int] = {} + for i, x in enumerate(vocab): + assert x.keys() == {"rank", "token_bytes", "token_str"} + assert x["rank"] == i + merge = base64.b64decode(x["token_bytes"]) + assert i >= 256 or merge == bytes([i]) + ranks[merge] = x["rank"] + + # sanity check + assert len(ranks) == len(vocab) + assert set(ranks.values()) == set(range(len(ranks))) + + return ranks + + +class TiktokenTokenizer: + def __init__(self, vocab_file: str): + + self.num_special_tokens = 1000 + vocab_size = DEFAULT_TIKTOKEN_MAX_VOCAB + pattern = PATTERN_TIKTOKEN + special_tokens = SPECIAL_TOKENS.copy() + inner_vocab_size = vocab_size - self.num_special_tokens + + token2id = reload_mergeable_ranks(vocab_file, max_vocab=inner_vocab_size) + self.tokenizer = tiktoken.Encoding( + name=Path(vocab_file).parent.name, + pat_str=pattern, + mergeable_ranks=token2id, + special_tokens={}, # special tokens are handled manually + ) + + # BOS / EOS / Pad token IDs + self._bos_id = special_tokens.index("") + self._eos_id = special_tokens.index("") + + def encode(self, text): + tokens = self.tokenizer.encode(text) + tokens = [t + self.num_special_tokens for t in tokens] + return tokens + + def decode(self, tokens): + # Filter out special tokens and adjust the remaining tokens + adjusted_tokens = [ + t - self.num_special_tokens + for t in tokens + if t not in {self._bos_id, self._eos_id} and t >= self.num_special_tokens + ] + + # Decode only if there are tokens left after filtering + if adjusted_tokens: + return self.tokenizer.decode(adjusted_tokens) + else: + return "" # Return an empty string if all tokens were filtered out + + def batch_decode(self, ids): + if isinstance(ids, np.ndarray) or torch.is_tensor(ids): + ids = ids.tolist() + + if isinstance(ids[0], list): + ids = ids[0] + + return self.decode(ids) + + @property + def pad_id(self): + return self._eos_id + + @property + def bos_token_id(self): + return self._bos_id + + @property + def eos_token_id(self): + return self._eos_id diff --git a/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py b/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py index db1aec0f5a55..b0e134ab0c35 100644 --- a/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py +++ b/nemo/export/trt_llm/converter/model_to_trt_llm_ckpt.py @@ -161,7 +161,7 @@ def convert_model_to_trt_llm_ckpt( or nemo_model_config.get("layernorm_zero_centered_gamma", False), "tp_size": training_tp_size, "split_gated_activation": nemo_model_config.get("activation", "gelu") - in ["swiglu", "geglu", "fast-swiglu", "fast-geglu"] + in ["swiglu", "geglu", "fast-swiglu", "fast-geglu", "openai-gelu"] and (decoder_type == "gptnext" or is_mcore), "num_attention_heads": num_attention_heads, "num_kv_heads": num_kv_heads, @@ -336,7 +336,7 @@ def dist_model_to_trt_llm_ckpt( "apply_layernorm_1p": nemo_model_config.get("normalization", "") == "layernorm1p", "tp_size": tp_size, "split_gated_activation": nemo_model_config.get("activation", "gelu") - in ["swiglu", "geglu", "fast-swiglu", "fast-geglu"], + in ["swiglu", "geglu", "fast-swiglu", "fast-geglu", "openai-gelu"], "num_attention_heads": nemo_model_config["num_attention_heads"], "num_kv_heads": nemo_model_config.get('num_query_groups', nemo_model_config['num_attention_heads']), "convert_on_device": True, diff --git a/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py b/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py index 23d227d32acf..a2d745864d92 100644 --- a/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py +++ b/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py @@ -36,6 +36,7 @@ from nemo.export.sentencepiece_tokenizer import SentencePieceTokenizer from nemo.export.tarutils import TarPath, ZarrPathStore +from nemo.export.tiktoken_tokenizer import TiktokenTokenizer LOGGER = logging.getLogger("NeMo") @@ -235,7 +236,7 @@ def load_sharded_metadata(checkpoint_dir: Union[Path, TarPath], torch_tensor=Tru def update_tokenizer_paths(tokenizer_config: Dict, unpacked_checkpoints_dir): def _update_config_entry(key, file_pattern): - old_path = tokenizer_config[key] + old_path = tokenizer_config.get(key, None) if old_path is None: return old_path = Path(old_path) @@ -262,7 +263,7 @@ def copy_tokenizer_files(config, out_dir): } for key in basenames.keys(): - if config[key] is None: + if config.get(key, None) is None: continue path = config[key] @@ -275,6 +276,7 @@ def copy_tokenizer_files(config, out_dir): continue dst_path = out_dir / f"{basenames[key]}{path.suffix}" + config[key] = str(dst_path) LOGGER.debug(f"Copy tokenizer {key}: {path}->{dst_path}") # Copy 'path' to 'dst_path' without shutil.copy(...) because 'path' may be a TarPath @@ -282,6 +284,8 @@ def copy_tokenizer_files(config, out_dir): with open(dst_path, 'wb') as outfile: outfile.write(infile.read()) + return config + def get_tokenizer(tokenizer_dir_or_path: Union[str, Path]) -> PreTrainedTokenizer: """Loads the tokenizer from the decoded NeMo weights dir.""" @@ -291,6 +295,10 @@ def get_tokenizer(tokenizer_dir_or_path: Union[str, Path]) -> PreTrainedTokenize tokenizer_spec = io.load_context((tokenizer_dir_or_path / "nemo_context"), subpath="model.tokenizer") return build_tokenizer(tokenizer_spec) + elif os.path.exists(os.path.join(tokenizer_dir_or_path, "vocab.json")): + vocab_path = tokenizer_dir_or_path / "vocab.json" if tokenizer_dir_or_path.is_dir() else tokenizer_dir_or_path + tokenizer_config = {"library": "tiktoken", "vocab_file": str(vocab_path)} + return build_tokenizer(tokenizer_config) else: if (tokenizer_dir_or_path / "huggingface_tokenizer").is_dir(): return AutoTokenizer.from_pretrained(tokenizer_dir_or_path / "huggingface_tokenizer") @@ -307,6 +315,8 @@ def build_tokenizer(tokenizer): tokenizer_config = tokenizer if tokenizer_config["library"] == "sentencepiece": return SentencePieceTokenizer(model_path=tokenizer_config["model"]) + elif tokenizer_config["library"] == "tiktoken": + return TiktokenTokenizer(vocab_file=tokenizer_config["vocab_file"]) elif "GPT2" in tokenizer_config["type"]: tokenizer = GPT2Tokenizer(tokenizer_config["vocab_file"], tokenizer_config["merge_file"]) else: @@ -317,24 +327,31 @@ def build_tokenizer(tokenizer): if tokenizer.eos_token_id is None: tokenizer.add_special_tokens({"eos_token": ""}) else: - try: - # If NeMo tokenizer, monkey patch interface - from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec - - if isinstance(tokenizer, TokenizerSpec): - - def batch_encode_patch(self, ids): + # For NeMo tokenizers, monkey patch encode & batch_decode methods for unified interface + import nemo.collections.common.tokenizers as nemo_tokenizers + + if isinstance(tokenizer, nemo_tokenizers.TokenizerSpec): + if isinstance(tokenizer, nemo_tokenizers.AutoTokenizer): + # Unwrap the original methods of HF tokenizer + batch_decode = tokenizer.tokenizer.batch_decode + encode = tokenizer.tokenizer.encode + elif isinstance(tokenizer, nemo_tokenizers.SentencePieceTokenizer): + # Define HF equivalents based on available SP methods + def batch_decode(self, ids): if torch.is_tensor(ids): ids = ids.cpu().numpy() - ids = ids[0] if len(ids.shape) > 1 else ids - return self.ids_to_text(ids) + if isinstance(ids, np.ndarray): + ids = ids.tolist() + return self.tokenizer.decode(ids) + + encode = tokenizer.tokenizer.encode_as_ids + else: + raise NotImplementedError(f"Patching tokenizer methods for {type(tokenizer)} is not available") - tokenizer.bos_token_id = tokenizer.bos_id - tokenizer.eos_token_id = tokenizer.eos_id - tokenizer.encode = tokenizer.text_to_ids - TokenizerSpec.batch_decode = batch_encode_patch - except: - raise TypeError(f'Unsupported tokenizer build input: {type(tokenizer)}') + tokenizer.bos_token_id = tokenizer.bos_id + tokenizer.eos_token_id = tokenizer.eos_id + nemo_tokenizers.TokenizerSpec.encode = encode + nemo_tokenizers.TokenizerSpec.batch_decode = batch_decode return tokenizer @@ -366,9 +383,8 @@ def load_nemo_model(nemo_ckpt: Union[str, Path], nemo_export_dir: Union[str, Pat ) else: tokenizer_config = update_tokenizer_paths(nemo_model_config["tokenizer"], unpacked_checkpoint_dir) - copy_tokenizer_files(tokenizer_config, nemo_export_dir) + tokenizer_config = copy_tokenizer_files(tokenizer_config, nemo_export_dir) - tokenizer_config["model"] = os.path.join(nemo_export_dir, "tokenizer.model") tokenizer = build_tokenizer(tokenizer_config) elif (nemo_dir / "weights").exists(): dist_ckpt_folder = nemo_dir / "weights" diff --git a/nemo/export/trt_llm/qnemo/qnemo_to_tensorrt_llm.py b/nemo/export/trt_llm/qnemo/qnemo_to_tensorrt_llm.py index eac1ab743849..f601c8cb1c5a 100644 --- a/nemo/export/trt_llm/qnemo/qnemo_to_tensorrt_llm.py +++ b/nemo/export/trt_llm/qnemo/qnemo_to_tensorrt_llm.py @@ -18,7 +18,6 @@ import warnings from typing import List, Optional -import tensorrt_llm from tensorrt_llm.models import PretrainedConfig from nemo.export.trt_llm.qnemo.utils import CONFIG_NAME, WEIGHTS_NAME @@ -51,7 +50,7 @@ def qnemo_to_tensorrt_llm( warnings.warn( "Note that setting tensor_parallel_size, pipeline_parallel_size and use_parallel_embedding " - " parameters for quantized models is done on calibration step with nemo.export.quantize module." + " parameters for quantized models is done on the calibration step (in PTQ workflow)." " These parameters are ignored when building and running TensorRT-LLM engine below.", UserWarning, stacklevel=3, @@ -93,11 +92,7 @@ def qnemo_to_tensorrt_llm( build_cmd += f"--remove_input_padding {'enable' if remove_input_padding else 'disable'} " build_cmd += f"--multiple_profiles {'enable' if multiple_profiles else 'disable'} " build_cmd += f"--reduce_fusion {'enable' if reduce_fusion else 'disable'} " - # TODO: resolve version check for setting use_fused_mlp once we move to 0.13.0 in the NeMo container - if tensorrt_llm.__version__ >= "0.13.0": - build_cmd += f"--use_fused_mlp {'enable' if use_fused_mlp else 'disable'} " - else: - build_cmd += "--use_fused_mlp " if use_fused_mlp else "" + build_cmd += f"--use_fused_mlp {'enable' if use_fused_mlp else 'disable'} " if not use_qdq: build_cmd += f"--gemm_plugin auto " diff --git a/nemo/export/trt_llm/qnemo/tokenizer_utils.py b/nemo/export/trt_llm/qnemo/tokenizer_utils.py index beca40bcd3d7..37b45521dcca 100644 --- a/nemo/export/trt_llm/qnemo/tokenizer_utils.py +++ b/nemo/export/trt_llm/qnemo/tokenizer_utils.py @@ -12,37 +12,43 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import os from omegaconf import OmegaConf from transformers import AutoTokenizer from nemo.export.sentencepiece_tokenizer import SentencePieceTokenizer +from nemo.export.tiktoken_tokenizer import TiktokenTokenizer # TODO: use get_nmt_tokenizer helper below to instantiate tokenizer once environment / dependencies get stable # from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer TOKENIZER_CONFIG_FILE = "tokenizer_config.yaml" TOKENIZER_DIR = "tokenizer" +LOGGER = logging.getLogger("NeMo") def get_nmt_tokenizer(nemo_checkpoint_path: str): """Build tokenizer from Nemo tokenizer config.""" - print(f"Initializing tokenizer from {TOKENIZER_CONFIG_FILE}") + LOGGER.info(f"Initializing tokenizer from {TOKENIZER_CONFIG_FILE}") tokenizer_cfg = OmegaConf.load(os.path.join(nemo_checkpoint_path, TOKENIZER_CONFIG_FILE)) library = tokenizer_cfg.library legacy = tokenizer_cfg.get("sentencepiece_legacy", library == "sentencepiece") if library == "huggingface": - print(f"Getting HuggingFace AutoTokenizer with pretrained_model_name: {tokenizer_cfg.type}") + LOGGER.info(f"Getting HuggingFace AutoTokenizer with pretrained_model_name: {tokenizer_cfg.type}") tokenizer = AutoTokenizer.from_pretrained(tokenizer_cfg["type"], use_fast=tokenizer_cfg.get("use_fast", False)) elif library == "sentencepiece": - print(f"Getting SentencePieceTokenizer with model: {tokenizer_cfg.model}") + LOGGER.info(f"Getting SentencePieceTokenizer with model: {tokenizer_cfg.model}") tokenizer = SentencePieceTokenizer( model_path=os.path.join(nemo_checkpoint_path, tokenizer_cfg.model), legacy=legacy ) + elif library == "tiktoken": + print(f"Getting TiktokenTokenizer with file: {tokenizer_cfg.vocab_file}") + tokenizer = TiktokenTokenizer(vocab_file=os.path.join(nemo_checkpoint_path, tokenizer_cfg.vocab_file)) else: raise NotImplementedError("Currently we only support 'huggingface' and 'sentencepiece' tokenizer libraries.") diff --git a/nemo/export/trt_llm/tensorrt_llm_build.py b/nemo/export/trt_llm/tensorrt_llm_build.py index 4be2d42ebe4d..b2b761483700 100755 --- a/nemo/export/trt_llm/tensorrt_llm_build.py +++ b/nemo/export/trt_llm/tensorrt_llm_build.py @@ -54,6 +54,8 @@ def build_and_save_engine( gpt_attention_plugin: str = "auto", gemm_plugin: str = "auto", reduce_fusion: bool = False, + gather_context_logits: bool = False, + gather_generation_logits: bool = False, ): architecture = "LLaMAForCausalLM" if model_config.architecture == "LlamaForCausalLM" else model_config.architecture try: @@ -96,8 +98,8 @@ def build_and_save_engine( 'max_num_tokens': max_num_tokens, 'opt_num_tokens': opt_num_tokens, 'max_prompt_embedding_table_size': max_prompt_embedding_table_size, - 'gather_context_logits': False, - 'gather_generation_logits': False, + 'gather_context_logits': gather_context_logits, + 'gather_generation_logits': gather_generation_logits, 'strongly_typed': False, 'builder_opt': None, 'use_refit': use_refit, @@ -118,14 +120,6 @@ def build_and_save_engine( build_config.lora_config = lora_config model = model_cls.from_config(model_config) - if not model_config.bias and model_config.architecture == 'GPTForCausalLM': - # NOTE: GPT models in megatron-core that set bias=False sets the bias false globally - # whereas bias=False in TRTLLM GPT models sets it false everywhere except - # LayerNorm. This change makes TRTLLM's implementation match megatron-core. - for name, module in model.named_modules(): - if isinstance(module, tensorrt_llm.layers.normalization.LayerNorm): - module.bias = None - module.register_parameter('bias', None) model = optimize_model( model, use_parallel_embedding=model_config.use_parallel_embedding, diff --git a/nemo/export/trt_llm/tensorrt_llm_run.py b/nemo/export/trt_llm/tensorrt_llm_run.py index bd7b8abd5f9e..ef67c918290f 100644 --- a/nemo/export/trt_llm/tensorrt_llm_run.py +++ b/nemo/export/trt_llm/tensorrt_llm_run.py @@ -647,6 +647,7 @@ def generate( streaming: bool = False, output_log_probs=False, multiprocessed_env=False, + output_generation_logits=False, **sampling_kwargs, ) -> Optional[List[List[str]]]: """Generate the output sequence from the input sequence. @@ -692,6 +693,7 @@ def generate( multiprocessed_env=multiprocessed_env, **sampling_kwargs, ) + assert outputs is not None if tensorrt_llm.mpi_rank() != 0: return None @@ -705,8 +707,8 @@ def generate( for b in range(output_ids.shape[0]) ] - if output_log_probs: - return output_lines_list, log_probs + if output_generation_logits: + return output_lines_list, outputs['generation_logits'] return output_lines_list diff --git a/nemo/export/vllm_exporter.py b/nemo/export/vllm_exporter.py index 0ce7d49126d3..97575058bd1c 100644 --- a/nemo/export/vllm_exporter.py +++ b/nemo/export/vllm_exporter.py @@ -222,7 +222,6 @@ def export( max_num_seqs=256, # Note: max_model_len can be derived by model_config if the input value is None max_model_len=model_config.max_model_len, - use_v2_block_manager=False, num_lookahead_slots=0, delay_factor=0.0, enable_chunked_prefill=False, @@ -403,6 +402,7 @@ def get_triton_input(self): Tensor(name="top_p", shape=(-1,), dtype=numpy.single, optional=True), Tensor(name="temperature", shape=(-1,), dtype=numpy.single, optional=True), Tensor(name="lora_uids", shape=(-1,), dtype=bytes, optional=True), + Tensor(name="output_generation_logits", shape=(-1,), dtype=numpy.bool_, optional=True), ) return inputs @@ -455,6 +455,7 @@ def forward( prompt_embeddings_checkpoint_path: Optional[str] = None, streaming: bool = False, output_log_probs: bool = False, + output_generation_logits: bool = False, ) -> Union[List[List[str]], Iterable[List[List[str]]]]: """ The forward function performs LLM evaluation on the provided array of prompts with other parameters shared, @@ -484,6 +485,9 @@ def forward( if output_log_probs: raise NotImplementedError("output_log_probs is not supported") + if output_generation_logits: + raise NotImplementedError("output_generation_logits is not supported") + request_ids = [] for index in range(len(input_texts)): prompt = input_texts[index] diff --git a/nemo/lightning/__init__.py b/nemo/lightning/__init__.py index 91d3b3f936d0..e01a2d5e5765 100644 --- a/nemo/lightning/__init__.py +++ b/nemo/lightning/__init__.py @@ -14,8 +14,8 @@ from typing import Union -from lightning_fabric.plugins.environments import slurm -from pytorch_lightning import plugins as _pl_plugins +from lightning.fabric.plugins.environments import slurm +from lightning.pytorch import plugins as _pl_plugins # This is here to import it once, which improves the speed of launch when in debug-mode from nemo.utils.import_utils import safe_import diff --git a/nemo/lightning/_strategy_lib.py b/nemo/lightning/_strategy_lib.py index 1bee71e26e17..1fb4b4e0a757 100644 --- a/nemo/lightning/_strategy_lib.py +++ b/nemo/lightning/_strategy_lib.py @@ -28,7 +28,7 @@ if TYPE_CHECKING: - from lightning_fabric.utilities.types import Optimizable + from lightning.fabric.utilities.types import Optimizable from megatron.core.model_parallel_config import ModelParallelConfig @@ -89,7 +89,8 @@ def init_parallel_ranks( seed=seed, pipeline_model_parallel_split_rank=getattr(parallel_config, "pipeline_model_parallel_split_rank", None), use_fp8=fp8, - init_mpi_proc_group=getattr(parallel_config, "tp_comm_overlap", False), + init_mpi_proc_group=getattr(parallel_config, "tp_comm_overlap", False) + and getattr(parallel_config, "tp_comm_bootstrap_backend", None) == 'mpi', # apex_transformer_log_level=self.cfg.get('apex_transformer_log_level', 30), ) @@ -514,6 +515,17 @@ def get_safe(param_id): def load_model_state_dict(megatron_parallel, checkpoint: Mapping[str, Any], strict: bool = True) -> None: from megatron.core import parallel_state + from megatron.core.dist_checkpointing.validation import StrictHandling, parse_strict_flag + + ## convert from StrictHandling to bool for PTL + if strict is not None and not isinstance(strict, bool): + strict = parse_strict_flag(strict) + strict_options = [ + StrictHandling.ASSUME_OK_UNEXPECTED, + StrictHandling.RAISE_UNEXPECTED, + StrictHandling.RAISE_ALL, + ] + strict = strict in strict_options for index, module in enumerate(megatron_parallel): if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None: diff --git a/nemo/lightning/base.py b/nemo/lightning/base.py index b6ba14726818..3b0b1c0c7234 100644 --- a/nemo/lightning/base.py +++ b/nemo/lightning/base.py @@ -19,7 +19,7 @@ import torch import torch.distributed -from pytorch_lightning import Trainer +from lightning.pytorch import Trainer from torch import nn diff --git a/nemo/lightning/data.py b/nemo/lightning/data.py index 9cf686464417..9cb685a096fa 100644 --- a/nemo/lightning/data.py +++ b/nemo/lightning/data.py @@ -19,7 +19,7 @@ from typing import List, Literal, Optional import torch -from pytorch_lightning.overrides.distributed import _IndexBatchSamplerWrapper +from lightning.pytorch.overrides.distributed import _IndexBatchSamplerWrapper from torch.utils.data import DataLoader, Dataset diff --git a/nemo/lightning/fabric/conversion.py b/nemo/lightning/fabric/conversion.py index 9ad713ec5261..d1c7affe3f40 100644 --- a/nemo/lightning/fabric/conversion.py +++ b/nemo/lightning/fabric/conversion.py @@ -15,10 +15,10 @@ from functools import singledispatch from typing import Any, TypeVar -from lightning_fabric import plugins as fl_plugins -from lightning_fabric import strategies as fl_strategies -from pytorch_lightning import plugins as pl_plugins -from pytorch_lightning import strategies as pl_strategies +from lightning.fabric import plugins as fl_plugins +from lightning.fabric import strategies as fl_strategies +from lightning.pytorch import plugins as pl_plugins +from lightning.pytorch import strategies as pl_strategies T = TypeVar('T') FabricT = TypeVar('FabricT') @@ -39,8 +39,8 @@ def to_fabric(obj: Any) -> Any: NotImplementedError: If no converter is registered for the object's type. Example: - >>> from pytorch_lightning.strategies import Strategy as PLStrategy - >>> from lightning_fabric.strategies import Strategy as FabricStrategy + >>> from lightning.pytorch.strategies import Strategy as PLStrategy + >>> from lightning.fabric.strategies import Strategy as FabricStrategy >>> from nemo.lightning.fabric.conversion import to_fabric >>> >>> # Define a custom PyTorch Lightning strategy @@ -70,7 +70,7 @@ def to_fabric(obj: Any) -> Any: f"No Fabric converter registered for {type(obj).__name__}. " f"To register a new conversion, use the @to_fabric.register decorator:\n\n" f"from nemo.lightning.fabric.conversion import to_fabric\n" - f"from lightning_fabric import strategies as fl_strategies\n\n" + f"from lightning.fabric import strategies as fl_strategies\n\n" f"@to_fabric.register({type(obj).__name__})\n" f"def _{type(obj).__name__.lower()}_converter(obj: {type(obj).__name__}) -> fl_strategies.Strategy:\n" f" return fl_strategies.SomeStrategy(\n" diff --git a/nemo/lightning/fabric/fabric.py b/nemo/lightning/fabric/fabric.py index 60eb518a1e42..7d604de749d6 100644 --- a/nemo/lightning/fabric/fabric.py +++ b/nemo/lightning/fabric/fabric.py @@ -17,8 +17,8 @@ from typing import TYPE_CHECKING, Optional, Protocol, Sequence, Type, TypeVar, Union, runtime_checkable import fiddle as fdl -import lightning_fabric as lb -import pytorch_lightning as pl +import lightning.fabric as lb +import lightning.pytorch as pl from torch import nn from typing_extensions import Self, override diff --git a/nemo/lightning/fabric/plugins.py b/nemo/lightning/fabric/plugins.py index 723b48b6b357..58bf5f5ca9f9 100644 --- a/nemo/lightning/fabric/plugins.py +++ b/nemo/lightning/fabric/plugins.py @@ -16,7 +16,7 @@ from typing import TYPE_CHECKING, Any, Generator, Literal, TypeVar import torch -from lightning_fabric.plugins.precision import MixedPrecision +from lightning.fabric.plugins.precision import MixedPrecision from torch import nn from torch.optim import Optimizer diff --git a/nemo/lightning/fabric/strategies.py b/nemo/lightning/fabric/strategies.py index 575f69a58caf..30a03504060f 100644 --- a/nemo/lightning/fabric/strategies.py +++ b/nemo/lightning/fabric/strategies.py @@ -29,21 +29,21 @@ ) import torch -from lightning_fabric.accelerators import CPUAccelerator -from lightning_fabric.accelerators.accelerator import Accelerator -from lightning_fabric.plugins.collectives.torch_collective import default_pg_timeout -from lightning_fabric.plugins.environments.cluster_environment import ClusterEnvironment -from lightning_fabric.plugins.io.checkpoint_io import CheckpointIO -from lightning_fabric.plugins.precision import Precision -from lightning_fabric.strategies import DDPStrategy -from lightning_fabric.strategies.strategy import _validate_keys_for_strict_loading -from lightning_fabric.utilities.types import _PATH, _Stateful +from lightning.fabric.accelerators import CPUAccelerator +from lightning.fabric.accelerators.accelerator import Accelerator +from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout +from lightning.fabric.plugins.environments.cluster_environment import ClusterEnvironment +from lightning.fabric.plugins.io.checkpoint_io import CheckpointIO +from lightning.fabric.plugins.precision import Precision +from lightning.fabric.strategies import DDPStrategy +from lightning.fabric.strategies.strategy import _validate_keys_for_strict_loading +from lightning.fabric.utilities.types import _PATH, _Stateful +from lightning.pytorch import LightningDataModule +from lightning.pytorch.loops.fetchers import _DataFetcher +from lightning.pytorch.plugins.io.wrapper import _WrappingCheckpointIO +from lightning.pytorch.utilities.combined_loader import CombinedLoader from megatron.core.distributed import DistributedDataParallelConfig from megatron.core.optimizer import OptimizerConfig -from pytorch_lightning import LightningDataModule -from pytorch_lightning.loops.fetchers import _DataFetcher -from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO -from pytorch_lightning.utilities.combined_loader import CombinedLoader from torch import Tensor, nn from torch.distributed.algorithms.ddp_comm_hooks.debugging_hooks import noop_hook from torch.nn import Module diff --git a/nemo/lightning/io/api.py b/nemo/lightning/io/api.py index be9372f2e79b..869ec6e613cb 100644 --- a/nemo/lightning/io/api.py +++ b/nemo/lightning/io/api.py @@ -16,7 +16,7 @@ from typing import Callable, Optional, Type, overload import fiddle as fdl -import pytorch_lightning as pl +import lightning.pytorch as pl from nemo.lightning.io.mixin import ConnectorMixin, ConnT, ModelConnector, load from nemo.lightning.io.pl import TrainerContext diff --git a/nemo/lightning/io/connector.py b/nemo/lightning/io/connector.py index e699f15565bd..a38be6ee8f0a 100644 --- a/nemo/lightning/io/connector.py +++ b/nemo/lightning/io/connector.py @@ -18,9 +18,9 @@ from pathlib import Path, PosixPath, PurePath, WindowsPath from typing import Generic, Optional, Tuple, TypeVar -import pytorch_lightning as pl +import lightning.pytorch as pl from filelock import FileLock, Timeout -from pytorch_lightning.trainer.states import TrainerFn +from lightning.pytorch.trainer.states import TrainerFn from nemo.lightning.ckpt_utils import ckpt_to_context_subdir diff --git a/nemo/lightning/io/pl.py b/nemo/lightning/io/pl.py index 10ed52b136c2..f2c70034fd50 100644 --- a/nemo/lightning/io/pl.py +++ b/nemo/lightning/io/pl.py @@ -17,12 +17,12 @@ from pathlib import Path from typing import Any, Callable, Dict, Generic, Optional, TypeVar, Union -import pytorch_lightning as pl +import lightning.pytorch as pl import torch -from lightning_fabric.plugins import CheckpointIO -from lightning_fabric.plugins.io.checkpoint_io import CheckpointIO -from lightning_fabric.utilities.cloud_io import get_filesystem -from lightning_fabric.utilities.types import _PATH +from lightning.fabric.plugins import CheckpointIO +from lightning.fabric.plugins.io.checkpoint_io import CheckpointIO +from lightning.fabric.utilities.cloud_io import get_filesystem +from lightning.fabric.utilities.types import _PATH from megatron.core.dist_checkpointing.serialization import ( get_default_load_sharded_strategy, get_default_save_sharded_strategy, @@ -155,9 +155,6 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_optio checkpoint_dir = ckpt_to_weights_subdir(path, is_saving=True) fs = get_filesystem(checkpoint_dir) - if fs.isdir(checkpoint_dir) and dist_checkpointing.check_is_distributed_checkpoint(checkpoint_dir): - logging.info(f'Distributed checkpoint at path {checkpoint_dir} already exists, skipping saving') - return fs.makedirs(checkpoint_dir, exist_ok=True) validate_sharding_integrity = not (self.validated_consistency and self.assume_constant_structure) @@ -173,7 +170,11 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_optio @override def load_checkpoint( - self, path: _PATH, sharded_state_dict=None, map_location: Optional[Callable] = None + self, + path: _PATH, + sharded_state_dict=None, + map_location: Optional[Callable] = None, + strict: Optional['StrictHandling'] | bool = None, ) -> Dict[str, Any]: """Loads checkpoint using :func:`torch.load`, with additional handling for ``fsspec`` remote loading of files. @@ -190,6 +191,7 @@ def load_checkpoint( """ from megatron.core import dist_checkpointing + from megatron.core.dist_checkpointing.validation import StrictHandling if map_location is not None: raise ValueError("`map_location` argument is not supported for `MegatronCheckpointIO.load_checkpoint`.") @@ -223,8 +225,21 @@ def load_checkpoint( if sharded_strategy is not None: logging.info(f'Using {sharded_strategy} dist-ckpt load strategy.') + if isinstance(strict, bool): + # For backward-compatibility reasons and a bug in MCore (strict check not applied to factories) + # we must apply a simple strict check here. + if not strict: + sharded_state_dict = self.adjust_non_strict_load(path, sharded_state_dict) + strict = StrictHandling.ASSUME_OK_UNEXPECTED if strict else StrictHandling.LOG_ALL + if strict is None: + # Default behavior + strict = StrictHandling.ASSUME_OK_UNEXPECTED + checkpoint = dist_checkpointing.load( - sharded_state_dict=sharded_state_dict, checkpoint_dir=str(path), sharded_strategy=sharded_strategy + sharded_state_dict=sharded_state_dict, + checkpoint_dir=str(path), + sharded_strategy=sharded_strategy, + strict=strict, ) checkpoint = _fix_tensors_device(checkpoint) @@ -287,6 +302,34 @@ def save_sharded_strategy(self) -> 'SaveShardedStrategy': self._save_sharded_strategy = self._determine_dist_ckpt_save_strategy() return self._save_sharded_strategy + def adjust_non_strict_load(self, path: _PATH, sharded_state_dict: Dict[str, Any]): + from megatron.core import dist_checkpointing + from megatron.core.dist_checkpointing.dict_utils import extract_matching_values + from megatron.core.dist_checkpointing.mapping import ShardedBase + + ckpt_sharded_metadata = dist_checkpointing.load_tensors_metadata(path) + loaded_keys = [] + missing_keys = [] + unexpected_keys = [] + + def should_remove_missing_sharded_base(x: Any): + if isinstance(x, ShardedBase): + if x.key in ckpt_sharded_metadata: + loaded_keys.append(x.key) + return False + else: + unexpected_keys.append(x.key) + return True + return False + + _, sharded_state_dict = extract_matching_values(sharded_state_dict, should_remove_missing_sharded_base) + logging.info(f'The following keys are not in the checkpoint and will not be loaded: {unexpected_keys}') + + # TODO: compute missing_keys by: + # 1. all_gather_object of loaded_keys + # 2. missing_keys = ckpt_sharded_metadata.keys() - loaded_keys + return sharded_state_dict + def _fix_tensors_device(ckpt: Dict) -> Dict: """Ensure checkpoint tensors are on the correct device.""" diff --git a/nemo/lightning/io/state.py b/nemo/lightning/io/state.py index 6632768ec8dd..f2c26aa4d495 100644 --- a/nemo/lightning/io/state.py +++ b/nemo/lightning/io/state.py @@ -242,7 +242,12 @@ def __call__(self, ctx: TransformCTX) -> TransformCTX: source_matches_dict = {k: _match_keys(list(source_dict.keys()), v) for k, v in source_key_dict.items()} target_matches = _match_keys(list(target_dict.keys()), target_key) param_names = list(filter(lambda x: x in source_matches_dict, fn_params)) - for layer_names_group in zip(*([source_matches_dict[v] for v in param_names] + [target_matches])): + source_matches = [ + source_matches_dict[v] if source_matches_dict[v].ndim > 0 else [source_matches_dict[v].item()] + for v in param_names + ] + target_matches = [target_matches if target_matches.ndim > 0 else [target_matches.item()]] + for layer_names_group in zip(*(source_matches + target_matches)): # Wrap in a list if it's a single layer (ie non-expert) if isinstance(layer_names_group[0], str): layer_names_group = [[x] for x in layer_names_group] diff --git a/nemo/lightning/megatron_parallel.py b/nemo/lightning/megatron_parallel.py index 6a3138b1da29..0f84f3be0a23 100644 --- a/nemo/lightning/megatron_parallel.py +++ b/nemo/lightning/megatron_parallel.py @@ -42,13 +42,12 @@ import torch import torch.distributed +from lightning.pytorch.utilities import move_data_to_device from megatron.core import parallel_state from megatron.core.distributed import DistributedDataParallel as McoreDDP from megatron.core.distributed import DistributedDataParallelConfig from megatron.core.optimizer import OptimizerConfig from megatron.core.transformer.transformer_config import TransformerConfig -from pytorch_lightning.trainer.states import TrainerFn -from pytorch_lightning.utilities import move_data_to_device from torch import Tensor, nn from typing_extensions import override @@ -58,7 +57,7 @@ STEP_OUTPUT = Optional[Union[Tensor, Mapping[str, Any]]] if TYPE_CHECKING: - import pytorch_lightning as pl + import lightning.pytorch as pl @runtime_checkable @@ -836,7 +835,7 @@ def add(self, *callbacks) -> "CallbackConnector": """ _pl_callback = None try: - import pytorch_lightning as pl + import lightning.pytorch as pl _pl_callback = pl.Callback except ImportError: diff --git a/nemo/lightning/nemo_logger.py b/nemo/lightning/nemo_logger.py index a901a3a8842a..79f622ebc6a8 100644 --- a/nemo/lightning/nemo_logger.py +++ b/nemo/lightning/nemo_logger.py @@ -18,10 +18,10 @@ from pathlib import Path from typing import List, Optional, Union -import lightning_fabric as fl -import pytorch_lightning as pl -from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint as PTLModelCheckpoint -from pytorch_lightning.loggers import Logger, TensorBoardLogger, WandbLogger +import lightning.fabric as fl +import lightning.pytorch as pl +from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint as PTLModelCheckpoint +from lightning.pytorch.loggers import Logger, TensorBoardLogger, WandbLogger from nemo.lightning.io.mixin import IOMixin from nemo.lightning.pytorch.callbacks import ModelCheckpoint diff --git a/nemo/lightning/pytorch/accelerate/__init__.py b/nemo/lightning/pytorch/accelerate/__init__.py new file mode 100644 index 000000000000..d9155f923f18 --- /dev/null +++ b/nemo/lightning/pytorch/accelerate/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/lightning/pytorch/accelerate/transformer_engine.py b/nemo/lightning/pytorch/accelerate/transformer_engine.py new file mode 100755 index 000000000000..8e621352d099 --- /dev/null +++ b/nemo/lightning/pytorch/accelerate/transformer_engine.py @@ -0,0 +1,123 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from types import MethodType + +import torch +from nemo.utils import logging +from nemo.utils.import_utils import safe_import_from + +te, HAVE_TE = safe_import_from("transformer_engine", "pytorch") + + +def te_accelerate(model, fp8_autocast=False): + """ + Replaces original model layers with TE's accelerated layers + Args: + model: HF model + fp8_autocast (bool): apply autocast or not + """ + + if not HAVE_TE: + logging.warning("Transformer Engine is not available and the module replacements " "will not be applied.") + else: + _apply_basic_module_replacement(model) + if fp8_autocast: + apply_fp8_autocast(model) + + +@torch.no_grad +def _apply_basic_module_replacement(model): + for name, module in model.named_children(): + if isinstance(module, torch.nn.Linear): + has_bias = module.bias is not None + if any(p % 16 != 0 for p in module.weight.shape): + continue + te_module = te.Linear( + module.in_features, module.out_features, bias=has_bias, params_dtype=module.weight.dtype + ) + te_module.weight.copy_(module.weight) + if has_bias: + te_module.bias.copy_(module.bias) + + setattr(model, name, te_module) + elif isinstance(module, torch.nn.LayerNorm): + te_module = te.LayerNorm(module.normalized_shape[0], eps=module.eps, params_dtype=module.weight.dtype) + te_module.weight.copy_(module.weight) + te_module.bias.copy_(module.bias) + setattr(model, name, te_module) + elif isinstance(module, torch.nn.RMSNorm): + te_module = te.RMSNorm(module.normalized_shape[0], eps=module.eps, dtype=module.weight.dtype) + te_module.weight.copy_(module.weight) + te_module.bias.copy_(module.bias) + setattr(model, name, te_module) + else: + _apply_basic_module_replacement(module) + + +def is_te_accelerated(model): + """ + Checks whether model has TE layers or not + Args: + model: HF model + """ + + if not HAVE_TE: + logging.warning("Transformer Engine is not available.") + return False + else: + for name, module in model.named_modules(): + if isinstance(module, (te.LayerNorm, te.Linear, te.TransformerLayer)): + return True + + return False + + +def apply_fp8_autocast(model, fp8_recipe_handler=None): + """ + Applies TE's autocast + Args: + model: HF model + fp8_recipe_handler: fpt handler + """ + + if not HAVE_TE: + logging.warning("Transformer Engine is not available and the FP8 autocast " "will not be applied.") + else: + import transformer_engine.common.recipe as te_recipe + + kwargs = fp8_recipe_handler.to_kwargs() if fp8_recipe_handler is not None else {} + if "fp8_format" in kwargs: + kwargs["fp8_format"] = getattr(te_recipe.Format, kwargs["fp8_format"]) + use_during_eval = kwargs.pop("use_autocast_during_eval", False) + fp8_recipe = te_recipe.DelayedScaling(**kwargs) + new_forward = _contextual_fp8_autocast(model.forward, fp8_recipe, use_during_eval) + + if hasattr(model.forward, "__func__"): + model.forward = MethodType(new_forward, model) + else: + model.forward = new_forward + + +def _contextual_fp8_autocast(model_forward, fp8_recipe, use_during_eval=False): + from transformer_engine.pytorch import fp8_autocast + + def forward(self, *args, **kwargs): + enabled = use_during_eval or self.training + with fp8_autocast(enabled=enabled, fp8_recipe=fp8_recipe): + return model_forward(*args, **kwargs) + + forward.__wrapped__ = model_forward + + return forward diff --git a/nemo/lightning/pytorch/callbacks/__init__.py b/nemo/lightning/pytorch/callbacks/__init__.py old mode 100644 new mode 100755 index 8da1a50dcd64..031f027e63b2 --- a/nemo/lightning/pytorch/callbacks/__init__.py +++ b/nemo/lightning/pytorch/callbacks/__init__.py @@ -16,6 +16,7 @@ from nemo.lightning.pytorch.callbacks.debugging import ParameterDebugger from nemo.lightning.pytorch.callbacks.garbage_collection import GarbageCollectionCallback from nemo.lightning.pytorch.callbacks.memory_profiler import MemoryProfileCallback +from nemo.lightning.pytorch.callbacks.model_callback import ModelCallback from nemo.lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint from nemo.lightning.pytorch.callbacks.model_transform import ModelTransform from nemo.lightning.pytorch.callbacks.nsys import NsysCallback @@ -36,4 +37,5 @@ "DdpParityChecker", "GarbageCollectionCallback", "ParameterDebugger", + "ModelCallback", ] diff --git a/nemo/lightning/pytorch/callbacks/ddp_parity_checker.py b/nemo/lightning/pytorch/callbacks/ddp_parity_checker.py index 391666fb8f32..320140d76f3a 100644 --- a/nemo/lightning/pytorch/callbacks/ddp_parity_checker.py +++ b/nemo/lightning/pytorch/callbacks/ddp_parity_checker.py @@ -15,8 +15,8 @@ from functools import cache import torch +from lightning.pytorch.callbacks.callback import Callback from megatron.core.utils import check_param_hashes_across_dp_replicas -from pytorch_lightning.callbacks.callback import Callback from nemo.lightning import io from nemo.utils import logging diff --git a/nemo/lightning/pytorch/callbacks/debugging.py b/nemo/lightning/pytorch/callbacks/debugging.py index 5f6e722ef89b..135e8e486837 100644 --- a/nemo/lightning/pytorch/callbacks/debugging.py +++ b/nemo/lightning/pytorch/callbacks/debugging.py @@ -14,9 +14,9 @@ from typing import Callable, Dict, List, Optional, Union -import pytorch_lightning as pl +import lightning.pytorch as pl import torch -from pytorch_lightning.callbacks import Callback +from lightning.pytorch.callbacks import Callback from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule from nemo.utils import logging diff --git a/nemo/lightning/pytorch/callbacks/garbage_collection.py b/nemo/lightning/pytorch/callbacks/garbage_collection.py index ba4d378ee893..90e122f6d3e4 100644 --- a/nemo/lightning/pytorch/callbacks/garbage_collection.py +++ b/nemo/lightning/pytorch/callbacks/garbage_collection.py @@ -15,7 +15,7 @@ import gc from typing import Any -import pytorch_lightning as pl +import lightning.pytorch as pl from nemo.utils import logging diff --git a/nemo/lightning/pytorch/callbacks/megatron_comm_overlap.py b/nemo/lightning/pytorch/callbacks/megatron_comm_overlap.py index fc4312e2ff84..172aeaeb855d 100644 --- a/nemo/lightning/pytorch/callbacks/megatron_comm_overlap.py +++ b/nemo/lightning/pytorch/callbacks/megatron_comm_overlap.py @@ -13,12 +13,12 @@ # limitations under the License. from dataclasses import asdict, dataclass, fields -import pytorch_lightning as pl +import lightning.pytorch as pl +from lightning.pytorch.callbacks.callback import Callback from megatron.core import ModelParallelConfig from megatron.core.distributed import DistributedDataParallelConfig from megatron.core.optimizer import OptimizerConfig -from pytorch_lightning.callbacks.callback import Callback from nemo.collections.llm.recipes.tp_overlap_configs.userbuffers import TransformerLayerTPOverlapCfg from nemo.lightning.pytorch.strategies.megatron_strategy import MegatronStrategy, ParallelismConfig @@ -43,6 +43,7 @@ class _CommOverlapConfig: # Tensor parallel communication overlap (experimental) tp_comm_overlap: bool = None tp_comm_overlap_cfg: dict = None + tp_comm_bootstrap_backend: str = None # Pipeline parallel communication overlap overlap_p2p_comm: bool = None batch_p2p_comm: bool = None @@ -88,6 +89,7 @@ def __init__( self, tp_comm_overlap: bool = None, tp_comm_overlap_cfg: TransformerLayerTPOverlapCfg = None, + tp_comm_bootstrap_backend: str = None, overlap_p2p_comm: bool = None, batch_p2p_comm: bool = None, overlap_grad_reduce: bool = None, @@ -102,6 +104,7 @@ def __init__( self.user_comm_overlap_cfg = _CommOverlapConfig( tp_comm_overlap=tp_comm_overlap, tp_comm_overlap_cfg=tp_comm_overlap_cfg, + tp_comm_bootstrap_backend=tp_comm_bootstrap_backend, overlap_p2p_comm=overlap_p2p_comm, batch_p2p_comm=batch_p2p_comm, overlap_grad_reduce=overlap_grad_reduce, @@ -114,6 +117,7 @@ def __init__( ) self.tp_comm_overlap_cfg = None + self.tp_comm_bootstrap_backend = None self.need_tp_overlap_ub_init = False def _get_model_comm_overlap_cfgs( @@ -129,6 +133,7 @@ def _get_model_comm_overlap_cfgs( # Optimizations disabled by default, can be overriden by user comm_overlap_cfg.tp_comm_overlap = False comm_overlap_cfg.tp_comm_overlap_cfg = None + comm_overlap_cfg.tp_comm_bootstrap_backend = None comm_overlap_cfg.defer_embedding_wgrad_compute = False comm_overlap_cfg.wgrad_deferral_limit = -1 @@ -216,6 +221,7 @@ def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str) if trainer.model.config.tp_comm_overlap: self.tp_comm_overlap_cfg = comm_overlap_cfg.tp_comm_overlap_cfg + self.tp_comm_bootstrap_backend = comm_overlap_cfg.tp_comm_bootstrap_backend self.need_tp_overlap_ub_init = True # Data parallel overlap is only available with the Megatron DDP and Distributed optimizer @@ -258,6 +264,7 @@ def _init_te_userbuffers(self, model_parallel_cfg: ModelParallelConfig): tp_size=parallel_state.get_tensor_model_parallel_world_size(), use_fp8=fp8, ub_cfgs=self.tp_comm_overlap_cfg, + bootstrap_backend=self.tp_comm_bootstrap_backend, ) except Exception as error: raise Exception(f"Tensor parallel overlap: userbuffer initialization failed with {error}") diff --git a/nemo/lightning/pytorch/callbacks/memory_profiler.py b/nemo/lightning/pytorch/callbacks/memory_profiler.py index 5b2ee1d46e11..2813bd141a7a 100644 --- a/nemo/lightning/pytorch/callbacks/memory_profiler.py +++ b/nemo/lightning/pytorch/callbacks/memory_profiler.py @@ -15,7 +15,7 @@ import os import torch -from pytorch_lightning.callbacks.callback import Callback +from lightning.pytorch.callbacks.callback import Callback from torch.utils.viz._cycles import warn_tensor_cycles from nemo.lightning import io diff --git a/nemo/lightning/pytorch/callbacks/model_callback.py b/nemo/lightning/pytorch/callbacks/model_callback.py new file mode 100755 index 000000000000..0625e3f006c5 --- /dev/null +++ b/nemo/lightning/pytorch/callbacks/model_callback.py @@ -0,0 +1,145 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import inspect +from typing import Callable, Optional +from lightning.pytorch.callbacks import LambdaCallback + + +class ModelCallback(LambdaCallback): + """ + A callback that extends LambdaCallback to intelligently handle function parameters. + Functions can take either (trainer, pl_module), just (pl_module), or just (trainer). + + Supported parameter names: + - trainer, pl_trainer + - model, pl_model, pl_module, module + + Example: + >>> # Using with torch.compile + >>> callback = ModelCallback(on_train_start=torch.compile) + >>> + >>> # Using with thunder_compile + >>> callback = ModelCallback(on_train_start=thunder_compile) + >>> + >>> # Mix different callbacks + >>> callback = ModelCallback( + ... on_train_start=lambda model: torch.compile(model), + ... on_fit_start=lambda trainer, model: print(f"Starting fit with {model}") + ... ) + """ + + TRAINER_PARAMS = {'trainer', 'pl_trainer'} + MODEL_PARAMS = {'model', 'pl_model', 'pl_module', 'module'} + + def __init__( + self, + setup: Optional[Callable] = None, + teardown: Optional[Callable] = None, + on_fit_start: Optional[Callable] = None, + on_fit_end: Optional[Callable] = None, + on_sanity_check_start: Optional[Callable] = None, + on_sanity_check_end: Optional[Callable] = None, + on_train_batch_start: Optional[Callable] = None, + on_train_batch_end: Optional[Callable] = None, + on_train_epoch_start: Optional[Callable] = None, + on_train_epoch_end: Optional[Callable] = None, + on_validation_epoch_start: Optional[Callable] = None, + on_validation_epoch_end: Optional[Callable] = None, + on_test_epoch_start: Optional[Callable] = None, + on_test_epoch_end: Optional[Callable] = None, + on_validation_batch_start: Optional[Callable] = None, + on_validation_batch_end: Optional[Callable] = None, + on_test_batch_start: Optional[Callable] = None, + on_test_batch_end: Optional[Callable] = None, + on_train_start: Optional[Callable] = None, + on_train_end: Optional[Callable] = None, + on_validation_start: Optional[Callable] = None, + on_validation_end: Optional[Callable] = None, + on_test_start: Optional[Callable] = None, + on_test_end: Optional[Callable] = None, + on_exception: Optional[Callable] = None, + on_save_checkpoint: Optional[Callable] = None, + on_load_checkpoint: Optional[Callable] = None, + on_before_backward: Optional[Callable] = None, + on_after_backward: Optional[Callable] = None, + on_before_optimizer_step: Optional[Callable] = None, + on_before_zero_grad: Optional[Callable] = None, + on_predict_start: Optional[Callable] = None, + on_predict_end: Optional[Callable] = None, + on_predict_batch_start: Optional[Callable] = None, + on_predict_batch_end: Optional[Callable] = None, + on_predict_epoch_start: Optional[Callable] = None, + on_predict_epoch_end: Optional[Callable] = None, + ): + # Create a dictionary of non-None callbacks + callbacks = { + name: self._wrap_func(func) + for name, func in locals().items() + if name != 'self' and name != '__class__' and func is not None + } + + super().__init__(**callbacks) + + def _get_param_type(self, param_name: str) -> Optional[str]: + """Determine if a parameter name refers to trainer or model.""" + param_name = param_name.lower() + if param_name in self.TRAINER_PARAMS: + return 'trainer' + if param_name in self.MODEL_PARAMS: + return 'model' + return None + + def _wrap_func(self, func: Callable) -> Callable: + """Wraps a function to handle parameter inspection and passing.""" + sig = inspect.signature(func) + params = sig.parameters + + def wrapped(trainer, pl_module, *args, **kwargs): + call_args = {} + + for param_name, param in params.items(): + param_type = self._get_param_type(param_name) + + if param_type == 'trainer': + call_args[param_name] = trainer + elif param_type == 'model': + call_args[param_name] = pl_module + else: + # If parameter name is not recognized, use position to determine + if len(params) == 1: + call_args[param_name] = pl_module + elif len(params) == 2: + if len(call_args) == 0: + call_args[param_name] = trainer + else: + call_args[param_name] = pl_module + else: + raise ValueError( + f"Unable to determine parameter mapping for '{param_name}'. " + f"Please use recognized parameter names: " + f"trainer/pl_trainer for trainer, " + f"model/pl_model/pl_module/module for model." + ) + + try: + return func(**call_args) + except TypeError as e: + raise TypeError( + f"Failed to call callback function {func.__name__ if hasattr(func, '__name__') else func}. " + f"Attempted to pass arguments: {call_args.keys()}. Error: {str(e)}" + ) from e + + return wrapped diff --git a/nemo/lightning/pytorch/callbacks/model_checkpoint.py b/nemo/lightning/pytorch/callbacks/model_checkpoint.py index b384976d82bd..455022b1ba44 100644 --- a/nemo/lightning/pytorch/callbacks/model_checkpoint.py +++ b/nemo/lightning/pytorch/callbacks/model_checkpoint.py @@ -17,14 +17,14 @@ import shutil from datetime import timedelta from pathlib import Path -from typing import Any, Dict, Iterable, List, Optional, Union +from typing import Any, Dict, Iterable, List, Literal, Optional, Union -import pytorch_lightning +import lightning.pytorch import torch from _weakref import proxy -from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint as PTLModelCheckpoint -from pytorch_lightning.callbacks.model_checkpoint import _is_local_file_protocol -from pytorch_lightning.utilities import rank_zero_info +from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint as PTLModelCheckpoint +from lightning.pytorch.callbacks.model_checkpoint import _is_local_file_protocol +from lightning.pytorch.utilities import rank_zero_info from nemo.lightning.ckpt_utils import ckpt_to_dir from nemo.lightning.io.pl import TrainerContext @@ -63,7 +63,7 @@ def __init__( self, monitor: Optional[str] = "val_loss", verbose: bool = True, - save_last: Optional[bool] = True, + save_last: Optional[Union[bool, Literal["link"]]] = True, save_top_k: int = 3, save_weights_only: bool = False, ## TODO: check support mode: str = "min", @@ -312,7 +312,7 @@ def _del_model_without_trainer(self, filepath: str) -> None: if torch.distributed.is_initialized(): torch.distributed.barrier() - def _ema_callback(self, trainer: 'pytorch_lightning.Trainer'): + def _ema_callback(self, trainer: 'lightning.pytorch.Trainer'): from nemo.collections.common.callbacks import EMA ema_callback = None @@ -393,7 +393,7 @@ def remove_checkpoint_unfinished_marker(checkpoint_path: Union[Path, str], barri except: return - def file_exists(self, filepath: str, trainer: "pytorch_lightning.Trainer", check_dist_ckpt: bool = True) -> bool: + def file_exists(self, filepath: str, trainer: "lightning.pytorch.Trainer", check_dist_ckpt: bool = True) -> bool: """Checks if a file or a file without a suffix (distributed checkpoint) exists.""" exists = self._fs.exists(filepath) or (check_dist_ckpt and self._fs.exists(ckpt_to_dir(filepath))) return trainer.strategy.broadcast(exists) @@ -432,7 +432,7 @@ def _link_checkpoint(self, trainer: "pl.Trainer", filepath: str, linkpath: str, linkpath = ckpt_to_dir(linkpath) super()._link_checkpoint(trainer, filepath, linkpath) - def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) -> None: + def _save_checkpoint(self, trainer: 'lightning.pytorch.Trainer', filepath: str) -> None: from nemo.utils.get_rank import is_global_rank_zero # barrier_after=True, so all ranks continue after the unfinished checkpoint marker is placed. @@ -499,7 +499,7 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) finalize_fn() def _get_finalize_save_checkpoint_callback( - self, trainer: 'pytorch_lightning.Trainer', filepath: str, global_step: int + self, trainer: 'lightning.pytorch.Trainer', filepath: str, global_step: int ): """Creates a callback that can be used to finalize async (and sync) ckpt saves.""" @@ -534,7 +534,7 @@ def _cb(): return _cb - def _remove_checkpoint(self, trainer: "pytorch_lightning.Trainer", filepath: str, override_async=False) -> None: + def _remove_checkpoint(self, trainer: "lightning.pytorch.Trainer", filepath: str, override_async=False) -> None: """Performs checkpoint removal. With async save, `self._remove_checkpoint` is called before the checkpoint diff --git a/nemo/lightning/pytorch/callbacks/model_transform.py b/nemo/lightning/pytorch/callbacks/model_transform.py index 64602b501ac3..b3c3310aa30f 100644 --- a/nemo/lightning/pytorch/callbacks/model_transform.py +++ b/nemo/lightning/pytorch/callbacks/model_transform.py @@ -15,7 +15,7 @@ from functools import wraps from typing import Any, Callable, Optional, TypeVar -import pytorch_lightning as pl +import lightning.pytorch as pl from torch import nn from nemo.utils import logging @@ -85,7 +85,7 @@ def _maybe_apply_transform(self, trainer): def apply_transform(self, trainer): self.model_transform(trainer.model) - from pytorch_lightning.utilities import model_summary + from lightning.pytorch.utilities import model_summary logging.info( f"After applying model_transform:\n" f"{model_summary.summarize(trainer.lightning_module, max_depth=1)}" diff --git a/nemo/lightning/pytorch/callbacks/moe_token_drop.py b/nemo/lightning/pytorch/callbacks/moe_token_drop.py index 10483dca5096..b0c7ff7999eb 100644 --- a/nemo/lightning/pytorch/callbacks/moe_token_drop.py +++ b/nemo/lightning/pytorch/callbacks/moe_token_drop.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl +from lightning.pytorch.callbacks.callback import Callback from megatron.core import ModelParallelConfig -from pytorch_lightning.callbacks.callback import Callback from nemo.lightning.pytorch.strategies.megatron_strategy import MegatronStrategy diff --git a/nemo/lightning/pytorch/callbacks/nsys.py b/nemo/lightning/pytorch/callbacks/nsys.py index 2a5707d3166c..2a6bc3668b94 100644 --- a/nemo/lightning/pytorch/callbacks/nsys.py +++ b/nemo/lightning/pytorch/callbacks/nsys.py @@ -15,7 +15,7 @@ from typing import List, Optional import torch -from pytorch_lightning.callbacks.callback import Callback +from lightning.pytorch.callbacks.callback import Callback from nemo.utils import logging from nemo.utils.get_rank import get_rank @@ -74,10 +74,14 @@ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx: int) -> Opt """ device = trainer.strategy.root_device - current_step = trainer.strategy.current_epoch_step + try: + # Not all strategies have this. e.g.: + # AttributeError: 'SingleDeviceStrategy' object has no attribute 'current_epoch_step' + current_step = trainer.strategy.current_epoch_step + except AttributeError: + current_step = self._nsys_profile_start_step if device.type == 'cuda': if current_step == self._nsys_profile_start_step and get_rank() in self._nsys_profile_ranks: - logging.info("====== Start nsys profiling ======") torch.cuda.cudart().cudaProfilerStart() if self._nsys_profile_gen_shape: torch.autograd.profiler.emit_nvtx(record_shapes=True).__enter__() @@ -91,9 +95,11 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx: int) """ device = trainer.strategy.root_device - current_step = trainer.strategy.current_epoch_step + try: + current_step = trainer.strategy.current_epoch_step + except AttributeError: + current_step = self._nsys_profile_end_step if device.type == 'cuda': if current_step == self._nsys_profile_end_step and get_rank() in self._nsys_profile_ranks: - logging.info("====== End nsys profiling ======") torch.cuda.cudart().cudaProfilerStop() torch.autograd.profiler.emit_nvtx().__exit__(None, None, None) diff --git a/nemo/lightning/pytorch/callbacks/peft.py b/nemo/lightning/pytorch/callbacks/peft.py index 5336615a4a38..c94e1f8e003e 100644 --- a/nemo/lightning/pytorch/callbacks/peft.py +++ b/nemo/lightning/pytorch/callbacks/peft.py @@ -18,12 +18,12 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple -import pytorch_lightning as pl +import lightning.pytorch as pl import torch import torch.nn as nn -from lightning_fabric.utilities.types import _PATH -from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO -from pytorch_lightning.trainer.states import TrainerFn +from lightning.fabric.utilities.types import _PATH +from lightning.pytorch.plugins.io.wrapper import _WrappingCheckpointIO +from lightning.pytorch.trainer.states import TrainerFn from typing_extensions import override from nemo.lightning.ckpt_utils import ADAPTER_META_FILENAME @@ -105,6 +105,8 @@ def __call__(self, model: nn.Module) -> nn.Module: else: model.walk(self.transform) + if hasattr(model, "trainer") and model.trainer.state.fn != TrainerFn.FITTING: + self.freeze_model(model) return model def freeze_model(self, model: nn.Module) -> None: @@ -126,9 +128,11 @@ def freeze_model(self, model: nn.Module) -> None: model.module.freeze() else: model.freeze() - model.train(mode=True) + if hasattr(model, "trainer") and model.trainer.state.fn == TrainerFn.FITTING: + model.train(mode=True) def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str) -> None: + """PTL callback setup function.""" from nemo.lightning.pytorch.strategies.utils import create_checkpoint_io super().setup(trainer, pl_module, stage=stage) @@ -160,6 +164,13 @@ def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str) trainer.strategy._setup_optimizers = False def apply_transform(self, trainer): + """ + This function does the following: + 1. Apply PEFT model transform. + 2. Set up model parallel and optimizer, which were skipped in setup + 3. Load weights and optimizer state dict + 4. Set up `finalize_model_grads` from mcore. + """ super().apply_transform(trainer) self.trainable_params = set( name for name, param in trainer.lightning_module.named_parameters() if param.requires_grad @@ -205,6 +216,10 @@ def apply_transform(self, trainer): ) def adapter_key_filter(self, key: str) -> bool: + """ + Given a key in the state dict, return whether the key is an adapter (or base model). + This function can be subclassed in each PEFT method class. + """ return key in self.trainable_params or ".adapter." in key or key.endswith(".adapters") @@ -241,6 +256,36 @@ def __init__(self, to_wrap: nn.Module, adapter: nn.Module): self.to_wrap = to_wrap self.adapter = adapter + def base_linear_forward(self, x): + """ + Run the forward method of the linear module `to_wrap`. + Return a tuple of three elements: linear_output, bias, layernorm_output + + x -> [layernorm/identity] -> layernorm_output -> [linear] -> linear_output, bias + + layernorm_output is different from input x only when linear layer is LayerNormColumnParallelLinear. + """ + linear_output = self.to_wrap(x) + assert isinstance( + linear_output, tuple + ), f"{self.to_wrap} should return a tuple but instead returns {linear_output}" + """ Four cases for the wrapped module's return values + 1. nothing: (out, None) + 2. return_bias: (out, bias) + 2. return_layernorm_output: ((out, ln_out), None) + 3. both: (out, bias, ln_out) + """ + bias = None + layernorm_output = x + if len(linear_output) == 2: + linear_output, bias = linear_output + if isinstance(linear_output, tuple) and len(linear_output) == 2: + linear_output, layernorm_output = linear_output + elif len(linear_output) == 3: + linear_output, bias, layernorm_output = linear_output + + return linear_output, bias, layernorm_output + def state_dict(self, destination=None, prefix='', keep_vars=False): """Retrieve the state dictionary of the wrapped module and adapter. @@ -295,33 +340,38 @@ def sharded_state_dict( sharded_state_dict.update(self.adapter.sharded_state_dict(f"{prefix}adapter.", sharded_offsets, metadata)) return sharded_state_dict - def load_state_dict(self, state_dict, strict=True): - """Load a state dictionary into the wrapped module and adapter. - This method overrides the default load_state_dict behavior to handle - loading states for both the main module and the adapter. +class WrappedAdapterIO(_WrappingCheckpointIO, AsyncCompatibleCheckpointIO): + """ + A wrapper class for checkpoint I/O operations, specifically designed for PEFT (Parameter-Efficient Fine-Tuning). - Args: - state_dict (dict): The state dictionary to load. - strict (bool): Whether to strictly enforce that the keys in state_dict - match the keys returned by this module's state_dict() - function. Defaults to True. - """ - # Check if the 'adapters' key is present in the state_dict - if 'adapters' in state_dict: - adapter_state_dict = state_dict.pop('adapters') - else: - adapter_state_dict = {} + This class handles the complexities of saving and loading checkpoints for both initial PEFT training and resuming + PEFT training. It ensures that only the necessary adapter weights are saved and loaded, while also preserving the + base model weights. - # Load the main module state dict - self.to_wrap.load_state_dict(state_dict, strict) + **Usage:** - # Load the adapter module state dict if present - if adapter_state_dict: - self.adapter.load_state_dict(adapter_state_dict, strict) + 1. **Initial PEFT Training:** + - The class handles the saving of only adapter weights. + - Metadata about the base model checkpoint is stored for future reference. + 2. **PEFT Resume:** + - The class loads both base model and adapter weights. + - The previously stored metadata is used to locate the correct base model checkpoint. + + **Attributes:** + + - `peft`: The PEFT instance associated with the wrapped checkpoint I/O. + - `model_ckpt_path`: The path to the base model checkpoint. + - `adapter_ckpt_path`: The path to the adapter checkpoint. + Note that the paths are set by save/load functions and users do not need to set them. + + **Methods:** + + - `save_checkpoint`: Saves the adapter weights and metadata to the specified path. + - `load_checkpoint`: Loads the base model and adapter weights based on the specified path and metadata. + """ -class WrappedAdapterIO(_WrappingCheckpointIO, AsyncCompatibleCheckpointIO): peft: Optional[PEFT] = None model_ckpt_path: Optional[Path] = None adapter_ckpt_path: Optional[Path] = None @@ -355,7 +405,11 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_optio @override def load_checkpoint( - self, path: _PATH, sharded_state_dict=None, map_location: Optional[Callable] = None + self, + path: _PATH, + sharded_state_dict=None, + map_location: Optional[Callable] = None, + strict: Optional['StrictHandling'] | bool = None, ) -> Dict[str, Any]: """ ===================== @@ -374,7 +428,8 @@ def load_checkpoint( As such, this function will be entered twice during PEFT training resume. For the FIRST TIME this function is called by trainer._checkpoint_connector._restore_modules_and_callbacks. - `path = AdapterPath(, base_model_path=)`, and sharded_state_dict contains only base model weights + `path = AdapterPath(, base_model_path=)`, and sharded_state_dict contains only base + model weights For the SECOND TIME this function is called by PEFT.apply_transform (above, in the same file). `path = PosixPath()`, and sharded_state_dict contains only adapter weights. @@ -401,7 +456,7 @@ def load_checkpoint( self.model_ckpt_path = path # Note: this will include the Trainer-state of the model-checkpoint - model_ckpt = self.checkpoint_io.load_checkpoint(path, sharded_state_dict, map_location) + model_ckpt = self.checkpoint_io.load_checkpoint(path, sharded_state_dict, map_location, strict) if adapter_ckpt is not None: ## PEFT Resume, FIRST TIME adapter_ckpt['state_dict'].update(model_ckpt['state_dict']) diff --git a/nemo/lightning/pytorch/callbacks/preemption.py b/nemo/lightning/pytorch/callbacks/preemption.py index 69ac378ed698..98b59a9da0d0 100644 --- a/nemo/lightning/pytorch/callbacks/preemption.py +++ b/nemo/lightning/pytorch/callbacks/preemption.py @@ -18,8 +18,8 @@ from typing import Optional import torch -from pytorch_lightning.callbacks import Callback -from pytorch_lightning.trainer.trainer import Trainer +from lightning.pytorch.callbacks import Callback +from lightning.pytorch.trainer.trainer import Trainer from nemo.lightning.io.mixin import IOMixin from nemo.utils import logging diff --git a/nemo/lightning/pytorch/callbacks/progress_bar.py b/nemo/lightning/pytorch/callbacks/progress_bar.py index 6912c3fc57d4..f3c3c4555bac 100644 --- a/nemo/lightning/pytorch/callbacks/progress_bar.py +++ b/nemo/lightning/pytorch/callbacks/progress_bar.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from pytorch_lightning.callbacks.progress import TQDMProgressBar -from pytorch_lightning.callbacks.progress.tqdm_progress import _update_n +from lightning.pytorch.callbacks.progress import TQDMProgressBar +from lightning.pytorch.callbacks.progress.tqdm_progress import _update_n class MegatronProgressBar(TQDMProgressBar): diff --git a/nemo/lightning/pytorch/callbacks/progress_printer.py b/nemo/lightning/pytorch/callbacks/progress_printer.py index d32f7d70cbdd..12d05ed2950c 100644 --- a/nemo/lightning/pytorch/callbacks/progress_printer.py +++ b/nemo/lightning/pytorch/callbacks/progress_printer.py @@ -15,9 +15,9 @@ from collections import defaultdict from typing import Any +from lightning.pytorch.callbacks.progress import ProgressBar +from lightning.pytorch.utilities.types import STEP_OUTPUT from megatron.core.num_microbatches_calculator import get_num_microbatches -from pytorch_lightning.callbacks.progress import ProgressBar -from pytorch_lightning.utilities.types import STEP_OUTPUT from typing_extensions import override diff --git a/nemo/lightning/pytorch/optim/base.py b/nemo/lightning/pytorch/optim/base.py index 1d476142941a..fec3b7c118a4 100644 --- a/nemo/lightning/pytorch/optim/base.py +++ b/nemo/lightning/pytorch/optim/base.py @@ -17,8 +17,8 @@ from copy import deepcopy from typing import List, Optional -import pytorch_lightning as L -from pytorch_lightning.utilities.types import OptimizerLRScheduler +import lightning.pytorch as L +from lightning.pytorch.utilities.types import OptimizerLRScheduler from torch.optim import Optimizer from nemo.lightning.io.mixin import IOMixin diff --git a/nemo/lightning/pytorch/optim/megatron.py b/nemo/lightning/pytorch/optim/megatron.py index 7ac413d4544f..9f9d2029be9e 100644 --- a/nemo/lightning/pytorch/optim/megatron.py +++ b/nemo/lightning/pytorch/optim/megatron.py @@ -15,7 +15,7 @@ import inspect from typing import Callable, List, Optional -import pytorch_lightning as pl +import lightning.pytorch as pl from megatron.core.distributed import finalize_model_grads from megatron.core.optimizer import OptimizerConfig from megatron.core.utils import get_model_config diff --git a/nemo/lightning/pytorch/optim/pytorch.py b/nemo/lightning/pytorch/optim/pytorch.py index 9d773917e4f4..ccd03f563ef8 100644 --- a/nemo/lightning/pytorch/optim/pytorch.py +++ b/nemo/lightning/pytorch/optim/pytorch.py @@ -14,8 +14,8 @@ from typing import Callable, List, Optional -import pytorch_lightning as pl -import pytorch_lightning as L +import lightning.pytorch as pl +import lightning.pytorch as L from torch.optim import Optimizer from torch.optim.optimizer import ParamsT diff --git a/nemo/lightning/pytorch/plugins/data_sampler.py b/nemo/lightning/pytorch/plugins/data_sampler.py index 024e2577c868..479e442d5ccb 100644 --- a/nemo/lightning/pytorch/plugins/data_sampler.py +++ b/nemo/lightning/pytorch/plugins/data_sampler.py @@ -16,7 +16,7 @@ import logging from typing import List, Literal, Optional -import pytorch_lightning as pl +import lightning.pytorch as pl from torch.utils.data import DataLoader from nemo.lightning.megatron_parallel import MegatronStep diff --git a/nemo/lightning/pytorch/plugins/mixed_precision.py b/nemo/lightning/pytorch/plugins/mixed_precision.py index 5c318b59e54a..830978ba11e7 100644 --- a/nemo/lightning/pytorch/plugins/mixed_precision.py +++ b/nemo/lightning/pytorch/plugins/mixed_precision.py @@ -16,9 +16,8 @@ from dataclasses import dataclass, fields from typing import Any, Callable, Generator, List, Literal, Tuple, TypeVar, Union -import pytorch_lightning as pl import torch -from pytorch_lightning.plugins.precision import Precision +from lightning.pytorch.plugins.precision import Precision from torch.nn import Module from torch.optim import Optimizer diff --git a/nemo/lightning/pytorch/strategies/fsdp_strategy.py b/nemo/lightning/pytorch/strategies/fsdp_strategy.py index 83d5781c0dde..05d0f3d629f3 100644 --- a/nemo/lightning/pytorch/strategies/fsdp_strategy.py +++ b/nemo/lightning/pytorch/strategies/fsdp_strategy.py @@ -12,19 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import shutil from collections import OrderedDict from pathlib import Path from typing import Any, Dict, Optional, Union -import pytorch_lightning as pl +import lightning.pytorch as pl import torch -from lightning_fabric.plugins import CheckpointIO -from lightning_fabric.strategies.fsdp import _get_sharded_state_dict_context +from lightning.fabric.plugins import CheckpointIO +from lightning.fabric.strategies.fsdp import _get_sharded_state_dict_context +from lightning.pytorch.strategies.fsdp import FSDPStrategy as PLFSDPStrategy +from lightning.pytorch.trainer.states import TrainerFn +from lightning.pytorch.utilities.types import STEP_OUTPUT from megatron.core.transformer.transformer_layer import TransformerLayer -from pytorch_lightning.strategies.fsdp import FSDPStrategy as PLFSDPStrategy -from pytorch_lightning.trainer.states import TrainerFn -from pytorch_lightning.utilities.types import STEP_OUTPUT from torch.distributed.checkpoint.state_dict import ( # get_state_dict, StateDictOptions, get_optimizer_state_dict, @@ -212,17 +213,17 @@ def save_checkpoint( checkpoint["sharded_state_dict"] = pyt_to_mcore_state_dict(checkpoint.pop("state_dict")) checkpoint["state_dict"] = OrderedDict([]) - ## replace unsharded optimizer_states with sharded dict. - ## note that if trainer.save_checkpoint(path, save_weights_only=True) is called, - ## the checkpoint will contain only model weights. Optimizer states will be omitted. - if ( - "optimizer_states" in checkpoint - and self.trainer.state.fn == TrainerFn.FITTING - and self.ckpt_save_optimizer - ): + if "optimizer_states" in checkpoint and self.trainer.state.fn == TrainerFn.FITTING: + # Clear the optimizer states. This handles the case where ckpt_save_optimizer=False + # Ideally, the optimizer state dicts should not be generated in this case checkpoint["optimizer_states"] = {} - checkpoint['optimizer'] = get_optimizer_state_dict(self.model, self.optimizers) - pyt_to_mcore_state_dict(checkpoint['optimizer']['state'], prefix="optimizer.state.") + + ## replace unsharded optimizer_states with sharded dict. + ## note that if trainer.save_checkpoint(path, save_weights_only=True) is called, + ## the checkpoint will contain only model weights. Optimizer states will be omitted. + if self.ckpt_save_optimizer: + checkpoint['optimizer'] = get_optimizer_state_dict(self.model, self.optimizers) + pyt_to_mcore_state_dict(checkpoint['optimizer']['state'], prefix="optimizer.state.") self.checkpoint_io.save_checkpoint(checkpoint, filepath, storage_options=storage_options) diff --git a/nemo/lightning/pytorch/strategies/megatron_strategy.py b/nemo/lightning/pytorch/strategies/megatron_strategy.py index c62a90313b45..b74677b01b09 100644 --- a/nemo/lightning/pytorch/strategies/megatron_strategy.py +++ b/nemo/lightning/pytorch/strategies/megatron_strategy.py @@ -35,20 +35,20 @@ cast, ) -import pytorch_lightning as pl +import lightning.pytorch as pl import torch import torch.distributed -from lightning_fabric.plugins import CheckpointIO, ClusterEnvironment -from lightning_fabric.utilities.optimizer import _optimizer_to_device, _optimizers_to_device +from lightning.fabric.plugins import CheckpointIO, ClusterEnvironment +from lightning.fabric.utilities.optimizer import _optimizer_to_device, _optimizers_to_device +from lightning.pytorch.accelerators import CPUAccelerator +from lightning.pytorch.loops import _AutomaticOptimization, evaluation_loop, fit_loop, prediction_loop +from lightning.pytorch.loops.fetchers import _DataLoaderIterDataFetcher +from lightning.pytorch.overrides.distributed import _sync_module_states +from lightning.pytorch.strategies.ddp import DDPStrategy +from lightning.pytorch.trainer.states import RunningStage, TrainerFn +from lightning.pytorch.utilities.types import STEP_OUTPUT from megatron.core.distributed import DistributedDataParallelConfig from megatron.core.optimizer import OptimizerConfig -from pytorch_lightning.accelerators import CPUAccelerator -from pytorch_lightning.loops import _AutomaticOptimization, evaluation_loop, fit_loop, prediction_loop -from pytorch_lightning.loops.fetchers import _DataLoaderIterDataFetcher -from pytorch_lightning.overrides.distributed import _sync_module_states -from pytorch_lightning.strategies.ddp import DDPStrategy -from pytorch_lightning.trainer.states import RunningStage, TrainerFn -from pytorch_lightning.utilities.types import STEP_OUTPUT from torch import nn from torch.distributed.algorithms.ddp_comm_hooks.debugging_hooks import noop_hook from torch.nn.parallel import DistributedDataParallel @@ -156,6 +156,9 @@ class MegatronStrategy(DDPStrategy, io.IOMixin): ckpt_load_directly_on_device (bool): if True, loads the weights directly on GPU. Has effect only for `zarr` based checkpoints (PyT Distributed always loads on device). Defaults to True. + ckpt_load_strictness (StrictHandling, optional): defines loading strictness. + If not None, overwrites the `strict` flag passed to `load_checkpoint`. + Defaults to None. setup_optimizers (bool): Whether to call the trainer's setup_optimizers function to perform any necessary conversions of optimizer parameters and move optimizer parameters to the correct device. Defaults to True. @@ -204,6 +207,7 @@ def __init__( ckpt_parallel_load: bool = True, ckpt_parallel_save_optim: bool = True, ckpt_load_directly_on_device: bool = True, + ckpt_load_strictness: Optional['StrictHandling'] = None, setup_optimizers: bool = True, init_model_parallel: bool = True, replace_progress_bar: bool = True, @@ -238,6 +242,7 @@ def __init__( self.lazy_init = lazy_init self.ckpt_load_optimizer = ckpt_load_optimizer self.ckpt_save_optimizer = ckpt_save_optimizer + self.ckpt_load_strictness = ckpt_load_strictness self.pipeline_dtype = pipeline_dtype self._setup_optimizers = setup_optimizers self._init_model_parallel = init_model_parallel @@ -278,7 +283,7 @@ def connect(self, model: pl.LightningModule) -> None: """Attaches a model to strategy.""" super().connect(model) - assert not 'is_hf_model' in model.__dict__, "Cannot use HfAutoModelForCausalLM with MegatronParallel" + assert not 'is_hf_model' in model.__dict__, "Cannot use HFAutoModelForCausalLM with MegatronParallel" dtype_config = getattr(self._precision_plugin, "dtype_config", None) if self.pipeline_dtype is None and dtype_config: @@ -693,16 +698,16 @@ def save_checkpoint( if "sharded_state_dict" not in checkpoint: checkpoint["sharded_state_dict"] = self.megatron_parallel.sharded_state_dict() - ## replace unsharded optimizer_states with sharded dict. - ## note that if trainer.save_checkpoint(path, save_weights_only=True) is called, - ## the checkpoint will contain only model weights. Optimizer states will be omitted. - if ( - "optimizer_states" in checkpoint - and self.trainer.state.fn == TrainerFn.FITTING - and self.ckpt_save_optimizer - ): + if "optimizer_states" in checkpoint and self.trainer.state.fn == TrainerFn.FITTING: + # Clear the optimizer states. This handles the case where ckpt_save_optimizer=False + # Ideally, the optimizer state dicts should not be generated in this case checkpoint["optimizer_states"] = {} - checkpoint["optimizer"] = [self.optimizer_sharded_state_dict()] + + ## replace unsharded optimizer_states with sharded dict. + ## note that if trainer.save_checkpoint(path, save_weights_only=True) is called, + ## the checkpoint will contain only model weights. Optimizer states will be omitted. + if self.ckpt_save_optimizer: + checkpoint["optimizer"] = [self.optimizer_sharded_state_dict()] self.checkpoint_io.save_checkpoint(checkpoint, filepath, storage_options=storage_options) @@ -733,7 +738,19 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path], selective_restore: if self.lightning_module.optimizers(use_pl_optimizer=False): sharded_state_dict["optimizer"] = [self.optimizer_sharded_state_dict(is_loading=True)] - checkpoint = self.checkpoint_io.load_checkpoint(checkpoint_path, sharded_state_dict=sharded_state_dict) + strict = ( + self.lightning_module.strict_loading if self.ckpt_load_strictness is None else self.ckpt_load_strictness + ) + checkpoint = self.checkpoint_io.load_checkpoint( + checkpoint_path, sharded_state_dict=sharded_state_dict, strict=strict + ) + + if selective_restore: + final_checkpoint = {} + for key in sharded_state_dict.keys(): + final_checkpoint[key] = checkpoint[key] + + return final_checkpoint return checkpoint @@ -748,7 +765,8 @@ def selective_restore(self) -> None: if self.restore_config.load_model_state: logging.info(f"Restoring model weights from {self.restore_config}") - self.load_model_state_dict(checkpoint=checkpoint) + strict = True if self.ckpt_load_strictness is None else self.ckpt_load_strictness + self.load_model_state_dict(checkpoint=checkpoint, strict=strict) if self.restore_config.load_optim_state: logging.info(f"Restoring optimizer states from {self.restore_config}") @@ -783,6 +801,7 @@ def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict: bool = Tr """loads model state dict""" assert self.megatron_parallel is not None + strict = strict if self.ckpt_load_strictness is None else self.ckpt_load_strictness _strategy_lib.load_model_state_dict(self.megatron_parallel, checkpoint, strict=strict) if not 'optimizer' in checkpoint: diff --git a/nemo/lightning/pytorch/strategies/utils.py b/nemo/lightning/pytorch/strategies/utils.py index 43a5a9243aa5..4f5a78419d6d 100644 --- a/nemo/lightning/pytorch/strategies/utils.py +++ b/nemo/lightning/pytorch/strategies/utils.py @@ -17,15 +17,14 @@ from pathlib import Path from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, cast -import pytorch_lightning as pl +import lightning.pytorch as pl import torch -from lightning_fabric.plugins import ClusterEnvironment +from lightning.fabric.plugins import ClusterEnvironment +from lightning.pytorch.callbacks import TQDMProgressBar from megatron.core import parallel_state from megatron.core.dist_checkpointing.mapping import ShardedBase, ShardedObject, ShardedTensor from megatron.core.dist_checkpointing.strategies.torch import sharded_tensor_to_torch_sharded_tensor from megatron.core.transformer.utils import _get_extra_state_offsets -from pytorch_lightning.callbacks import TQDMProgressBar -from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO from torch.distributed._sharded_tensor import ShardedTensor as TorchShardedTensor from torch.distributed._tensor import DTensor, Replicate, Shard from torch.distributed.device_mesh import DeviceMesh diff --git a/nemo/lightning/pytorch/trainer.py b/nemo/lightning/pytorch/trainer.py index c97c59ef524d..701c1cde4eaf 100644 --- a/nemo/lightning/pytorch/trainer.py +++ b/nemo/lightning/pytorch/trainer.py @@ -16,9 +16,9 @@ from copy import deepcopy import fiddle as fdl -import pytorch_lightning as pl -from pytorch_lightning.loops import _TrainingEpochLoop -from pytorch_lightning.loops.fetchers import _DataFetcher +import lightning.pytorch as pl +from lightning.pytorch.loops import _TrainingEpochLoop +from lightning.pytorch.loops.fetchers import _DataFetcher from typing_extensions import Self from nemo.lightning.fabric.conversion import to_fabric diff --git a/nemo/lightning/resume.py b/nemo/lightning/resume.py index 412ca8665b84..6d6ddda1fd80 100644 --- a/nemo/lightning/resume.py +++ b/nemo/lightning/resume.py @@ -18,8 +18,8 @@ from pathlib import Path, PosixPath, WindowsPath from typing import Optional, Union -import lightning_fabric as fl -import pytorch_lightning as pl +import lightning.fabric as fl +import lightning.pytorch as pl from nemo.lightning import io from nemo.lightning.base import NEMO_MODELS_CACHE @@ -37,17 +37,25 @@ def _try_restore_tokenizer(model, ckpt_path): + from nemo.collections.common.tokenizers import TokenizerSpec from nemo.lightning.io import load_context try: tokenizer = load_context(ckpt_path, "model.tokenizer") + except ValueError as e: + logging.warning( + f"Encountered error while trying to restore tokenizer. Tokenizer is not restored. " f"Original error: {e}" + ) + return model + + if isinstance(tokenizer, TokenizerSpec): model.tokenizer = tokenizer model.__io__.tokenizer = tokenizer.__io__ - except: - # Ignore if the ckpt doesn't have a tokenizer. - pass - finally: - return model + else: + # Ignore if the ckpt doesn't have a tokenizer. type(tokenizer)==TrainerContext in this case. + logging.warning("Checkpoint does not have model.tokenizer field. Tokenizer is not restored.") + + return model @dataclass(kw_only=True) @@ -56,8 +64,10 @@ class AutoResume: checkpoints in NeMo. Attributes: - restore_config (Optional[RestoreConfig]): Optional config for selectively restoring specific parts like model weights, optimizer states, etc. - If the config contains a path from HF or another non-NeMo checkpoint format, the checkpoint will be automatically converted to a NeMo compatible format. + restore_config (Optional[RestoreConfig]): Optional config for selectively restoring specific parts like model + weights, optimizer states, etc. + If the config contains a path from HF or another non-NeMo checkpoint format, the checkpoint will be + automatically converted to a NeMo compatible format. resume_from_folder or the run's log_dir takes precedence over restore_config. resume_from_directory (str): Path to the checkpointing directory to restore from. resume_from_path (str): Path to a specific checkpoint to restore from. @@ -209,17 +219,22 @@ def _find_trainer_ckpt_path(self) -> Optional[Path]: if not checkpoint_dir.exists() or (not len(end_checkpoints) > 0 and not len(last_checkpoints) > 0): if self.resume_ignore_no_checkpoint: - warn = f"There were no checkpoints found in checkpoint_dir or no checkpoint folder at checkpoint_dir :{checkpoint_dir}. " + warn = ( + f"There were no checkpoints found in checkpoint_dir or no checkpoint folder at checkpoint_dir " + f":{checkpoint_dir}. " + ) if checkpoint is None: warn += "Training from scratch." logging.warning(warn) else: if self.restore_config: - # resume_if_exists is True but run is not resumable. Do not fail and try to do selective restore later instead. + # resume_if_exists is True but run is not resumable. Do not fail and try to do selective restore + # later instead. return None else: raise NotFoundError( - f"There were no checkpoints found in checkpoint_dir or no checkpoint folder at checkpoint_dir :{checkpoint_dir}. Cannot resume." + f"There were no checkpoints found in checkpoint_dir or no checkpoint folder at checkpoint_dir " + f":{checkpoint_dir}. Cannot resume." ) elif len(end_checkpoints) > 0: if not self.resume_past_end: @@ -240,7 +255,8 @@ def _find_trainer_ckpt_path(self) -> Optional[Path]: # Select the checkpoint with the latest modified time checkpoint = sorted(last_checkpoints, key=lambda pth: pth.lstat().st_mtime, reverse=True)[0] logging.warning( - f"Multiple checkpoints {last_checkpoints} matches *last.ckpt. Selecting one with the latest modified time." + f"Multiple checkpoints {last_checkpoints} matches *last.ckpt. Selecting one with the latest " + f"modified time." ) else: checkpoint = last_checkpoints[0] diff --git a/nemo/lightning/run/plugins.py b/nemo/lightning/run/plugins.py index 9d2936e567ec..645665723706 100644 --- a/nemo/lightning/run/plugins.py +++ b/nemo/lightning/run/plugins.py @@ -20,17 +20,22 @@ import nemo_run as run import yaml +from lightning.pytorch import Callback +from lightning.pytorch.loggers import WandbLogger from nemo_run.core.serialization.yaml import YamlSerializer -from pytorch_lightning import Callback -from pytorch_lightning.loggers import WandbLogger from nemo.lightning.pytorch.callbacks import NsysCallback, PreemptionCallback from nemo.lightning.pytorch.strategies.megatron_strategy import MegatronStrategy from nemo.utils import logging +from nemo.utils.import_utils import safe_import + +res_module, HAVE_RES = safe_import('nvidia_resiliency_ext.ptl_resiliency') + # This file contains plugins based on NeMo-Run's run.Plugin API. # Plugins operate both on a configured task and an executor at the same time, and are specific to NeMo-Run. -# If you are adding functionality that goes directly into the Pytorch Lightning trainer, you may consider adding a callback instead of a plugin. +# If you are adding functionality that goes directly into the Pytorch Lightning trainer, +# you may consider adding a callback instead of a plugin. def _merge_callbacks(partial: run.Partial, callbacks: list[run.Config[Callback]]): @@ -79,6 +84,55 @@ def setup(self, task: run.Partial | run.Script, executor: run.Executor): _merge_callbacks(task, callbacks=self.callbacks) +@dataclass(kw_only=True) +class FaultTolerancePlugin(run.Plugin): + """ + A plugin for setting up the fault tolerance callback from nvidia-resiliency-ext. + This plugin enables workload hang detection, automatic calculation of timeouts used for hang detection, detection of rank(s) terminated due to an error and workload respawning in case of a failure. + Note: FaultTolerancePlugin does not work with the NsysPlugin. + Args: + num_in_process_restarts (int): Max number of restarts on failure, within the same job. Default is 3. + num_job_retries_on_failure (int): Max number of new job restarts on failure. Default is 2. + initial_rank_heartbeat_timeout (int): Timeouts are time intervals used by a rank monitor to detect that a rank is not alive. This is the max timeout for the initial heartbeat. Default is 1800. + rank_heartbeat_timeout (int): This is the timeout for subsequent hearbeats after the initial heartbeat. Default is 300. + """ + + num_in_process_restarts: int = 3 + num_job_retries_on_failure: int = 2 + initial_rank_heartbeat_timeout: int = 1800 + rank_heartbeat_timeout: int = 300 + + def setup(self, task: run.Partial | run.Script, executor: run.Executor): + + assert HAVE_RES, "nvidia-resiliency-ext.ptl_resiliency is required to use the FaultTolerancePlugin." + + executor.launcher = run.FaultTolerance( + max_restarts=self.num_in_process_restarts, + initial_rank_heartbeat_timeout=self.initial_rank_heartbeat_timeout, + rank_heartbeat_timeout=self.rank_heartbeat_timeout, + ) + executor.retries = self.num_job_retries_on_failure + + assert isinstance(task, run.Partial) + + callbacks = [ + run.Config( + res_module.FaultToleranceCallback, autoresume=True, calculate_timeouts=True, exp_dir=task.log.log_dir + ) + ] + + assert not executor.launcher.nsys_profile, "Nsys not supported with the FaultTolerancePlugin." + if hasattr(task, "trainer") and hasattr(task.trainer, "callbacks"): + assert all( + map( + lambda cb: not cb.__fn_or_cls__ == NsysCallback if "__fn_or_cls__" in dir(cb) else True, + task.trainer.callbacks, + ) + ), "Nsys not supported with FaultTolerancePlugin." + + _merge_callbacks(task, callbacks=callbacks) + + @dataclass(kw_only=True) class NsysPlugin(run.Plugin): """ @@ -260,8 +314,11 @@ class PerfEnvPlugin(run.Plugin): enable_layernorm_sm_margin: bool = True layernorm_sm_margin: int = 16 enable_vboost: bool = False + nccl_pp_comm_chunksize: int = None def get_vboost_srun_cmd(self, nodes, job_dir): + "Create the vboost `sudo nvidia-smi boost-slider --vboost 1` command" + import shlex vboost_cmd = " ".join( @@ -281,12 +338,13 @@ def get_vboost_srun_cmd(self, nodes, job_dir): return vboost_cmd def setup(self, task: run.Partial | run.Script, executor: run.Executor): + """Enable the performance environment settings""" if task.trainer.strategy.__fn_or_cls__ == MegatronStrategy: # Force program order kernel launch for TP, CP overlap tp_size = task.trainer.strategy.tensor_model_parallel_size cp_size = task.trainer.strategy.context_parallel_size - if tp_size > 1 and cp_size > 1: + if tp_size > 1 or cp_size > 1: executor.env_vars["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" # Set LayerNorm SM margin to support the overlap with LayerNorm kernel @@ -294,6 +352,13 @@ def setup(self, task: run.Partial | run.Script, executor: run.Executor): executor.env_vars["NVTE_FWD_LAYERNORM_SM_MARGIN"] = str(self.layernorm_sm_margin) executor.env_vars["NVTE_BWD_LAYERNORM_SM_MARGIN"] = str(self.layernorm_sm_margin) + # Set the chunk size of P2P communications. Using a large chunk size reduces the + # buffering overhead from the communication kernel execution time + pp_size = task.trainer.strategy.pipeline_model_parallel_size + if pp_size > 1 and self.nccl_pp_comm_chunksize is not None: + assert isinstance(self.nccl_pp_comm_chunksize, int) and self.nccl_pp_comm_chunksize > 1 + executor.env_vars["NCCL_P2P_NET_CHUNKSIZE"] = str(self.nccl_pp_comm_chunksize) + # Improve perf by steering power to tensor cores, may not work on all systems if self.enable_vboost and isinstance(executor, run.SlurmExecutor): vboost_cmd = self.get_vboost_srun_cmd(executor.nodes, executor.job_dir) diff --git a/nemo/utils/callbacks/cuda_graph.py b/nemo/utils/callbacks/cuda_graph.py index c78196934108..b44006828963 100644 --- a/nemo/utils/callbacks/cuda_graph.py +++ b/nemo/utils/callbacks/cuda_graph.py @@ -37,15 +37,15 @@ from types import MethodType from typing import Any, Dict -import pytorch_lightning as pl +import lightning.pytorch as pl import torch -from pytorch_lightning import LightningModule -from pytorch_lightning.callbacks import Callback -from pytorch_lightning.loops.optimization.automatic import ClosureResult -from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection, _ResultMetric -from pytorch_lightning.utilities import CombinedLoader, rank_zero_info -from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature -from pytorch_lightning.utilities.types import STEP_OUTPUT +from lightning.pytorch import LightningModule +from lightning.pytorch.callbacks import Callback +from lightning.pytorch.loops.optimization.automatic import ClosureResult +from lightning.pytorch.trainer.connectors.logger_connector.result import _ResultCollection, _ResultMetric +from lightning.pytorch.utilities import CombinedLoader, rank_zero_info +from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature +from lightning.pytorch.utilities.types import STEP_OUTPUT from torch.nn.parallel import DistributedDataParallel __all__ = ["CUDAGraphCallback"] @@ -431,8 +431,8 @@ def on_save_checkpoint( Called when saving a checkpoint to give you a chance to store anything else you might want to save. Args: - trainer: the current :class:`~pytorch_lightning.trainer.Trainer` instance. - pl_module: the current :class:`~pytorch_lightning.core.module.LightningModule` instance. + trainer: the current :class:`~lightning.pytorch.trainer.Trainer` instance. + pl_module: the current :class:`~lightning.pytorch.core.module.LightningModule` instance. checkpoint: the checkpoint dictionary that will be saved. """ # Since we've add bound method to optimizer and lr_scheduler, it can lead to more diff --git a/nemo/utils/callbacks/dist_ckpt_io.py b/nemo/utils/callbacks/dist_ckpt_io.py index 091075488878..b78ec9b4ac0f 100644 --- a/nemo/utils/callbacks/dist_ckpt_io.py +++ b/nemo/utils/callbacks/dist_ckpt_io.py @@ -19,12 +19,12 @@ from time import time from typing import Any, Dict, Optional, Union -import pytorch_lightning as pl -from lightning_fabric.plugins import CheckpointIO -from lightning_fabric.utilities.cloud_io import get_filesystem -from lightning_fabric.utilities.types import _PATH -from pytorch_lightning import Callback -from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO +import lightning.pytorch as pl +from lightning.fabric.plugins import CheckpointIO +from lightning.fabric.utilities.cloud_io import get_filesystem +from lightning.fabric.utilities.types import _PATH +from lightning.pytorch import Callback +from lightning.pytorch.plugins.io.wrapper import _WrappingCheckpointIO from nemo.utils import logging diff --git a/nemo/utils/callbacks/nemo_model_checkpoint.py b/nemo/utils/callbacks/nemo_model_checkpoint.py index dc1da9ce1875..8fe3beaaa985 100644 --- a/nemo/utils/callbacks/nemo_model_checkpoint.py +++ b/nemo/utils/callbacks/nemo_model_checkpoint.py @@ -19,14 +19,13 @@ from pathlib import Path from typing import Any, Dict, Iterable, List, Optional, Tuple, Union -import pytorch_lightning +import lightning.pytorch import torch from _weakref import proxy - -from lightning_fabric.utilities.cloud_io import get_filesystem -from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint, _is_local_file_protocol -from pytorch_lightning.trainer import call -from pytorch_lightning.utilities import rank_zero_info +from lightning.fabric.utilities.cloud_io import get_filesystem +from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint, _is_local_file_protocol +from lightning.pytorch.trainer import call +from lightning.pytorch.utilities import rank_zero_info from nemo.collections.common.callbacks import EMA from nemo.utils import logging @@ -357,7 +356,7 @@ def _del_model_without_trainer(self, filepath: str) -> None: except: logging.info(f"Tried to remove checkpoint: {filepath} but failed.") - def _ema_callback(self, trainer: 'pytorch_lightning.Trainer') -> Optional[EMA]: + def _ema_callback(self, trainer: 'lightning.pytorch.Trainer') -> Optional[EMA]: ema_callback = None for callback in trainer.callbacks: if isinstance(callback, EMA): @@ -506,12 +505,12 @@ def remove_checkpoint_unfinished_marker(checkpoint_path: Union[Path, str], barri except: return - def file_exists(self, filepath: str, trainer: "pytorch_lightning.Trainer", check_dist_ckpt: bool = True) -> bool: + def file_exists(self, filepath: str, trainer: "lightning.pytorch.Trainer", check_dist_ckpt: bool = True) -> bool: """Checks if a file or a file without a suffix (distributed checkpoint) exists.""" exists = self._fs.exists(filepath) or (check_dist_ckpt and self._fs.exists(ckpt_to_dir(filepath))) return trainer.strategy.broadcast(exists) - def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) -> None: + def _save_checkpoint(self, trainer: 'lightning.pytorch.Trainer', filepath: str) -> None: # barrier_after=True, so all ranks continue after the unfinished checkpoint marker is placed. # if anything goes wrong during checkpointing, we should be able to detect that data is incomplete. self.set_checkpoint_unfinished_marker(filepath, barrier_after=True) @@ -552,7 +551,7 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) self._drop_optimizer_states(trainer, filepath, storage_options) def _get_finalize_save_checkpoint_callback( - self, trainer: 'pytorch_lightning.Trainer', filepath: str, global_step: int + self, trainer: 'lightning.pytorch.Trainer', filepath: str, global_step: int ): """Creates a callback that can be used to finalize async (and sync) ckpt saves.""" @@ -585,7 +584,7 @@ def _cb(): return _cb - def _remove_checkpoint(self, trainer: "pytorch_lightning.Trainer", filepath: str, override_async=False) -> None: + def _remove_checkpoint(self, trainer: "lightning.pytorch.Trainer", filepath: str, override_async=False) -> None: """Performs checkpoint removal or deferred removal. With async save, `self._remove_checkpoint` is called before the checkpoint diff --git a/nemo/utils/callbacks/preemption.py b/nemo/utils/callbacks/preemption.py index e9b5f95022f3..178fe94cee7c 100644 --- a/nemo/utils/callbacks/preemption.py +++ b/nemo/utils/callbacks/preemption.py @@ -16,7 +16,7 @@ import sys import torch -from pytorch_lightning.callbacks import Callback +from lightning.pytorch.callbacks import Callback from nemo.utils import logging @@ -24,7 +24,7 @@ class PreemptionCallback(Callback): """ PreemptionCallback class creates a callback that checks for preemption during training at the end of every step. - Upon preemption the callback provides a function to gracefully exit the training immediately and also saves the current state in a checkpoint as *last.ckpt. + Upon preemption the callback provides a function to gracefully exit the training immediately and also saves the current state in a checkpoint as *last.ckpt. (to be able to start from the same step without wasting any compute while resuming the next time). PreemptionCallback is always enabled by default via the arg create_preemption_callback under ExpManagerConfig. To disable please pass @@ -47,7 +47,7 @@ def interrupted(self): def on_train_start(self, trainer, pl_module): """ - Defines custom handlers at the beginning of training to be executed when the + Defines custom handlers at the beginning of training to be executed when the preemption signal is received. """ diff --git a/nemo/utils/callbacks/s3_checkpoint_io.py b/nemo/utils/callbacks/s3_checkpoint_io.py index 7a9f984fee1b..4a48198311a2 100644 --- a/nemo/utils/callbacks/s3_checkpoint_io.py +++ b/nemo/utils/callbacks/s3_checkpoint_io.py @@ -22,7 +22,7 @@ from typing import Any, Callable, Dict, Optional, Union import torch -from lightning_fabric.plugins.io.checkpoint_io import CheckpointIO +from lightning.fabric.plugins.io.checkpoint_io import CheckpointIO from nemo.utils import logging from nemo.utils.s3_utils import ( diff --git a/nemo/utils/cloud.py b/nemo/utils/cloud.py index 7245567d636c..d565028bdf8c 100644 --- a/nemo/utils/cloud.py +++ b/nemo/utils/cloud.py @@ -17,8 +17,8 @@ from time import sleep import wget -from pytorch_lightning.plugins.environments import LightningEnvironment -from pytorch_lightning.strategies import DDPStrategy, StrategyRegistry +from lightning.pytorch.plugins.environments import LightningEnvironment +from lightning.pytorch.strategies import DDPStrategy, StrategyRegistry from nemo.utils import logging @@ -105,7 +105,10 @@ def initialize_sagemaker() -> None: """ StrategyRegistry.register( - name='smddp', strategy=SageMakerDDPStrategy, process_group_backend="smddp", find_unused_parameters=False, + name='smddp', + strategy=SageMakerDDPStrategy, + process_group_backend="smddp", + find_unused_parameters=False, ) def _install_system_libraries() -> None: diff --git a/nemo/utils/exp_manager.py b/nemo/utils/exp_manager.py index b512bc57cbab..04c43c46d247 100644 --- a/nemo/utils/exp_manager.py +++ b/nemo/utils/exp_manager.py @@ -26,18 +26,18 @@ from shutil import copy, move from typing import Any, Collection, Dict, List, Optional, Tuple, Union -import pytorch_lightning +import lightning.pytorch import torch from hydra.core.hydra_config import HydraConfig from hydra.utils import get_original_cwd +from lightning.pytorch.callbacks import Callback, ModelCheckpoint +from lightning.pytorch.callbacks.early_stopping import EarlyStopping +from lightning.pytorch.callbacks.timer import Interval, Timer +from lightning.pytorch.loggers import MLFlowLogger, NeptuneLogger, TensorBoardLogger, WandbLogger +from lightning.pytorch.loops import _TrainingEpochLoop +from lightning.pytorch.strategies.ddp import DDPStrategy +from lightning.pytorch.trainer.connectors.checkpoint_connector import _CheckpointConnector from omegaconf import DictConfig, OmegaConf, open_dict -from pytorch_lightning.callbacks import Callback, ModelCheckpoint -from pytorch_lightning.callbacks.early_stopping import EarlyStopping -from pytorch_lightning.callbacks.timer import Interval, Timer -from pytorch_lightning.loggers import MLFlowLogger, NeptuneLogger, TensorBoardLogger, WandbLogger -from pytorch_lightning.loops import _TrainingEpochLoop -from pytorch_lightning.strategies.ddp import DDPStrategy -from pytorch_lightning.trainer.connectors.checkpoint_connector import _CheckpointConnector from nemo.collections.common.callbacks import EMA from nemo.constants import NEMO_ENV_VARNAME_TESTING, NEMO_ENV_VARNAME_VERSION @@ -343,7 +343,7 @@ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx) self._on_batch_end("validation_step_timing in s", trainer, pl_module) -def exp_manager(trainer: 'pytorch_lightning.Trainer', cfg: Optional[Union[DictConfig, Dict]] = None) -> Optional[Path]: +def exp_manager(trainer: 'lightning.pytorch.Trainer', cfg: Optional[Union[DictConfig, Dict]] = None) -> Optional[Path]: """ exp_manager is a helper function used to manage folders for experiments. It follows the pytorch lightning paradigm of exp_dir/model_or_experiment_name/version. If the lightning trainer has a logger, exp_manager will get exp_dir, @@ -362,7 +362,7 @@ def exp_manager(trainer: 'pytorch_lightning.Trainer', cfg: Optional[Union[DictCo resume_if_exists is set to True, creating the version folders is ignored. Args: - trainer (pytorch_lightning.Trainer): The lightning trainer. + trainer (lightning.pytorch.Trainer): The lightning trainer. cfg (DictConfig, dict): Can have the following keys: - explicit_log_dir (str, Path): Can be used to override exp_dir/name/version folder creation. Defaults to @@ -680,7 +680,7 @@ def exp_manager(trainer: 'pytorch_lightning.Trainer', cfg: Optional[Union[DictCo return log_dir -def error_checks(trainer: 'pytorch_lightning.Trainer', cfg: Optional[Union[DictConfig, Dict]] = None): +def error_checks(trainer: 'lightning.pytorch.Trainer', cfg: Optional[Union[DictConfig, Dict]] = None): """ Checks that the passed trainer is compliant with NeMo and exp_manager's passed configuration. Checks that: - Throws error when hydra has changed the working directory. This causes issues with lightning's DDP @@ -728,7 +728,7 @@ def _filter_out_unfinished_checkpoints(checkpoint_paths: Collection[Union[Path, def check_resume( - trainer: 'pytorch_lightning.Trainer', + trainer: 'lightning.pytorch.Trainer', log_dir: str, resume_if_exists: bool = False, resume_past_end: bool = False, @@ -886,7 +886,7 @@ def check_resume( def check_explicit_log_dir( - trainer: 'pytorch_lightning.Trainer', explicit_log_dir: Union[Path, str], exp_dir: str, name: str, version: str + trainer: 'lightning.pytorch.Trainer', explicit_log_dir: Union[Path, str], exp_dir: str, name: str, version: str ) -> Tuple[Path, str, str, str]: """Checks that the passed arguments are compatible with explicit_log_dir. @@ -917,7 +917,7 @@ def check_explicit_log_dir( def get_log_dir( - trainer: 'pytorch_lightning.Trainer', + trainer: 'lightning.pytorch.Trainer', exp_dir: str = None, name: str = None, version: str = None, @@ -1025,7 +1025,7 @@ def get_git_diff(): def configure_loggers( - trainer: 'pytorch_lightning.Trainer', + trainer: 'lightning.pytorch.Trainer', exp_dir: [Path, str], log_dir: [Path, str], name: str, @@ -1136,7 +1136,7 @@ def resume_start(self, checkpoint_path=None) -> None: def configure_checkpointing( - trainer: 'pytorch_lightning.Trainer', + trainer: 'lightning.pytorch.Trainer', log_dir: Path, name: str, resume: bool, @@ -1257,12 +1257,12 @@ def _check_time_remaining(self, trainer: "pl.Trainer") -> None: monitor_candidates = checkpoint_callback._monitor_candidates(trainer) checkpoint_callback._save_last_checkpoint(trainer, monitor_candidates) # Throw this exception to signal to Lightning to terminate gracefully. - from pytorch_lightning.utilities.exceptions import _TunerExitException + from lightning.pytorch.utilities.exceptions import _TunerExitException raise _TunerExitException() -def configure_no_restart_validation_training_loop(trainer: pytorch_lightning.Trainer) -> None: +def configure_no_restart_validation_training_loop(trainer: lightning.pytorch.Trainer) -> None: if type(trainer.fit_loop.epoch_loop) != _TrainingEpochLoop: warnings.warn("Detected custom epoch loop. Skipping no validation on restart support.", UserWarning) return diff --git a/nemo/utils/lightning_logger_patch.py b/nemo/utils/lightning_logger_patch.py index 1b21ce3b1ae5..1528146c64b5 100644 --- a/nemo/utils/lightning_logger_patch.py +++ b/nemo/utils/lightning_logger_patch.py @@ -15,7 +15,7 @@ import logging as _logging from logging.handlers import MemoryHandler -import pytorch_lightning as pl +import lightning.pytorch as pl HANDLERS = {} PATCHED = False diff --git a/nemo/utils/loggers/clearml_logger.py b/nemo/utils/loggers/clearml_logger.py index 4e2063705b4f..c7c3945ad853 100644 --- a/nemo/utils/loggers/clearml_logger.py +++ b/nemo/utils/loggers/clearml_logger.py @@ -19,11 +19,11 @@ from typing import Any, List, Literal, Mapping, Optional, Union import pandas as pd +from lightning.pytorch.callbacks import Checkpoint +from lightning.pytorch.loggers import Logger +from lightning.pytorch.utilities.parsing import AttributeDict from lightning_utilities.core.apply_func import apply_to_collection from omegaconf import DictConfig, ListConfig, OmegaConf -from pytorch_lightning.callbacks import Checkpoint -from pytorch_lightning.loggers import Logger -from pytorch_lightning.utilities.parsing import AttributeDict from torch import Tensor from nemo.utils import logging diff --git a/nemo/utils/loggers/dllogger.py b/nemo/utils/loggers/dllogger.py index cdeef63b75f7..871d7ee3f7a2 100644 --- a/nemo/utils/loggers/dllogger.py +++ b/nemo/utils/loggers/dllogger.py @@ -17,11 +17,11 @@ from pathlib import Path from typing import Optional +from lightning.pytorch.loggers import Logger +from lightning.pytorch.utilities import rank_zero_only +from lightning.pytorch.utilities.parsing import AttributeDict from lightning_utilities.core.apply_func import apply_to_collection from omegaconf import DictConfig, ListConfig, OmegaConf -from pytorch_lightning.loggers import Logger -from pytorch_lightning.utilities import rank_zero_only -from pytorch_lightning.utilities.parsing import AttributeDict from nemo.utils import logging @@ -34,7 +34,7 @@ HAVE_DLLOGGER = False try: - from lightning_fabric.utilities.logger import _convert_params, _flatten_dict, _sanitize_callable_params + from lightning.fabric.utilities.logger import _convert_params, _flatten_dict, _sanitize_callable_params PL_LOGGER_UTILITIES = True except (ImportError, ModuleNotFoundError): diff --git a/nemo/utils/sequence_packing_utils.py b/nemo/utils/sequence_packing_utils.py index cee2be248f73..2ca03ce44b67 100644 --- a/nemo/utils/sequence_packing_utils.py +++ b/nemo/utils/sequence_packing_utils.py @@ -115,7 +115,7 @@ def create_hist(dataset: np.array, truncate_seq_len: int): logging.info("Creating histogram from tokenized dataset...") sequences = collections.defaultdict(list) - counts = [0] * truncate_seq_len + counts = [0] * (truncate_seq_len + 1) for item_dict in dataset: # Minus 1 here to account for the fact that transformer input and label have one less token than the full sequence @@ -129,7 +129,7 @@ def create_hist(dataset: np.array, truncate_seq_len: int): logging.debug(counts) histogram = [] - for seq_len in range(truncate_seq_len): + for seq_len in range(truncate_seq_len + 1): histogram.append(len(sequences[seq_len])) return sequences, histogram diff --git a/requirements/requirements_asr.txt b/requirements/requirements_asr.txt index d28b3f7980a7..783f7a483dc5 100644 --- a/requirements/requirements_asr.txt +++ b/requirements/requirements_asr.txt @@ -8,6 +8,7 @@ kaldiio lhotse>=1.26.0 librosa>=0.10.2 marshmallow +optuna packaging pyannote.core pyannote.metrics diff --git a/requirements/requirements_lightning.txt b/requirements/requirements_lightning.txt index e8020f244821..adca2283f577 100644 --- a/requirements/requirements_lightning.txt +++ b/requirements/requirements_lightning.txt @@ -1,8 +1,8 @@ cloudpickle fiddle hydra-core>1.3,<=1.3.2 +lightning>2.2.1 omegaconf<=2.3 -pytorch-lightning>2.2.1 torchmetrics>=0.11.0 transformers>=4.45.0 wandb diff --git a/requirements/requirements_multimodal.txt b/requirements/requirements_multimodal.txt index 18abe82c9f96..aa33b3b55127 100644 --- a/requirements/requirements_multimodal.txt +++ b/requirements/requirements_multimodal.txt @@ -5,7 +5,7 @@ diffusers>=0.19.3 einops_exts imageio kornia -megatron-energon +megatron-energon<3.0.0 nerfacc>=0.5.3 open_clip_torch==2.24.0 PyMCubes diff --git a/requirements/requirements_tts.txt b/requirements/requirements_tts.txt index 0d499feb3b1f..6d20e0f2250f 100644 --- a/requirements/requirements_tts.txt +++ b/requirements/requirements_tts.txt @@ -11,3 +11,5 @@ nltk pandas pypinyin pypinyin-dict +seaborn + diff --git a/requirements/requirements_vllm.txt b/requirements/requirements_vllm.txt index 414e05078680..6f5c8880f632 100644 --- a/requirements/requirements_vllm.txt +++ b/requirements/requirements_vllm.txt @@ -1 +1,19 @@ -vllm==0.5.3.post1 +# Minimal set of NeMo requirements to run vLLM export & deployment in /opt/venv in a NeMo container +braceexpand +faiss-cpu +h5py +hydra-core>1.3,<=1.3.2 +ijson +jieba +lightning>2.2.1 +matplotlib>=3.3.2 +omegaconf<=2.3 +onnx>=1.7.0 +OpenCC +pangu +rouge_score +sacrebleu +scikit-learn +vllm==0.6.3 +webdataset>=0.2.86 +wget diff --git a/scripts/checkpoint_averaging/average_model_checkpoints.py b/scripts/checkpoint_averaging/average_model_checkpoints.py index 06c522f1e192..ce88bba9716b 100644 --- a/scripts/checkpoint_averaging/average_model_checkpoints.py +++ b/scripts/checkpoint_averaging/average_model_checkpoints.py @@ -60,7 +60,7 @@ import os -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from omegaconf import OmegaConf, open_dict diff --git a/scripts/checkpoint_averaging/megatron_checkpoint_averaging.py b/scripts/checkpoint_averaging/megatron_checkpoint_averaging.py index 59f02a117da4..7b964fd7bade 100755 --- a/scripts/checkpoint_averaging/megatron_checkpoint_averaging.py +++ b/scripts/checkpoint_averaging/megatron_checkpoint_averaging.py @@ -35,8 +35,8 @@ import sys import torch +from lightning.pytorch.trainer.trainer import Trainer from omegaconf.omegaconf import OmegaConf, open_dict -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy, NLPSaveRestoreConnector from nemo.core import ModelPT @@ -60,7 +60,10 @@ def main(): help='A list of Python file names to "from FILE import *" (Needed when some classes were defined in __main__ of a script)', ) parser.add_argument( - '--class_path', type=str, default='', help='A path to class "module.submodule.class" (if given)', + '--class_path', + type=str, + default='', + help='A path to class "module.submodule.class" (if given)', ) args = parser.parse_args() diff --git a/scripts/checkpoint_converters/convert_baichuan2_hf_to_nemo.py b/scripts/checkpoint_converters/convert_baichuan2_hf_to_nemo.py index b87f7e028cdb..b35fb201865e 100644 --- a/scripts/checkpoint_converters/convert_baichuan2_hf_to_nemo.py +++ b/scripts/checkpoint_converters/convert_baichuan2_hf_to_nemo.py @@ -25,9 +25,9 @@ from collections import OrderedDict import torch +from lightning.pytorch.core.saving import _load_state as ptl_load_state +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import OmegaConf -from pytorch_lightning.core.saving import _load_state as ptl_load_state -from pytorch_lightning.trainer.trainer import Trainer from transformers import AutoModelForCausalLM, AutoTokenizer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel @@ -158,7 +158,7 @@ def convert(args): scaler = None if precision in [16, '16', '16-mixed']: scaler = GradScaler( - init_scale=nemo_config.get('native_amp_init_scale', 2 ** 32), + init_scale=nemo_config.get('native_amp_init_scale', 2**32), growth_interval=nemo_config.get('native_amp_growth_interval', 1000), hysteresis=nemo_config.get('hysteresis', 2), ) diff --git a/scripts/checkpoint_converters/convert_baichuan2_nemo_to_hf.py b/scripts/checkpoint_converters/convert_baichuan2_nemo_to_hf.py index ec048e4b6f19..335989309791 100644 --- a/scripts/checkpoint_converters/convert_baichuan2_nemo_to_hf.py +++ b/scripts/checkpoint_converters/convert_baichuan2_nemo_to_hf.py @@ -17,7 +17,7 @@ from collections import OrderedDict import torch -from pytorch_lightning import Trainer +from lightning.pytorch import Trainer from transformers import AutoModelForCausalLM from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel @@ -128,7 +128,7 @@ def convert(input_nemo_file, output_hf_file, precision=None, cpu_only=False) -> ffn_hidden_size = model.cfg.ffn_hidden_size num_query_groups = model.cfg.get("num_query_groups", head_num) # different num_query_groups for 70B - head_size = hidden_size // head_num + head_size = model.cfg.get("kv_channels") or (hidden_size // head_num) # equivalent to hf's head_dim heads_per_group = head_num // num_query_groups qkv_total_dim = head_num + 2 * num_query_groups diff --git a/scripts/checkpoint_converters/convert_bert_nemo_to_hf.py b/scripts/checkpoint_converters/convert_bert_nemo_to_hf.py index e970ea29fca2..0ec5cc1e474b 100644 --- a/scripts/checkpoint_converters/convert_bert_nemo_to_hf.py +++ b/scripts/checkpoint_converters/convert_bert_nemo_to_hf.py @@ -26,7 +26,7 @@ import torch import torch.nn.functional as F -from pytorch_lightning import Trainer +from lightning.pytorch import Trainer from transformers import AutoTokenizer, BertConfig, BertModel from nemo.collections.nlp.models.language_modeling.megatron_bert_model import MegatronBertModel @@ -207,10 +207,16 @@ def convert_config(ref_config, hf_state_dict): def get_args(): parser = ArgumentParser() parser.add_argument( - "--input_name_or_path", type=str, required=True, help="Path to .nemo file", + "--input_name_or_path", + type=str, + required=True, + help="Path to .nemo file", ) parser.add_argument( - "--output_path", type=str, required=True, help="Output HF model path", + "--output_path", + type=str, + required=True, + help="Output HF model path", ) args = parser.parse_args() diff --git a/scripts/checkpoint_converters/convert_chatglm_hf_to_nemo.py b/scripts/checkpoint_converters/convert_chatglm_hf_to_nemo.py index 363e4de09ef7..2545181ce968 100644 --- a/scripts/checkpoint_converters/convert_chatglm_hf_to_nemo.py +++ b/scripts/checkpoint_converters/convert_chatglm_hf_to_nemo.py @@ -25,8 +25,8 @@ from collections import OrderedDict import torch +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import OmegaConf -from pytorch_lightning.trainer.trainer import Trainer from transformers import AutoModel, AutoTokenizer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel @@ -126,7 +126,7 @@ def convert(args): scaler = None if precision in [16, '16', '16-mixed']: scaler = GradScaler( - init_scale=nemo_config.get('native_amp_init_scale', 2 ** 32), + init_scale=nemo_config.get('native_amp_init_scale', 2**32), growth_interval=nemo_config.get('native_amp_growth_interval', 1000), hysteresis=nemo_config.get('hysteresis', 2), ) @@ -211,7 +211,11 @@ def convert(args): qkv_bias = torch.cat((qkv_bias, q[i * heads_per_group : (i + 1) * heads_per_group, :])) qkv_bias = torch.cat((qkv_bias, k[i : i + 1, :])) qkv_bias = torch.cat((qkv_bias, v[i : i + 1, :])) - qkv_bias = qkv_bias.reshape([head_size * (head_num + 2 * num_query_groups),]) + qkv_bias = qkv_bias.reshape( + [ + head_size * (head_num + 2 * num_query_groups), + ] + ) if mcore_gpt: qkv_weights_base_name = f'model.decoder.layers.{l}.self_attention.linear_qkv.weight' diff --git a/scripts/checkpoint_converters/convert_chatglm_nemo_to_hf.py b/scripts/checkpoint_converters/convert_chatglm_nemo_to_hf.py index 5a8e52ee8be5..241e4254a9be 100644 --- a/scripts/checkpoint_converters/convert_chatglm_nemo_to_hf.py +++ b/scripts/checkpoint_converters/convert_chatglm_nemo_to_hf.py @@ -17,7 +17,7 @@ from collections import OrderedDict import torch -from pytorch_lightning import Trainer +from lightning.pytorch import Trainer from transformers import AutoModel from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel @@ -126,7 +126,7 @@ def convert(input_nemo_file, output_hf_file, precision=None, cpu_only=False) -> num_layers = model.cfg.num_layers num_query_groups = model.cfg.get("num_query_groups", head_num) # different num_query_groups for 70B - head_size = hidden_size // head_num + head_size = model.cfg.get("kv_channels") or (hidden_size // head_num) # equivalent to hf's head_dim heads_per_group = head_num // num_query_groups # 32 / 2 = 16 qkv_total_dim = head_num + 2 * num_query_groups # 32 + 2 * 2 = 36 diff --git a/scripts/checkpoint_converters/convert_clip_hf_to_nemo.py b/scripts/checkpoint_converters/convert_clip_hf_to_nemo.py index 2b8156ad4b26..c47444534604 100644 --- a/scripts/checkpoint_converters/convert_clip_hf_to_nemo.py +++ b/scripts/checkpoint_converters/convert_clip_hf_to_nemo.py @@ -38,9 +38,9 @@ from argparse import ArgumentParser import torch +from lightning.pytorch.plugins.environments import TorchElasticEnvironment +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import OmegaConf -from pytorch_lightning.plugins.environments import TorchElasticEnvironment -from pytorch_lightning.trainer.trainer import Trainer from transformers import CLIPModel from nemo.collections.multimodal.models.vision_language_foundation.clip.megatron_clip_models import MegatronCLIPModel diff --git a/scripts/checkpoint_converters/convert_falcon_hf_to_nemo.py b/scripts/checkpoint_converters/convert_falcon_hf_to_nemo.py index ae8885f4de93..8a880a290484 100644 --- a/scripts/checkpoint_converters/convert_falcon_hf_to_nemo.py +++ b/scripts/checkpoint_converters/convert_falcon_hf_to_nemo.py @@ -32,7 +32,7 @@ import time from typing import Dict -import pytorch_lightning as pl +import lightning.pytorch as pl import torch import yaml from omegaconf import OmegaConf @@ -83,11 +83,11 @@ def get_new_key(old_key): def load_falcon_config(args) -> FalconConfig: - """ Helper utility to load FalconConfig. + """Helper utility to load FalconConfig. Legacy Falcon-7B and Falcon-40B are not compatible with `transformers.FalconConfig` and `transformers.FalconModel`. need to manually set the config values - and force to `falcon` model type. + and force to `falcon` model type. """ config = FalconConfig.from_pretrained(args.input_name_or_path) if config.model_type == 'RefinedWeb': diff --git a/scripts/checkpoint_converters/convert_falcon_nemo_to_hf.py b/scripts/checkpoint_converters/convert_falcon_nemo_to_hf.py index da8f15b92649..cc1d99b6d1c6 100644 --- a/scripts/checkpoint_converters/convert_falcon_nemo_to_hf.py +++ b/scripts/checkpoint_converters/convert_falcon_nemo_to_hf.py @@ -17,7 +17,7 @@ from collections import OrderedDict import torch -from pytorch_lightning import Trainer +from lightning.pytorch import Trainer from transformers import AutoModelForCausalLM from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel diff --git a/scripts/checkpoint_converters/convert_gpt_nemo_to_mcore.py b/scripts/checkpoint_converters/convert_gpt_nemo_to_mcore.py index 35039f8d02e9..61443a3bcb28 100644 --- a/scripts/checkpoint_converters/convert_gpt_nemo_to_mcore.py +++ b/scripts/checkpoint_converters/convert_gpt_nemo_to_mcore.py @@ -17,8 +17,8 @@ from collections import OrderedDict import torch +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import OmegaConf, open_dict -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy diff --git a/scripts/checkpoint_converters/convert_llama_hf_to_nemo.py b/scripts/checkpoint_converters/convert_llama_hf_to_nemo.py index 4eb8cb6330ca..44de38497b44 100644 --- a/scripts/checkpoint_converters/convert_llama_hf_to_nemo.py +++ b/scripts/checkpoint_converters/convert_llama_hf_to_nemo.py @@ -27,8 +27,8 @@ from collections import OrderedDict import torch +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import OmegaConf -from pytorch_lightning.trainer.trainer import Trainer from transformers import AutoTokenizer, LlamaForCausalLM, LlamaTokenizer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel diff --git a/scripts/checkpoint_converters/convert_llama_hf_to_nemo_load.py b/scripts/checkpoint_converters/convert_llama_hf_to_nemo_load.py index 42d3e77ce4c8..75bd0d0ab6ed 100644 --- a/scripts/checkpoint_converters/convert_llama_hf_to_nemo_load.py +++ b/scripts/checkpoint_converters/convert_llama_hf_to_nemo_load.py @@ -28,8 +28,8 @@ from collections import OrderedDict import torch +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import OmegaConf -from pytorch_lightning.trainer.trainer import Trainer from transformers import AutoTokenizer, LlamaForCausalLM, LlamaTokenizer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel diff --git a/scripts/checkpoint_converters/convert_llama_hf_to_nemo_save_dict.py b/scripts/checkpoint_converters/convert_llama_hf_to_nemo_save_dict.py index f7096996e5b1..4a8a409a88fd 100644 --- a/scripts/checkpoint_converters/convert_llama_hf_to_nemo_save_dict.py +++ b/scripts/checkpoint_converters/convert_llama_hf_to_nemo_save_dict.py @@ -27,8 +27,8 @@ from collections import OrderedDict import torch +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import OmegaConf -from pytorch_lightning.trainer.trainer import Trainer from transformers import AutoTokenizer, LlamaForCausalLM, LlamaTokenizer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel diff --git a/scripts/checkpoint_converters/convert_llama_nemo_to_hf.py b/scripts/checkpoint_converters/convert_llama_nemo_to_hf.py index a3c40676a980..87b7151aa961 100644 --- a/scripts/checkpoint_converters/convert_llama_nemo_to_hf.py +++ b/scripts/checkpoint_converters/convert_llama_nemo_to_hf.py @@ -17,8 +17,8 @@ from collections import OrderedDict import torch +from lightning.pytorch import Trainer from omegaconf import open_dict -from pytorch_lightning import Trainer from transformers import AutoModelForCausalLM, LlamaTokenizer, LlamaTokenizerFast, convert_slow_tokenizer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel @@ -26,7 +26,7 @@ from nemo.utils import logging """ -Script to convert a llama2 checkpoint in nemo (mcore path) into a HuggingFace checkpoint. +Script to convert a llama checkpoint in nemo (mcore path) into a HuggingFace checkpoint. This script can be used to 1) generate only the HF weights, or 2) generate an entire HF model folder. 1) Generate only HF weights from a nemo file: @@ -37,13 +37,21 @@ 2) Generate the full HF model folder + python convert_llama_nemo_to_hf.py \ + --input_name_or_path /path/to/file.nemo or /path/to/extracted_folder \ + --output_path /path/to/pytorch_model.bin \ + --hf_input_path /path/to/input_hf_folder \ + --hf_output_path /path/to/output_hf_folder + +3) Generate the full HF model folder with a custom tokenizer + python convert_llama_nemo_to_hf.py \ --input_name_or_path /path/to/file.nemo or /path/to/extracted_folder \ --output_path /path/to/pytorch_model.bin \ --hf_input_path /path/to/input_hf_folder \ --hf_output_path /path/to/output_hf_folder \ - --input_tokenizer /path/to/tokenizer \ - --hf_output_tokenizer /path/to/output_tokenizer \ + --input_tokenizer /path/to/custom_nemo_tokenizer.model \ + --hf_output_tokenizer /path/to/output_tokenizer Use the --cpu-only flag if the model cannot fit in the GPU (e.g. Llama2 70b). However this option makes the conversion script significantly slower. @@ -143,7 +151,7 @@ def convert(input_nemo_file, output_hf_file, precision=None, cpu_only=False) -> ffn_hidden_size = model.cfg.ffn_hidden_size num_query_groups = model.cfg.get("num_query_groups", head_num) # different num_query_groups for 70B - head_size = hidden_size // head_num + head_size = model.cfg.get("kv_channels") or (hidden_size // head_num) # equivalent to hf's head_dim heads_per_group = head_num // num_query_groups qkv_total_dim = head_num + 2 * num_query_groups @@ -246,21 +254,25 @@ def replace_hf_weights_and_tokenizer( nemo_exported = torch.load(weights_file) if tokenizer_path: - tokenizer = LlamaTokenizer.from_pretrained( - tokenizer_path, - local_files_only=True, - legacy=False, - ) - tmp_tokenizer = convert_slow_tokenizer.convert_slow_tokenizer(tokenizer) - fast_tokenizer = LlamaTokenizerFast(tokenizer_object=tmp_tokenizer) - tokenizer_length = len(fast_tokenizer) - model.resize_token_embeddings(tokenizer_length) + try: + tokenizer = LlamaTokenizer.from_pretrained( + tokenizer_path, + local_files_only=True, + legacy=False, + ) + tmp_tokenizer = convert_slow_tokenizer.convert_slow_tokenizer(tokenizer) + fast_tokenizer = LlamaTokenizerFast(tokenizer_object=tmp_tokenizer) + tokenizer_length = len(fast_tokenizer) + model.resize_token_embeddings(tokenizer_length) + except: + tokenizer = None + logging.warning("Could not load custom tokenizer, proceeding with default tokenizer") model.load_state_dict(nemo_exported) model.save_pretrained(output_hf_path) logging.info(f"Full HF model saved to {output_hf_path}") - if tokenizer_path: + if tokenizer_path and (tokenizer is not None): fast_tokenizer.save_pretrained(output_hf_tokenizer) tokenizer.save_pretrained(output_hf_tokenizer) logging.info(f"Tokenizer saved to {output_hf_tokenizer}") diff --git a/scripts/checkpoint_converters/convert_mistral_7b_hf_to_nemo.py b/scripts/checkpoint_converters/convert_mistral_7b_hf_to_nemo.py index 4bceb250999f..3cf5bbd4acf9 100644 --- a/scripts/checkpoint_converters/convert_mistral_7b_hf_to_nemo.py +++ b/scripts/checkpoint_converters/convert_mistral_7b_hf_to_nemo.py @@ -29,9 +29,9 @@ import torch import torch.nn +from lightning.pytorch.core.saving import _load_state as ptl_load_state +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import OmegaConf -from pytorch_lightning.core.saving import _load_state as ptl_load_state -from pytorch_lightning.trainer.trainer import Trainer from transformers import AutoModelForCausalLM, AutoTokenizer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel diff --git a/scripts/checkpoint_converters/convert_mistral_7b_nemo_to_hf.py b/scripts/checkpoint_converters/convert_mistral_7b_nemo_to_hf.py index b8c30a1b929d..1f0a31076f8e 100644 --- a/scripts/checkpoint_converters/convert_mistral_7b_nemo_to_hf.py +++ b/scripts/checkpoint_converters/convert_mistral_7b_nemo_to_hf.py @@ -25,7 +25,7 @@ import torch import torch.nn -from pytorch_lightning.trainer.trainer import Trainer +from lightning.pytorch.trainer.trainer import Trainer from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel @@ -134,7 +134,7 @@ def convert(in_file, precision=None, cpu_only=True) -> None: num_layers = model.cfg.num_layers num_query_groups = model.cfg.get("num_query_groups", head_num) # different num_query_groups for 70B - head_size = model.cfg.get('kv_channels', hidden_size // head_num) + head_size = model.cfg.get("kv_channels") or (hidden_size // head_num) # equivalent to hf's head_dim heads_per_group = head_num // num_query_groups qkv_total_dim = head_num + 2 * num_query_groups diff --git a/scripts/checkpoint_converters/convert_mixtral_hf_to_nemo.py b/scripts/checkpoint_converters/convert_mixtral_hf_to_nemo.py index 36e4c0c2c3ea..a75c6876e70a 100644 --- a/scripts/checkpoint_converters/convert_mixtral_hf_to_nemo.py +++ b/scripts/checkpoint_converters/convert_mixtral_hf_to_nemo.py @@ -30,9 +30,9 @@ import megatron.core.parallel_state as parallel_state import torch import torch.nn +from lightning.pytorch.core.saving import _load_state as ptl_load_state +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import OmegaConf -from pytorch_lightning.core.saving import _load_state as ptl_load_state -from pytorch_lightning.trainer.trainer import Trainer from transformers import AutoModelForCausalLM, AutoTokenizer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel diff --git a/scripts/checkpoint_converters/convert_mixtral_nemo_to_hf.py b/scripts/checkpoint_converters/convert_mixtral_nemo_to_hf.py index 2bac2eaad616..eb934803f164 100644 --- a/scripts/checkpoint_converters/convert_mixtral_nemo_to_hf.py +++ b/scripts/checkpoint_converters/convert_mixtral_nemo_to_hf.py @@ -26,7 +26,7 @@ import megatron.core.parallel_state as parallel_state import torch import torch.nn -from pytorch_lightning.trainer.trainer import Trainer +from lightning.pytorch.trainer.trainer import Trainer from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel @@ -137,7 +137,7 @@ def convert(in_file, precision=None) -> None: num_layers = model.cfg.num_layers num_query_groups = model.cfg.get("num_query_groups", head_num) # different num_query_groups for 70B - head_size = hidden_size // head_num + head_size = model.cfg.get("kv_channels") or (hidden_size // head_num) # equivalent to hf's head_dim heads_per_group = head_num // num_query_groups qkv_total_dim = head_num + 2 * num_query_groups diff --git a/scripts/checkpoint_converters/convert_mpt_hf_to_nemo.py b/scripts/checkpoint_converters/convert_mpt_hf_to_nemo.py index e7d81f709092..d4a450a8e046 100644 --- a/scripts/checkpoint_converters/convert_mpt_hf_to_nemo.py +++ b/scripts/checkpoint_converters/convert_mpt_hf_to_nemo.py @@ -56,7 +56,7 @@ import argparse import os -import pytorch_lightning as pl +import lightning.pytorch as pl import torch import yaml from omegaconf import OmegaConf @@ -68,7 +68,11 @@ if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument( - "--input_name_or_path", type=str, default=None, required=True, help="Path to Huggingface MPT checkpoints", + "--input_name_or_path", + type=str, + default=None, + required=True, + help="Path to Huggingface MPT checkpoints", ) parser.add_argument("--output_path", type=str, default=None, required=True, help="Path to output .nemo file.") parser.add_argument( diff --git a/scripts/checkpoint_converters/convert_nemotron_nemo_to_hf.py b/scripts/checkpoint_converters/convert_nemotron_nemo_to_hf.py index fc0f660cbd42..2f66773f8724 100644 --- a/scripts/checkpoint_converters/convert_nemotron_nemo_to_hf.py +++ b/scripts/checkpoint_converters/convert_nemotron_nemo_to_hf.py @@ -19,9 +19,9 @@ from collections import OrderedDict import torch -from pytorch_lightning import Trainer +from lightning.pytorch import Trainer from transformers import LlamaTokenizer, PreTrainedTokenizerFast -from transformers.convert_slow_tokenizer import LlamaConverter +from transformers.convert_slow_tokenizer import LlamaConverter, TikTokenConverter from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel @@ -130,6 +130,20 @@ def convert_hf_config(nemo_config, tokenizer, vocab_size, dtype, hf_output_path, json.dump(hf_config, open(f"{hf_output_path}/config.json", "w"), indent=2) +def convert_tiktoken(vocab_file) -> None: + with open(vocab_file, 'r') as f: + vocab = json.load(f) + os.remove(vocab_file) + + lines = [] + for line in vocab: + lines.append(f"{line['token_bytes']} {line['rank']}") + + for line in lines: + with open(vocab_file, 'a') as f: + f.write(line + '\n') + + def convert(input_nemo_file, output_hf_file, precision=None, cpu_only=False) -> None: """ Convert NeMo weights to HF weights @@ -323,6 +337,28 @@ def extract_nemotron_tokenizer(nemo_file, model_config, output_hf_path, nemo_tok ) tokenizer.save_pretrained(output_hf_path) logging.info(f"Setencepiece tokenizer has been saved to {output_tokenizer}") + elif tokenizer_cfg.library == "tiktoken": + tokenizer_fn = tokenizer_cfg.model[5:] + special_tokens = ["", "", ""] + import tarfile + + archive = tarfile.open(nemo_file, "r") + tokenizer_filename = "./" + tokenizer_fn # exclude 'nemo:' prefix + archive.extract(tokenizer_filename, output_hf_path) + archive.close() + vocab_file = os.path.join(output_hf_path, tokenizer_fn) + convert_tiktoken(vocab_file) + converted_tokenizer = TikTokenConverter( + vocab_file=vocab_file, additional_special_tokens=special_tokens + ).converted() + os.remove(vocab_file) + tokenizer = PreTrainedTokenizerFast( + tokenizer_object=converted_tokenizer, + model_input_names=["input_ids", "attention_mask"], + bos_token="", + eos_token="", + ) + tokenizer.save_pretrained(output_hf_path) elif isinstance(nemo_tokenizer, AutoTokenizer): nemo_tokenizer.tokenizer.save_pretrained(output_hf_path) logging.info(f"HF AutoTokenizer has been saved to {output_hf_path}") diff --git a/scripts/checkpoint_converters/convert_qwen2_hf_to_nemo.py b/scripts/checkpoint_converters/convert_qwen2_hf_to_nemo.py index 223c7af50843..b472a7e5c6f3 100644 --- a/scripts/checkpoint_converters/convert_qwen2_hf_to_nemo.py +++ b/scripts/checkpoint_converters/convert_qwen2_hf_to_nemo.py @@ -25,8 +25,8 @@ from collections import OrderedDict import torch +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import OmegaConf -from pytorch_lightning.trainer.trainer import Trainer from transformers import Qwen2ForCausalLM, Qwen2Tokenizer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel diff --git a/scripts/checkpoint_converters/convert_qwen2_nemo_to_hf.py b/scripts/checkpoint_converters/convert_qwen2_nemo_to_hf.py index 6080499ffdf8..968caade917c 100644 --- a/scripts/checkpoint_converters/convert_qwen2_nemo_to_hf.py +++ b/scripts/checkpoint_converters/convert_qwen2_nemo_to_hf.py @@ -17,7 +17,7 @@ from collections import OrderedDict import torch -from pytorch_lightning import Trainer +from lightning.pytorch import Trainer from transformers import Qwen2ForCausalLM, Qwen2Tokenizer, Qwen2TokenizerFast, convert_slow_tokenizer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel @@ -142,7 +142,7 @@ def convert(input_nemo_file, output_hf_file, precision=None, cpu_only=False) -> ffn_hidden_size = model.cfg.ffn_hidden_size num_query_groups = model.cfg.get("num_query_groups", head_num) - head_size = hidden_size // head_num + head_size = model.cfg.get("kv_channels") or (hidden_size // head_num) # equivalent to hf's head_dim heads_per_group = head_num // num_query_groups qkv_total_dim = head_num + 2 * num_query_groups diff --git a/scripts/checkpoint_converters/convert_starcoder2_hf_to_nemo.py b/scripts/checkpoint_converters/convert_starcoder2_hf_to_nemo.py index fc898c797a9e..862777cf52a8 100644 --- a/scripts/checkpoint_converters/convert_starcoder2_hf_to_nemo.py +++ b/scripts/checkpoint_converters/convert_starcoder2_hf_to_nemo.py @@ -28,9 +28,9 @@ import torch import torch.nn +from lightning.pytorch.core.saving import _load_state as ptl_load_state +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import OmegaConf -from pytorch_lightning.core.saving import _load_state as ptl_load_state -from pytorch_lightning.trainer.trainer import Trainer from transformers import AutoModelForCausalLM, AutoTokenizer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel @@ -168,7 +168,7 @@ def convert(args): scaler = None if precision in [16, '16', '16-mixed']: scaler = GradScaler( - init_scale=nemo_config.get('native_amp_init_scale', 2 ** 32), + init_scale=nemo_config.get('native_amp_init_scale', 2**32), growth_interval=nemo_config.get('native_amp_growth_interval', 1000), hysteresis=nemo_config.get('hysteresis', 2), ) diff --git a/scripts/checkpoint_converters/convert_starcoder2_nemo_to_hf.py b/scripts/checkpoint_converters/convert_starcoder2_nemo_to_hf.py index 4b65533b74ec..c418a714be0a 100644 --- a/scripts/checkpoint_converters/convert_starcoder2_nemo_to_hf.py +++ b/scripts/checkpoint_converters/convert_starcoder2_nemo_to_hf.py @@ -25,7 +25,7 @@ import torch import torch.nn -from pytorch_lightning.trainer.trainer import Trainer +from lightning.pytorch.trainer.trainer import Trainer from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel @@ -141,7 +141,7 @@ def convert(in_file, precision=None, cpu_only=True) -> None: num_layers = model.cfg.num_layers num_query_groups = model.cfg.get("num_query_groups", head_num) # different num_query_groups for 70B - head_size = hidden_size // head_num + head_size = model.cfg.get("kv_channels") or (hidden_size // head_num) # equivalent to hf's head_dim heads_per_group = head_num // num_query_groups qkv_total_dim = head_num + 2 * num_query_groups diff --git a/scripts/checkpoint_converters/convert_starcoder_hf_to_nemo.py b/scripts/checkpoint_converters/convert_starcoder_hf_to_nemo.py index e600c65e6de1..6b9f30ab427b 100644 --- a/scripts/checkpoint_converters/convert_starcoder_hf_to_nemo.py +++ b/scripts/checkpoint_converters/convert_starcoder_hf_to_nemo.py @@ -52,7 +52,7 @@ import os from typing import Dict -import pytorch_lightning as pl +import lightning.pytorch as pl import torch import yaml from omegaconf import OmegaConf diff --git a/scripts/checkpoint_converters/quantize_model_to_nf4.py b/scripts/checkpoint_converters/quantize_model_to_nf4.py index db3a48aaa16d..8fbaeb875f7a 100644 --- a/scripts/checkpoint_converters/quantize_model_to_nf4.py +++ b/scripts/checkpoint_converters/quantize_model_to_nf4.py @@ -16,7 +16,7 @@ from typing import List import torch -from pytorch_lightning import Trainer +from lightning.pytorch import Trainer from torch import nn from nemo.collections.nlp.models.language_modeling.megatron_gpt_sft_model import MegatronGPTSFTModel diff --git a/scripts/confidence_ensembles/build_ensemble.py b/scripts/confidence_ensembles/build_ensemble.py index 4c05e2e4ff3f..dfb3793b42f4 100644 --- a/scripts/confidence_ensembles/build_ensemble.py +++ b/scripts/confidence_ensembles/build_ensemble.py @@ -80,8 +80,8 @@ from typing import Dict, List, Optional, Tuple import joblib +import lightning.pytorch as pl import numpy as np -import pytorch_lightning as pl from omegaconf import MISSING, DictConfig, OmegaConf from sklearn.linear_model import LogisticRegression from sklearn.metrics import confusion_matrix @@ -215,7 +215,12 @@ class BuildEnsembleConfig: preserve_frame_confidence=True, exclude_blank=True, aggregation="mean", - method_cfg=ConfidenceMethodConfig(name="entropy", entropy_type="renyi", alpha=0.25, entropy_norm="lin",), + method_cfg=ConfidenceMethodConfig( + name="entropy", + entropy_type="renyi", + alpha=0.25, + entropy_norm="lin", + ), ) ) temperature: float = 1.0 @@ -499,7 +504,12 @@ def find_best_confidence( dev_features = np.array(list(zip(*cur_dev_confidences))) dev_labels = np.array(dev_labels) pipe, score = train_model_selection( - training_features, training_labels, dev_features, dev_labels, tune_lr, tune_lr_config, + training_features, + training_labels, + dev_features, + dev_labels, + tune_lr, + tune_lr_config, ) if max_score < score: max_score = score @@ -513,7 +523,7 @@ def find_best_confidence( @hydra_runner(config_name="BuildEnsembleConfig", schema=BuildEnsembleConfig) def main(cfg: BuildEnsembleConfig): # silencing all messages from nemo/ptl to avoid dumping tons of configs to the stdout - logging.getLogger('pytorch_lightning').setLevel(logging.CRITICAL) + logging.getLogger('lightning.pytorch').setLevel(logging.CRITICAL) logging.getLogger('nemo_logger').setLevel(logging.CRITICAL) LOG.info(f'Build ensemble config:\n{OmegaConf.to_yaml(cfg)}') diff --git a/scripts/deploy/nlp/deploy_triton.py b/scripts/deploy/nlp/deploy_triton.py index e3394726fa1c..154ffc90dc9c 100755 --- a/scripts/deploy/nlp/deploy_triton.py +++ b/scripts/deploy/nlp/deploy_triton.py @@ -419,13 +419,14 @@ def nemo_deploy(argv): LOGGER.info("Triton deploy function will be called.") nm.deploy() + nm.run() except Exception as error: LOGGER.error("Error message has occurred during deploy function. Error message: " + str(error)) return try: LOGGER.info("Model serving on Triton is will be started.") - if args.start_rest_service == "True": + if args.start_rest_service: try: LOGGER.info("REST service will be started.") uvicorn.run( diff --git a/scripts/deploy/nlp/deploy_vllm_triton.py b/scripts/deploy/nlp/deploy_vllm_triton.py index ab9f13a1b8da..a3cf5e8ec762 100755 --- a/scripts/deploy/nlp/deploy_vllm_triton.py +++ b/scripts/deploy/nlp/deploy_vllm_triton.py @@ -41,7 +41,7 @@ def get_args(argv): "-mt", "--model_type", type=str, - required=False, + required=True, choices=["llama", "mistral", "mixtral", "starcoder2", "gemma"], help="Type of the model", ) diff --git a/scripts/deploy/nlp/query_inframework.py b/scripts/deploy/nlp/query_inframework.py index e77ab72a1f04..a62e09fa071d 100644 --- a/scripts/deploy/nlp/query_inframework.py +++ b/scripts/deploy/nlp/query_inframework.py @@ -15,7 +15,7 @@ import argparse import sys -from nemo.deploy.nlp.query_llm import NemoQueryLLMPyTorch +from nemo.deploy.nlp import NemoQueryLLMPyTorch def get_args(argv): diff --git a/scripts/diffusion_model_lora_merge/merge_lora_weights_into_base_model.py b/scripts/diffusion_model_lora_merge/merge_lora_weights_into_base_model.py index 57d9964cad3d..a80d9d2639e3 100644 --- a/scripts/diffusion_model_lora_merge/merge_lora_weights_into_base_model.py +++ b/scripts/diffusion_model_lora_merge/merge_lora_weights_into_base_model.py @@ -16,7 +16,7 @@ from typing import Any, Dict import torch -from pytorch_lightning import Trainer +from lightning.pytorch import Trainer from nemo.collections.multimodal.models.text_to_image.stable_diffusion.ldm.ddpm import MegatronLatentDiffusion from nemo.collections.multimodal.parts.utils import setup_trainer_and_model_for_inference diff --git a/scripts/export.py b/scripts/export.py index acfd3e3e3450..6e0b9b72e15b 100644 --- a/scripts/export.py +++ b/scripts/export.py @@ -30,8 +30,8 @@ import sys import torch +from lightning.pytorch import Trainer from omegaconf import OmegaConf -from pytorch_lightning import Trainer import nemo from nemo.core import ModelPT diff --git a/scripts/llm/ptq.py b/scripts/llm/ptq.py index c04d32290e5f..2afe38c37b4d 100644 --- a/scripts/llm/ptq.py +++ b/scripts/llm/ptq.py @@ -17,6 +17,8 @@ def get_args(): + """Parses PTQ arguments""" + parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter, description="NeMo PTQ argument parser", @@ -58,6 +60,10 @@ def get_args(): type=str, help='Calibration dataset to be used. Should be \"wikitext\", \"cnn_dailymail\" or path to a local .json file', ) + parser.add_argument( + '--generate_sample', help='Generate sample model output after performing PTQ', action='store_true' + ) + parser.set_defaults(generate_sample=False) args = parser.parse_args() if args.output_path is None: @@ -68,6 +74,8 @@ def get_args(): def main(): + """Example NeMo 2.0 Post Training Quantization workflow""" + args = get_args() quantization_config = quantization.QuantizationConfig( @@ -87,6 +95,7 @@ def main(): inference_tensor_parallel=args.tensor_parallelism_size, inference_pipeline_parallel=args.pipeline_parallelism_size, dtype=args.dtype, + generate_sample=args.generate_sample, ) quantizer = quantization.Quantizer(quantization_config, export_config) diff --git a/scripts/nemo_legacy_import/nlp_checkpoint_port.py b/scripts/nemo_legacy_import/nlp_checkpoint_port.py index b7541ffdb8cd..058f9e072f5f 100644 --- a/scripts/nemo_legacy_import/nlp_checkpoint_port.py +++ b/scripts/nemo_legacy_import/nlp_checkpoint_port.py @@ -30,7 +30,7 @@ import logging import sys -import pytorch_lightning as pl +import lightning.pytorch as pl from omegaconf import OmegaConf, open_dict from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector diff --git a/scripts/nlp_language_modeling/convert_prompt_learning_ckpt_to_nemo.py b/scripts/nlp_language_modeling/convert_prompt_learning_ckpt_to_nemo.py index 334b3415a93b..3e96186552a5 100644 --- a/scripts/nlp_language_modeling/convert_prompt_learning_ckpt_to_nemo.py +++ b/scripts/nlp_language_modeling/convert_prompt_learning_ckpt_to_nemo.py @@ -14,7 +14,7 @@ import os -from pytorch_lightning.trainer.trainer import Trainer +from lightning.pytorch.trainer.trainer import Trainer from nemo.collections.nlp.models.language_modeling.megatron_gpt_prompt_learning_model import ( MegatronGPTPromptLearningModel, diff --git a/scripts/nlp_language_modeling/hf_t5-v1_1_to_nemo.py b/scripts/nlp_language_modeling/hf_t5-v1_1_to_nemo.py index 6a94e8f501bb..2361e000ef7e 100644 --- a/scripts/nlp_language_modeling/hf_t5-v1_1_to_nemo.py +++ b/scripts/nlp_language_modeling/hf_t5-v1_1_to_nemo.py @@ -53,8 +53,8 @@ from argparse import ArgumentParser import torch +from lightning.pytorch import Trainer from omegaconf.omegaconf import OmegaConf, open_dict -from pytorch_lightning import Trainer from transformers import AutoTokenizer, T5ForConditionalGeneration from nemo.collections.nlp.models.language_modeling.megatron_t5_model import MegatronT5Model diff --git a/scripts/nlp_language_modeling/merge_lora_weights/merge.py b/scripts/nlp_language_modeling/merge_lora_weights/merge.py index 55d50502705c..3a6d110997ba 100644 --- a/scripts/nlp_language_modeling/merge_lora_weights/merge.py +++ b/scripts/nlp_language_modeling/merge_lora_weights/merge.py @@ -33,8 +33,8 @@ from typing import Any, Dict, List import torch +from lightning.pytorch.trainer.trainer import Trainer from omegaconf import OmegaConf, open_dict -from pytorch_lightning.trainer.trainer import Trainer from torch.utils.data import DataLoader, Dataset from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel diff --git a/scripts/nlp_language_modeling/prepare_packed_ft_dataset.py b/scripts/nlp_language_modeling/prepare_packed_ft_dataset.py index 7ff2342e4087..19a3e6a78228 100644 --- a/scripts/nlp_language_modeling/prepare_packed_ft_dataset.py +++ b/scripts/nlp_language_modeling/prepare_packed_ft_dataset.py @@ -88,6 +88,14 @@ def tokenize_dataset(cfg: 'DictConfig'): # using the same template as SFT/PEFT script. This may be overkill but guarantees the preprocess settings # are identical to normal SFT training data_cfg = cfg.model.data.train_ds + pad_seq_length_to_mult = 16 + cp_size = cfg.model.get("context_parallel_size", 1) + + # if context parallel is used, each individual data length in one packed dataset sample + # needs to be a multiple of (cp_size * 2): https://github.com/NVIDIA/TransformerEngine/pull/641 + if cp_size > 1: + pad_seq_length_to_mult = max(pad_seq_length_to_mult, cp_size * 2) + if os.path.isdir(cfg.tokenizer_path): # pass in a Hugging Face folder which contains tokenizer.json tokenizer = get_nmt_tokenizer(library="huggingface", model_name=cfg.tokenizer_path, use_fast=True) @@ -99,7 +107,7 @@ def tokenize_dataset(cfg: 'DictConfig'): tokenizer=tokenizer, max_seq_length=data_cfg.max_seq_length, min_seq_length=data_cfg.min_seq_length, - pad_seq_length_to_mult=16, # adds padding in collate_fn so this value is irrelevant here + pad_seq_length_to_mult=pad_seq_length_to_mult, add_bos=data_cfg.get('add_bos', False), add_eos=data_cfg.get('add_eos', True), add_sep=data_cfg.get('add_sep', False), @@ -121,7 +129,40 @@ def tokenize_dataset(cfg: 'DictConfig'): is_test=True, ) - return np.array([dataset[i] for i in range(len(dataset))]) + max_seq_length = dataset.max_seq_length + pad_id = dataset.tokenizer.eos_id + pad_seq_length_to_mult = dataset.pad_seq_length_to_mult + dataset = np.array([dataset[i] for i in range(len(dataset))]) + if cp_size > 1: + + def pre_pad_dataset(data, max_seq_length, max_length_to_pad, pad_id): + ''' + pad each individual data point to the length of max_length + ''' + assert max_seq_length >= max_length_to_pad + for key, val in data.items(): + if key in {'input_ids', 'context_ids'}: + if len(val) <= max_length_to_pad: + # because input_ids are truncated by 1 for inputs and labels, + # we add 1 extra padding here to make sure padded inputs and labels + # are is a multiple of (cp_size * 2) + val = val + [pad_id] * (max_length_to_pad - len(val) + 1) + data[key] = val + elif len(val) > max_seq_length: + logging.info( + f"""The current sequence length {len(val)} for packing is + larger than the max_seq_length specified ({max_seq_length}). + The current seqquence length is truncated to the size of max_seq_length. + Please consider increase the sequence packing size""" + ) + data[key] = val[:max_seq_length] + return + + ceil_to_nearest = lambda n, m: (n + m - 1) // m * m + for data in dataset: + max_length_to_pad = min(max_seq_length, ceil_to_nearest(len(data['input_ids']), pad_seq_length_to_mult)) + pre_pad_dataset(data, max_seq_length, max_length_to_pad, pad_id) + return dataset @dataclass diff --git a/scripts/nlp_language_modeling/service_launch_scripts/start_retro_model_service.py b/scripts/nlp_language_modeling/service_launch_scripts/start_retro_model_service.py index ee32f69bf734..dd7c1a3656be 100644 --- a/scripts/nlp_language_modeling/service_launch_scripts/start_retro_model_service.py +++ b/scripts/nlp_language_modeling/service_launch_scripts/start_retro_model_service.py @@ -15,8 +15,8 @@ import os import torch +from lightning.pytorch import Trainer from omegaconf.omegaconf import OmegaConf, open_dict -from pytorch_lightning import Trainer from nemo.collections.nlp.models.language_modeling.megatron_retrieval_model import MegatronRetrievalModel from nemo.collections.nlp.modules.common.text_generation_server import MegatronServer @@ -66,7 +66,10 @@ def main(cfg) -> None: save_restore_connector.model_extracted_dir = model_path model_cfg = MegatronRetrievalModel.restore_from( - model_path, trainer=trainer, return_config=True, save_restore_connector=save_restore_connector, + model_path, + trainer=trainer, + return_config=True, + save_restore_connector=save_restore_connector, ) with open_dict(model_cfg): @@ -76,7 +79,10 @@ def main(cfg) -> None: model_cfg.activations_checkpoint_method = None model = MegatronRetrievalModel.restore_from( - model_path, trainer=trainer, save_restore_connector=save_restore_connector, override_config_path=model_cfg, + model_path, + trainer=trainer, + save_restore_connector=save_restore_connector, + override_config_path=model_cfg, ) # check whether the DDP is initialized diff --git a/scripts/speech_recognition/confidence/benchmark_asr_confidence.py b/scripts/speech_recognition/confidence/benchmark_asr_confidence.py index 9c42ef6cca5b..7208867ff938 100644 --- a/scripts/speech_recognition/confidence/benchmark_asr_confidence.py +++ b/scripts/speech_recognition/confidence/benchmark_asr_confidence.py @@ -18,7 +18,7 @@ from pathlib import Path from typing import Optional -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from omegaconf import MISSING, OmegaConf from sklearn.model_selection import ParameterGrid diff --git a/scripts/speech_recognition/oomptimizer.py b/scripts/speech_recognition/oomptimizer.py index 3d5eb5a4dbb1..8d215cbc14eb 100755 --- a/scripts/speech_recognition/oomptimizer.py +++ b/scripts/speech_recognition/oomptimizer.py @@ -20,7 +20,7 @@ from typing import Iterable, Literal import click -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from lhotse import compute_num_samples from omegaconf import OmegaConf diff --git a/scripts/vlm/llava_next_finetune.py b/scripts/vlm/llava_next_finetune.py new file mode 100644 index 000000000000..334b360d7c70 --- /dev/null +++ b/scripts/vlm/llava_next_finetune.py @@ -0,0 +1,236 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Example: + torchrun --nproc_per_node=8 scripts/vlm/llava_next_finetune.py \ + --devices=8 --tp=4 --data_type=mock + + torchrun --nproc_per_node=8 scripts/vlm/llava_next_finetune.py \ + --devices=8 --tp=4 --data_type=energon --data_path='' \ + --language_model_path=/root/.cache/nemo/models/lmsys/vicuna-7b-v1.5 +""" + +import argparse + +import torch +from megatron.core.optimizer import OptimizerConfig +from pytorch_lightning.loggers import WandbLogger + +from nemo import lightning as nl +from nemo.collections import llm, vlm +from nemo.lightning.pytorch.optim import CosineAnnealingScheduler +from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule +from nemo.utils.exp_manager import TimingCallback + + +def main(args): + # pylint: disable=C0115,C0116 + + # Global and micro batch sizes + gbs = args.gbs + mbs = args.mbs + max_steps = args.max_steps + + decoder_seq_length = 4096 + + if args.data_type == "energon": + from transformers import AutoProcessor + + from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer + from nemo.collections.multimodal.data.energon import SimpleMultiModalDataModule + from nemo.collections.multimodal.data.energon.config import MultiModalSampleConfig + from nemo.collections.vlm import LlavaNextTaskEncoder + + data_path = args.data_path + model_id = "llava-hf/llava-v1.6-vicuna-7b-hf" + processor = AutoProcessor.from_pretrained(model_id) + tokenizer = AutoTokenizer(model_id) + + multimodal_sample_config = MultiModalSampleConfig() + + task_encoder = LlavaNextTaskEncoder( + tokenizer=tokenizer.tokenizer, + image_processor=processor.image_processor, + multimodal_sample_config=multimodal_sample_config, + ) + data = SimpleMultiModalDataModule( + path=data_path, + tokenizer=tokenizer, + image_processor=processor.image_processor, + num_workers=32, + micro_batch_size=mbs, + global_batch_size=gbs, + multimodal_sample_config=multimodal_sample_config, + task_encoder=task_encoder, + ) + + elif args.data_type == "mock": + data = vlm.LlavaNextMockDataModule( + seq_length=decoder_seq_length, + global_batch_size=gbs, + micro_batch_size=mbs, + tokenizer=None, + image_processor=None, + num_workers=4, + ) + else: + raise ValueError(f"Data type {args.data_type} not supported") + + # Submodules configurations + language_transformer_config = llm.Llama2Config7B( + seq_length=decoder_seq_length, + ) + vision_transformer_config = vlm.HFCLIPVisionConfig( + pretrained_model_name_or_path="openai/clip-vit-large-patch14-336" + ) + vision_projection_config = vlm.MultimodalProjectorConfig( + projector_type=args.projector_type, + input_size=vision_transformer_config.hidden_size, + hidden_size=language_transformer_config.hidden_size, + ffn_hidden_size=language_transformer_config.hidden_size, + ) + + # Llava Next model configuration + llava_next_config = vlm.LlavaNextConfig( + language_transformer_config=language_transformer_config, + vision_transformer_config=vision_transformer_config, + vision_projection_config=vision_projection_config, + language_model_from_pretrained=args.language_model_path, + freeze_language_model=False, + freeze_vision_model=True, + ) + + model = vlm.LlavaNextModel(llava_next_config, tokenizer=data.tokenizer) + + # Training strategy setup + strategy = nl.MegatronStrategy( + tensor_model_parallel_size=args.tp_size, + pipeline_model_parallel_size=args.pp_size, + encoder_pipeline_model_parallel_size=args.encoder_pp_size, + pipeline_dtype=torch.bfloat16, + sequence_parallel=False, + ) + + # Checkpoint callback setup + checkpoint_callback = nl.ModelCheckpoint( + save_last=True, + monitor="reduced_train_loss", + save_top_k=2, + every_n_train_steps=1000, + dirpath=args.log_dir, + ) + + # Trainer setup + trainer = nl.Trainer( + num_nodes=args.num_nodes, + devices=args.devices, + max_steps=max_steps, + accelerator="gpu", + strategy=strategy, + plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"), + callbacks=[checkpoint_callback, TimingCallback()], + val_check_interval=500, + limit_val_batches=gbs, + log_every_n_steps=1, + num_sanity_val_steps=0, + ) + + # Logger setup + nemo_logger = nl.NeMoLogger( + log_dir=args.log_dir, + name=args.name, + wandb=WandbLogger(project=args.wandb_project, name=args.name) if args.wandb_project is not None else None, + ) + + # Auto resume setup + resume = nl.AutoResume( + resume_if_exists=True, + resume_ignore_no_checkpoint=True, + resume_from_directory=args.log_dir, + restore_config=nl.RestoreConfig(path=args.restore_path) if args.restore_path is not None else None, + ) + + # Optimizer and scheduler setup + opt_config = OptimizerConfig( + optimizer='adam', + lr=args.lr, + adam_beta1=0.9, + adam_beta2=0.95, + use_distributed_optimizer=True, + bf16=True, + ) + sched = CosineAnnealingScheduler( + max_steps=trainer.max_steps, + warmup_steps=150, + constant_steps=0, + min_lr=2.0e-07, + ) + opt = MegatronOptimizerModule(opt_config, sched) + + # PEFT setup + if args.peft == 'lora': + peft = vlm.peft.LoRA( + target_modules=[ + "linear_qkv", + "linear_proj", + "linear_fc1", + "linear_fc2", + ] + ) + else: + peft = None + + llm.finetune( + model=model, + data=data, + trainer=trainer, + peft=peft, + log=nemo_logger, + optim=opt, + resume=resume, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Llava Next Finetuning Script") + + # Argument parsing + parser.add_argument("--data_type", type=str, required=False, default="mock", help="mock | energon") + parser.add_argument("--data_path", type=str, required=False, default=None, help="Path to the dataset JSON file") + parser.add_argument( + "--log_dir", type=str, required=False, default="/results", help="Directory for logging and checkpoints" + ) + parser.add_argument( + "--language_model_path", type=str, required=False, default=None, help="Path to the pretrained language model" + ) + parser.add_argument( + "--restore_path", type=str, required=False, default=None, help="Path to restore model from checkpoint" + ) + parser.add_argument("--devices", type=int, required=False, default=1) + parser.add_argument("--num_nodes", type=int, required=False, default=1) + parser.add_argument("--max_steps", type=int, required=False, default=5190) + parser.add_argument("--tp_size", type=int, required=False, default=4) + parser.add_argument("--pp_size", type=int, required=False, default=1) + parser.add_argument("--encoder_pp_size", type=int, required=False, default=0) + parser.add_argument("--projector_type", type=str, required=False, default="mlp2x_gelu") + parser.add_argument("--name", type=str, required=False, default="llava_next_finetune") + parser.add_argument("--peft", type=str, default='none', help="none | lora") + parser.add_argument("--wandb_project", type=str, required=False, default=None) + parser.add_argument("--gbs", type=int, required=False, default=64, help="Global batch size") + parser.add_argument("--mbs", type=int, required=False, default=4, help="Micro batch size") + parser.add_argument("--lr", type=float, required=False, default=2.0e-05, help="Learning rate") + + args = parser.parse_args() + main(args) diff --git a/scripts/vlm/llava_next_generation.py b/scripts/vlm/llava_next_generation.py new file mode 100644 index 000000000000..1b3d8a46b955 --- /dev/null +++ b/scripts/vlm/llava_next_generation.py @@ -0,0 +1,166 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +import requests +import torch +from PIL import Image +from transformers import AutoProcessor + +from nemo import lightning as nl +from nemo.collections import vlm +from nemo.utils import logging + + +def load_image(image_url: str) -> Image.Image: + # pylint: disable=C0115,C0116 + try: + response = requests.get(image_url, stream=True) + response.raise_for_status() + image = Image.open(response.raw) + return image + except requests.exceptions.RequestException as e: + print(f"Error loading image from {image_url}: {e}") + return None + + +def generate(model, processor, raw_image, text): + # pylint: disable=C0115,C0116 + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What are these?"}, + {"type": "image"}, + ], + } + ] + + input_text = processor.apply_chat_template(messages, add_generation_prompt=True) + + input_ids = processor.tokenizer(input_text, return_tensors='pt').input_ids.cuda() + inputs = processor(input_text, raw_image, return_tensors='pt').to(0, torch.float32) + + input_ids[input_ids == 32000] = -200 + media = inputs['pixel_values'].cuda() + media = media.reshape(media.size(1), 3, 336, 336) + position_ids = ( + torch.arange(input_ids.size(1), dtype=torch.long, device=input_ids.device).unsqueeze(0).expand_as(input_ids) + ) + + generated_ids = input_ids.clone() + width, height = raw_image.size + image_sizes = torch.tensor([[height, width]], dtype=torch.long).cuda() + + for _ in range(20): + with torch.no_grad(): + attention_mask = (input_ids != 0).long().cuda() + output = model( + media=media, + input_ids=input_ids, + position_ids=position_ids, + image_sizes=image_sizes, + num_media_tiles=[media.size(0)], + attention_mask=attention_mask, + ) + next_token_ids = torch.argmax(output[:, -1], dim=-1, keepdim=True) + + generated_ids = torch.cat([generated_ids, next_token_ids], dim=-1) + + input_ids = generated_ids + position_ids = ( + torch.arange(input_ids.size(1), dtype=torch.long, device=input_ids.device) + .unsqueeze(0) + .expand_as(input_ids) + ) + print(f"next_token_ids {next_token_ids}") + + # If the generated token is the end of sequence token, stop generating + if next_token_ids.item() == processor.tokenizer.eos_token_id: + print(f"breaking") + break + generated_ids[generated_ids == -200] = 0 + generated_texts = processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=False) + logging.info("======== GENERATED TEXT OUTPUT ========") + logging.info(f"{generated_texts}") + logging.info("=======================================") + + +def main(args) -> None: + # pylint: disable=C0115,C0116 + model_id = 'llava-hf/llava-v1.6-vicuna-7b-hf' + strategy = nl.MegatronStrategy( + tensor_model_parallel_size=args.tp_size, + ckpt_load_optimizer=False, + ckpt_save_optimizer=False, + ) + trainer = nl.Trainer( + devices=args.tp_size, + max_steps=1000, + accelerator="gpu", + strategy=strategy, + plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"), + val_check_interval=1000, + limit_val_batches=50, + ) + + processor = AutoProcessor.from_pretrained(model_id) + tokenizer = processor.tokenizer + + fabric = trainer.to_fabric() + + if args.load_from_hf: + model = fabric.import_model("hf://llava-hf/llava-v1.6-vicuna-7b-hf", vlm.LlavaNextModel) + else: + model = vlm.LlavaNextModel(vlm.LlavaNextConfig7B(), tokenizer=tokenizer) + model = fabric.load_model(args.local_model_path, model) + + model = model.module.cuda() + model.eval() + model = model.to(torch.bfloat16) + + # Load the image + raw_image = load_image(args.image_url) + if raw_image is None: + return # Exit if the image can't be loaded + + generate(model, processor, raw_image=raw_image, text="What are these?") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Llava Next Generation example") + parser.add_argument( + "--load_from_hf", + action="store_true", + help="Flag to indicate whether to load the model from Hugging Face hub.", + ) + parser.add_argument( + "--local_model_path", + type=str, + default=None, + help="Local path to the model if not loading from Hugging Face.", + ) + parser.add_argument( + "--image_url", + type=str, + # pylint: disable=line-too-long + default="http://images.cocodataset.org/val2017/000000039769.jpg", + help="URL of the image to use for inference.", + ) + parser.add_argument("--devices", type=int, required=False, default=1) + parser.add_argument("--tp_size", type=int, required=False, default=1) + + args = parser.parse_args() + main(args) diff --git a/scripts/vlm/llava_next_nemo_run.py b/scripts/vlm/llava_next_nemo_run.py new file mode 100644 index 000000000000..3193b05e10fc --- /dev/null +++ b/scripts/vlm/llava_next_nemo_run.py @@ -0,0 +1,71 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import nemo_run as run + +from nemo.collections import vlm + + +def configure_recipe(nodes: int = 1, gpus_per_node: int = 8, pretrain=False): + """Configure the recipe""" + if pretrain: + recipe = vlm.llava_next_7b.pretrain_recipe( + dir="./outputs/checkpoints/llava", # Path to store checkpoints + name="llava_pretrain", + num_nodes=nodes, + num_gpus_per_node=gpus_per_node, + ) + else: + recipe = vlm.llava_next_7b.finetune_recipe( + dir="./outputs/checkpoints/llava", # Path to store checkpoints + name="llava_finetune", + num_nodes=nodes, + num_gpus_per_node=gpus_per_node, + ) + recipe.trainer.max_steps = 100 + recipe.trainer.val_check_interval = 100 + recipe.model.config.freeze_vision_model = True + return recipe + + +def local_executor_torchrun(nodes: int = 1, devices: int = 8) -> run.LocalExecutor: + # pylint: disable=C0115,C0116 + # Env vars for jobs are configured here + env_vars = {} + + executor = run.LocalExecutor(ntasks_per_node=devices, launcher="torchrun", env_vars=env_vars) + + return executor + + +def run_pretraining(): + # pylint: disable=C0115,C0116 + recipe = configure_recipe(pretrain=True) + executor = local_executor_torchrun(nodes=recipe.trainer.num_nodes, devices=recipe.trainer.devices) + + run.run(recipe, executor=executor) + + +def run_finetuning(): + # pylint: disable=C0115,C0116 + recipe = configure_recipe(pretrain=False) + executor = local_executor_torchrun(nodes=recipe.trainer.num_nodes, devices=recipe.trainer.devices) + + run.run(recipe, executor=executor) + + +# This condition is necessary for the script to be compatible with Python's multiprocessing module. +if __name__ == "__main__": + run_pretraining() + # run_finetuning() diff --git a/scripts/vlm/llava_next_pretrain.py b/scripts/vlm/llava_next_pretrain.py new file mode 100644 index 000000000000..bb84e3dae1e5 --- /dev/null +++ b/scripts/vlm/llava_next_pretrain.py @@ -0,0 +1,223 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Example: + torchrun --nproc_per_node=8 scripts/vlm/llava_next_pretrain.py \ + --devices=8 --tp=4 --data_type=mock + + torchrun --nproc_per_node=8 scripts/vlm/llava_next_pretrain.py \ + --devices=8 --tp=4 --data_type=energon --data_path='' \ + --language_model_path=/root/.cache/nemo/models/lmsys/vicuna-7b-v1.5 +""" + +import argparse + +import torch +from megatron.core.optimizer import OptimizerConfig +from pytorch_lightning.loggers import WandbLogger + +from nemo import lightning as nl +from nemo.collections import llm, vlm +from nemo.lightning.pytorch.optim import CosineAnnealingScheduler +from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule +from nemo.utils.exp_manager import TimingCallback + + +def main(args): + # pylint: disable=C0115,C0116 + + # Global and micro batch sizes + gbs = args.gbs + mbs = args.mbs + max_steps = args.max_steps + + decoder_seq_length = 4096 + + if args.data_type == "energon": + from transformers import AutoProcessor + + from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer + from nemo.collections.multimodal.data.energon import SimpleMultiModalDataModule + from nemo.collections.multimodal.data.energon.config import MultiModalSampleConfig + from nemo.collections.vlm import LlavaNextTaskEncoder + + data_path = args.data_path + model_id = "llava-hf/llava-v1.6-vicuna-7b-hf" + processor = AutoProcessor.from_pretrained(model_id) + tokenizer = AutoTokenizer(model_id) + + multimodal_sample_config = MultiModalSampleConfig() + # Setting system prompt to empty string + multimodal_sample_config.conversation_template_config.system = '' + + task_encoder = LlavaNextTaskEncoder( + tokenizer=tokenizer.tokenizer, + image_processor=processor.image_processor, + multimodal_sample_config=multimodal_sample_config, + ) + data = SimpleMultiModalDataModule( + path=data_path, + tokenizer=tokenizer, + image_processor=processor.image_processor, + num_workers=32, + micro_batch_size=mbs, + global_batch_size=gbs, + multimodal_sample_config=multimodal_sample_config, + task_encoder=task_encoder, + ) + + elif args.data_type == "mock": + data = vlm.LlavaNextMockDataModule( + seq_length=decoder_seq_length, + global_batch_size=gbs, + micro_batch_size=mbs, + tokenizer=None, + image_processor=None, + num_workers=4, + ) + else: + raise ValueError(f"Data type {args.data_type} not supported") + + # Submodules configurations + language_transformer_config = llm.Llama2Config7B( + seq_length=decoder_seq_length, + ) + vision_transformer_config = vlm.HFCLIPVisionConfig( + pretrained_model_name_or_path="openai/clip-vit-large-patch14-336" + ) + vision_projection_config = vlm.MultimodalProjectorConfig( + projector_type=args.projector_type, + input_size=vision_transformer_config.hidden_size, + hidden_size=language_transformer_config.hidden_size, + ffn_hidden_size=language_transformer_config.hidden_size, + ) + + # Llava Next model configuration + llava_next_config = vlm.LlavaNextConfig( + language_transformer_config=language_transformer_config, + vision_transformer_config=vision_transformer_config, + vision_projection_config=vision_projection_config, + language_model_from_pretrained=args.language_model_path, + freeze_language_model=True, + freeze_vision_model=True, + ) + + model = vlm.LlavaNextModel(llava_next_config, tokenizer=data.tokenizer) + + # Training strategy setup + strategy = nl.MegatronStrategy( + tensor_model_parallel_size=args.tp_size, + pipeline_model_parallel_size=args.pp_size, + encoder_pipeline_model_parallel_size=args.encoder_pp_size, + pipeline_dtype=torch.bfloat16, + sequence_parallel=False, + ) + + # Checkpoint callback setup + checkpoint_callback = nl.ModelCheckpoint( + save_last=True, + monitor="reduced_train_loss", + save_top_k=2, + every_n_train_steps=1000, + dirpath=args.log_dir, + ) + + # Trainer setup + trainer = nl.Trainer( + num_nodes=args.num_nodes, + devices=args.devices, + max_steps=max_steps, + accelerator="gpu", + strategy=strategy, + plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"), + callbacks=[checkpoint_callback, TimingCallback()], + val_check_interval=500, + limit_val_batches=gbs, + log_every_n_steps=1, + num_sanity_val_steps=0, + ) + + # Logger setup + nemo_logger = nl.NeMoLogger( + log_dir=args.log_dir, + name=args.name, + wandb=WandbLogger(project=args.wandb_project, name=args.name) if args.wandb_project is not None else None, + ) + + # Auto resume setup + resume = nl.AutoResume( + resume_if_exists=True, + resume_ignore_no_checkpoint=True, + resume_from_directory=args.log_dir, + restore_config=nl.RestoreConfig(path=args.restore_path) if args.restore_path is not None else None, + ) + + # Optimizer and scheduler setup + opt_config = OptimizerConfig( + optimizer='adam', + lr=args.lr, + adam_beta1=0.9, + adam_beta2=0.95, + use_distributed_optimizer=True, + bf16=True, + ) + sched = CosineAnnealingScheduler( + max_steps=trainer.max_steps, + warmup_steps=150, + constant_steps=0, + min_lr=2.0e-05, + ) + opt = MegatronOptimizerModule(opt_config, sched) + + llm.pretrain( + model=model, + data=data, + trainer=trainer, + log=nemo_logger, + optim=opt, + resume=resume, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Llava Next Pretraining Script") + + # Argument parsing + parser.add_argument("--data_type", type=str, required=False, default="mock", help="mock | energon") + parser.add_argument("--data_path", type=str, required=False, default=None, help="Path to the dataset JSON file") + parser.add_argument( + "--log_dir", type=str, required=False, default="/results", help="Directory for logging and checkpoints" + ) + parser.add_argument( + "--language_model_path", type=str, required=False, default=None, help="Path to the pretrained language model" + ) + parser.add_argument( + "--restore_path", type=str, required=False, default=None, help="Path to restore model from checkpoint" + ) + parser.add_argument("--devices", type=int, required=False, default=1) + parser.add_argument("--num_nodes", type=int, required=False, default=1) + parser.add_argument("--max_steps", type=int, required=False, default=2100) + parser.add_argument("--tp_size", type=int, required=False, default=2) + parser.add_argument("--pp_size", type=int, required=False, default=1) + parser.add_argument("--encoder_pp_size", type=int, required=False, default=0) + parser.add_argument("--projector_type", type=str, required=False, default="mlp2x_gelu") + parser.add_argument("--name", type=str, required=False, default="llava_next_pretrain") + parser.add_argument("--wandb_project", type=str, required=False, default=None) + parser.add_argument("--gbs", type=int, required=False, default=32, help="Global batch size") + parser.add_argument("--mbs", type=int, required=False, default=4, help="Micro batch size") + parser.add_argument("--lr", type=float, required=False, default=0.001, help="Learning rate") + + args = parser.parse_args() + main(args) diff --git a/scripts/vlm/mllama_finetune.py b/scripts/vlm/mllama_finetune.py new file mode 100644 index 000000000000..2b6990a03aa5 --- /dev/null +++ b/scripts/vlm/mllama_finetune.py @@ -0,0 +1,212 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +import torch +from megatron.core.optimizer import OptimizerConfig +from pytorch_lightning.loggers import WandbLogger +from transformers import AutoProcessor + +from nemo import lightning as nl +from nemo.collections import llm, vlm +from nemo.collections.vlm import ImageDataConfig +from nemo.collections.vlm.mllama.data.lazy import MLlamaLazyDataModule +from nemo.lightning.pytorch.optim import CosineAnnealingScheduler +from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule +from nemo.utils.exp_manager import TimingCallback + + +def main(args): + """ + Main function for setting up and training the MLLama model. + + This function prepares the data module, model, training strategy, + logger, checkpointing, and optimizer configuration. It then starts + the training loop using PyTorch Lightning's trainer. + + Args: + args (argparse.Namespace): The command-line arguments passed to the script. + """ + # Setting gbs, mbs, and max_steps from arguments + gbs = args.gbs + mbs = args.mbs + max_steps = args.max_steps + + # encoder (vision) seq length + # ((img_res / patch_size) ** 2 + cls_token) * num_tiles, = ((560 / 14) ** 2 + 1) * 4 = 6404 + seq_length = 6404 + decoder_seq_length = 1024 # decoder (llm) seq length + + if args.restore_path is not None and args.restore_path.startswith("nemo://"): + model_id = args.restore_path[len("nemo://") :] + else: + model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct" + + processor = AutoProcessor.from_pretrained(model_id) + image_processor = processor.image_processor + tokenizer = processor.tokenizer + + # Data configuration + data_config = ImageDataConfig( + image_folder=args.image_folder, + conv_template="mllama", + ) + + # Data module setup + data = MLlamaLazyDataModule( + paths=args.data_path, + data_config=data_config, + seq_length=seq_length, + decoder_seq_length=decoder_seq_length, + global_batch_size=gbs, + micro_batch_size=mbs, + tokenizer=tokenizer, + image_processor=image_processor, + num_workers=16, + ) + + model_configs = { + "meta-llama/Llama-3.2-11B-Vision": vlm.MLlamaConfig11B, + "meta-llama/Llama-3.2-11B-Vision-Instruct": vlm.MLlamaConfig11BInstruct, + "meta-llama/Llama-3.2-90B-Vision": vlm.MLlamaConfig90B, + "meta-llama/Llama-3.2-90B-Vision-Instruct": vlm.MLlamaConfig90BInstruct, + } + conf = model_configs[model_id]() + if args.pp_size > 1: + conf.language_model_config.first_pipeline_num_layers = 0 + model = vlm.MLlamaModel(conf, tokenizer=tokenizer) + + # Training strategy setup + strategy = nl.MegatronStrategy( + tensor_model_parallel_size=args.tp_size, + pipeline_model_parallel_size=args.pp_size, + encoder_pipeline_model_parallel_size=args.encoder_pp_size, + pipeline_dtype=torch.bfloat16, + ) + + # Checkpoint callback setup + checkpoint_callback = nl.ModelCheckpoint( + save_last=True, + monitor="reduced_train_loss", + save_top_k=6, + every_n_train_steps=100, + dirpath=args.log_dir, + ) + + # Trainer setup + trainer = nl.Trainer( + num_nodes=args.num_nodes, + devices=args.devices, + max_steps=max_steps, + accelerator="gpu", + strategy=strategy, + plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"), + callbacks=[checkpoint_callback, TimingCallback()], + val_check_interval=500, + limit_val_batches=gbs, + log_every_n_steps=1, + num_sanity_val_steps=0, + ) + + # Logger setup + nemo_logger = nl.NeMoLogger( + log_dir=args.log_dir, + name=args.name, + wandb=WandbLogger(project=args.wandb_project, name=args.name) if args.wandb_project is not None else None, + ) + + # Auto resume setup + resume = nl.AutoResume( + resume_if_exists=True, + resume_ignore_no_checkpoint=True, + resume_from_directory=args.log_dir, + restore_config=nl.RestoreConfig(path=args.restore_path) if args.restore_path is not None else None, + ) + + # Optimizer and scheduler setup + opt_config = OptimizerConfig( + optimizer='adam', + lr=args.lr, + adam_beta1=0.9, + adam_beta2=0.95, + use_distributed_optimizer=True, + bf16=True, + ) + sched = CosineAnnealingScheduler( + max_steps=trainer.max_steps, + warmup_steps=100, + constant_steps=0, + min_lr=args.lr, + ) + opt = MegatronOptimizerModule(opt_config, sched) + + # PEFT setup + if args.peft == 'lora': + peft = vlm.peft.LoRA( + freeze_vision_model=True, + target_modules=[ + "linear_qkv", + "linear_q", + "linear_kv", + ], + dim=8, + alpha=32, + dropout=0.05, + dropout_position="pre", + ) + else: + peft = None + + llm.finetune( + model=model, + data=data, + trainer=trainer, + peft=peft, + log=nemo_logger, + optim=opt, + resume=resume, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Mllama Model Training Script") + + parser.add_argument( + "--restore_path", type=str, required=False, default=None, help="Path to restore model from checkpoint" + ) + parser.add_argument("--data_path", type=str, required=True, help="Path to the dataset") + parser.add_argument("--image_folder", type=str, required=True, help="Path to the image folder") + parser.add_argument( + "--log_dir", + type=str, + required=False, + default="/results", + help="Directory for logging and checkpoints", + ) + parser.add_argument("--devices", type=int, required=False, default=1) + parser.add_argument("--num_nodes", type=int, required=False, default=1) + parser.add_argument("--max_steps", type=int, required=False, default=5190) + parser.add_argument("--tp_size", type=int, required=False, default=1) + parser.add_argument("--pp_size", type=int, required=False, default=1) + parser.add_argument("--encoder_pp_size", type=int, required=False, default=0) + parser.add_argument("--name", type=str, required=False, default="neva_pretrain") + parser.add_argument("--peft", type=str, default='none', help="none | lora") + parser.add_argument("--wandb_project", type=str, required=False, default=None) + parser.add_argument("--gbs", type=int, required=False, default=64, help="Global batch size") + parser.add_argument("--mbs", type=int, required=False, default=2, help="Micro batch size") + parser.add_argument("--lr", type=float, required=False, default=2.0e-06, help="Learning rate") + + args = parser.parse_args() + main(args) diff --git a/scripts/vlm/mllama_generation.py b/scripts/vlm/mllama_generation.py new file mode 100644 index 000000000000..4ebf2d0055ad --- /dev/null +++ b/scripts/vlm/mllama_generation.py @@ -0,0 +1,164 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +import requests +import torch +from PIL import Image +from transformers import AutoProcessor + +from nemo import lightning as nl +from nemo.collections import vlm +from nemo.collections.vlm.mllama.model.utils import create_vision_mask_tensor + +model_id = "meta-llama/Llama-3.2-11B-Vision-Instruct" + + +def load_image(image_url: str) -> Image.Image: + # pylint: disable=C0115,C0116 + try: + response = requests.get(image_url, stream=True) + response.raise_for_status() + image = Image.open(response.raw) + return image + except requests.exceptions.RequestException as e: + print(f"Error loading image from {image_url}: {e}") + return None + + +def generate(model, processor, image, text): + # pylint: disable=C0115,C0116 + tokenizer = processor.tokenizer + + messages = [ + { + "role": "user", + "content": [{"type": "text", "text": text}], + } + ] + input_text = processor.apply_chat_template(messages, add_generation_prompt=True) + batch = processor(image, input_text, add_special_tokens=False, return_tensors="pt") + + input_ids = batch["input_ids"].cuda(non_blocking=True) + position_ids = ( + torch.arange(input_ids.size(1), dtype=torch.long, device=input_ids.device).unsqueeze(0).expand_as(input_ids) + ) + num_tiles = processor.image_processor.preprocess(image, return_tensors='pt')["num_tiles"] + + min_prompt_len = position_ids.shape[-1] + + input_ids = input_ids[:, :min_prompt_len] + generated_ids = input_ids.clone() + + from tqdm import tqdm + + for cur_pos in tqdm(range(min_prompt_len, min_prompt_len + 100)): + with torch.no_grad(): + position_ids = torch.arange(0, cur_pos, dtype=torch.long, device="cuda").reshape(1, -1) + batch_masks = create_vision_mask_tensor(generated_ids[0]) + + output = model( + batch_images=batch["pixel_values"].cuda(non_blocking=True), + batch_masks=[batch_masks], + num_chunks=torch.tensor(num_tiles), + aspect_ratio_ids=batch["aspect_ratio_ids"].cuda(non_blocking=True), + tokens=generated_ids, + position_ids=position_ids, + ) + + next_token_ids = torch.argmax(output[:, -1], dim=-1, keepdim=True) + # Broadcast the tensor from rank 0 to all other ranks + torch.distributed.broadcast(next_token_ids, src=0) + generated_ids = torch.cat([generated_ids, next_token_ids], dim=-1) + if (next_token_ids == tokenizer.eos_token_id).all(): + break + + generated_ids = generated_ids.tolist() + generated_texts = tokenizer.decode(generated_ids[0][min_prompt_len:]) + + if torch.distributed.get_rank() == 0: + print("======== GENERATED TEXT OUTPUT ========") + print(f"{generated_texts}") + print("=======================================") + return generated_texts + + +def main(args) -> None: + # pylint: disable=C0115,C0116 + strategy = nl.MegatronStrategy( + tensor_model_parallel_size=args.tp_size, + ckpt_load_optimizer=False, + ckpt_save_optimizer=False, + ) + trainer = nl.Trainer( + devices=args.tp_size, + max_steps=1000, + accelerator="gpu", + strategy=strategy, + plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"), + val_check_interval=1000, + limit_val_batches=50, + ) + + processor = AutoProcessor.from_pretrained(model_id) + tokenizer = processor.tokenizer + + fabric = trainer.to_fabric() + + if args.load_from_hf: + model = fabric.import_model(f"hf://{model_id}", vlm.MLlamaModel) + else: + model = vlm.MLlamaModel(vlm.MLlamaConfig11BInstruct(), tokenizer=tokenizer) + model = fabric.load_model(args.local_model_path, model) + + model = model.module.cuda() + model.eval() + model = model.to(torch.bfloat16) + + # Load the image + raw_image = load_image(args.image_url) + if raw_image is None: + return # Exit if the image can't be loaded + + generate(model, processor, image=raw_image, text="<|image|>\nDescribe the image.") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="") + parser.add_argument( + "--load_from_hf", + action="store_true", + help="Flag to indicate whether to load the model from Hugging Face hub.", + ) + parser.add_argument( + "--local_model_path", + type=str, + default=None, + help="Local path to the model if not loading from Hugging Face.", + ) + parser.add_argument( + "--image_url", + type=str, + # pylint: disable=line-too-long + default="https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg", + help="URL of the image to use for inference.", + ) + parser.add_argument("--devices", type=int, required=False, default=1) + parser.add_argument("--tp_size", type=int, required=False, default=1) + parser.add_argument("--pp_size", type=int, required=False, default=1) + parser.add_argument("--encoder_pp_size", type=int, required=False, default=0) + + args = parser.parse_args() + main(args) diff --git a/tests/collections/asr/confidence/test_asr_confidence.py b/tests/collections/asr/confidence/test_asr_confidence.py index 015264a9debe..89beb61f50bf 100644 --- a/tests/collections/asr/confidence/test_asr_confidence.py +++ b/tests/collections/asr/confidence/test_asr_confidence.py @@ -19,8 +19,8 @@ import numpy as np import pytest +from lightning.pytorch import Trainer from omegaconf import OmegaConf -from pytorch_lightning import Trainer from nemo.collections.asr.models import ASRModel, EncDecCTCModelBPE, EncDecRNNTBPEModel from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecodingConfig diff --git a/tests/collections/asr/numba/rnnt_loss/test_rnnt_pytorch.py b/tests/collections/asr/numba/rnnt_loss/test_rnnt_pytorch.py index 1a29a14f540d..9f38bf6dbe8a 100644 --- a/tests/collections/asr/numba/rnnt_loss/test_rnnt_pytorch.py +++ b/tests/collections/asr/numba/rnnt_loss/test_rnnt_pytorch.py @@ -33,6 +33,7 @@ if torch.cuda.is_available(): DEVICES.append('cuda') +CUDA_ONLY_DEVICE = ['cuda'] DTYPES = [np.float32] if numba_utils.is_numba_cuda_fp16_supported(): @@ -542,65 +543,86 @@ def test_case_randomized_act_label(self, device): class TestTDTLoss: @pytest.mark.unit - @pytest.mark.parametrize('device', DEVICES) + @pytest.mark.parametrize('device', CUDA_ONLY_DEVICE) def test_case_randomized_act_label(self, device): - if device == 'cuda': - numba_utils.skip_numba_cuda_test_if_unsupported(__NUMBA_MINIMUM_VERSION__) + numba_utils.skip_numba_cuda_test_if_unsupported(__NUMBA_MINIMUM_VERSION__) - B, T, U, V = 4, 8, 4, 8 # here V is number of non blank labels - durations = [0, 1, 2, 3, 4, 5] - sigma = 0.05 + B, T, U, V = 4, 8, 4, 8 # here V is number of non blank labels + durations = [0, 1, 2, 3, 4, 5] + sigma = 0.05 - acts = torch.rand([B, T, U, V + 1 + len(durations)]) - labels = [[random.randrange(0, V) for i in range(U - 1)] for j in range(B)] + acts = torch.rand([B, T, U, V + 1 + len(durations)]) + labels = [[random.randrange(0, V) for i in range(U - 1)] for j in range(B)] - fn_pt = TDTLossNumba(blank=V, reduction='sum', durations=durations, sigma=sigma) - pt_cost, pt_grads = wrap_and_call(fn_pt, acts, labels, device) + fn_pt = TDTLossNumba(blank=V, reduction='sum', durations=durations, sigma=sigma) + pt_cost, pt_grads = wrap_and_call(fn_pt, acts, labels, device) - fn_ag = TDTLossPytorch( - blank=V, reduction='sum', durations=durations, sigma=sigma - ) # ag for automatic gradient computation - ag_cost, ag_grads = wrap_and_call(fn_ag, acts, labels, device) + fn_ag = TDTLossPytorch( + blank=V, reduction='sum', durations=durations, sigma=sigma + ) # ag for automatic gradient computation + ag_cost, ag_grads = wrap_and_call(fn_ag, acts, labels, device) - assert np.allclose(pt_cost, ag_cost, rtol=1e-6), "tdt costs mismatch." - assert np.allclose(pt_grads, ag_grads, rtol=1e-2), "td gradient mismatch." + assert np.allclose(pt_cost, ag_cost, rtol=1e-6), "tdt costs mismatch." + assert np.allclose(pt_grads, ag_grads, rtol=1e-2), "td gradient mismatch." @pytest.mark.unit - @pytest.mark.parametrize('device', DEVICES) + @pytest.mark.parametrize('device', CUDA_ONLY_DEVICE) + def test_case_randomized_act_label_no_0_duration(self, device): + numba_utils.skip_numba_cuda_test_if_unsupported(__NUMBA_MINIMUM_VERSION__) + + B, T, U, V = 4, 8, 4, 8 # here V is number of non blank labels + durations = [1, 2, 3, 4, 5] + sigma = 0.05 + + acts = torch.rand([B, T, U, V + 1 + len(durations)]) + labels = [[random.randrange(0, V) for i in range(U - 1)] for j in range(B)] + + fn_pt = TDTLossNumba(blank=V, reduction='sum', durations=durations, sigma=sigma) + pt_cost, pt_grads = wrap_and_call(fn_pt, acts, labels, device) + + fn_ag = TDTLossPytorch( + blank=V, reduction='sum', durations=durations, sigma=sigma + ) # ag for automatic gradient computation + ag_cost, ag_grads = wrap_and_call(fn_ag, acts, labels, device) + + assert np.allclose(pt_cost, ag_cost, rtol=1e-6), "tdt costs mismatch." + assert np.allclose(pt_grads, ag_grads, rtol=1e-2), "td gradient mismatch." + + @pytest.mark.unit + @pytest.mark.parametrize('device', CUDA_ONLY_DEVICE) def test_case_fixed_case_act_label(self, device): - if device == 'cuda': - numba_utils.skip_numba_cuda_test_if_unsupported(__NUMBA_MINIMUM_VERSION__) + numba_utils.skip_numba_cuda_test_if_unsupported(__NUMBA_MINIMUM_VERSION__) - B, T, U, V = 1, 3, 2, 3 # here V is number of non blank labels - durations = [0, 1, 2] - sigma = 0.05 + B, T, U, V = 1, 3, 2, 3 # here V is number of non blank labels + durations = [0, 1, 2] + sigma = 0.05 - acts = torch.zeros([B, T, U, V + 1 + len(durations)]) - labels = [[(i + j) % (V - 1) for i in range(U - 1)] for j in range(B)] + acts = torch.zeros([B, T, U, V + 1 + len(durations)]) + labels = [[(i + j) % (V - 1) for i in range(U - 1)] for j in range(B)] - fn_pt = TDTLossNumba(blank=V, reduction='sum', durations=durations, sigma=sigma) - pt_cost, pt_grads = wrap_and_call(fn_pt, acts, labels, device) + fn_pt = TDTLossNumba(blank=V, reduction='sum', durations=durations, sigma=sigma) + pt_cost, pt_grads = wrap_and_call(fn_pt, acts, labels, device) - expected_cost = 4.155739 - expected_grads = [ + expected_cost = 4.155739 + expected_grads = [ + [ [ - [ - [-0.64962804, 0.25, 0.25, 0.14962798, 0.2672583, -0.16792619, -0.09933221], - [0.01651875, 0.01651875, 0.01651875, -0.04955626, 0.022025, -0.01227201, -0.009753], - ], - [ - [-0.04892651, 0.01714851, 0.01714851, 0.01462949, -0.01143234, -0.01143234, 0.02286467], - [0.12531489, 0.12531489, 0.12531489, -0.37594467, 0.16708651, 0.13027048, -0.29735702], - ], - [ - [-0.02572276, 0.00857425, 0.00857425, 0.00857425, -0.02286468, 0.01143234, 0.01143234], - [0.13388914, 0.13388914, 0.13388914, -0.40166742, 0.17851885, -0.35703772, 0.17851885], - ], - ] + [-0.64962804, 0.25, 0.25, 0.14962798, 0.2672583, -0.16792619, -0.09933221], + [0.01651875, 0.01651875, 0.01651875, -0.04955626, 0.022025, -0.01227201, -0.009753], + ], + [ + [-0.04892651, 0.01714851, 0.01714851, 0.01462949, -0.01143234, -0.01143234, 0.02286467], + [0.12531489, 0.12531489, 0.12531489, -0.37594467, 0.16708651, 0.13027048, -0.29735702], + ], + [ + [-0.02572276, 0.00857425, 0.00857425, 0.00857425, -0.02286468, 0.01143234, 0.01143234], + [0.13388914, 0.13388914, 0.13388914, -0.40166742, 0.17851885, -0.35703772, 0.17851885], + ], ] + ] - assert np.allclose(pt_cost, expected_cost, rtol=1e-6), "tdt costs mismatch." - assert np.allclose(pt_grads, expected_grads, rtol=1e-2), "td gradient mismatch." + assert np.allclose(pt_cost, expected_cost, rtol=1e-6), "tdt costs mismatch." + assert np.allclose(pt_grads, expected_grads, rtol=1e-2), "td gradient mismatch." if __name__ == "__main__": diff --git a/tests/collections/asr/test_asr_context_biasing.py b/tests/collections/asr/test_asr_context_biasing.py index 0fa76fdfb95d..b23b12655a8d 100644 --- a/tests/collections/asr/test_asr_context_biasing.py +++ b/tests/collections/asr/test_asr_context_biasing.py @@ -19,7 +19,7 @@ import numpy as np import pytest import torch -from pytorch_lightning import Trainer +from lightning.pytorch import Trainer from nemo.collections.asr.models import EncDecCTCModelBPE from nemo.collections.asr.parts import context_biasing @@ -105,25 +105,43 @@ def test_merge_alignment_with_ws_hyps(self, conformer_ctc_bpe_model): # ctc argmax predictions preds = np.array([120, 29, blank_idx, blank_idx]) pred_text, raw_text = context_biasing.merge_alignment_with_ws_hyps( - preds, asr_model, ws_results, decoder_type="ctc", blank_idx=blank_idx, + preds, + asr_model, + ws_results, + decoder_type="ctc", + blank_idx=blank_idx, ) assert raw_text == "gp" assert pred_text == "gpu" # rnnt token predictions preds = rnnt_utils.Hypothesis( - y_sequence=torch.tensor([120, 29]), score=0.0, timestep=torch.tensor([0, 1, 2, 3]), + y_sequence=torch.tensor([120, 29]), + score=0.0, + timestep=torch.tensor([0, 1, 2, 3]), ) pred_text, raw_text = context_biasing.merge_alignment_with_ws_hyps( - preds, asr_model, ws_results, decoder_type="rnnt", blank_idx=blank_idx, + preds, + asr_model, + ws_results, + decoder_type="rnnt", + blank_idx=blank_idx, ) assert raw_text == "gp" assert pred_text == "gpu" # rnnt empty token predictions - preds = rnnt_utils.Hypothesis(y_sequence=[], score=0.0, timestep=[],) + preds = rnnt_utils.Hypothesis( + y_sequence=[], + score=0.0, + timestep=[], + ) pred_text, raw_text = context_biasing.merge_alignment_with_ws_hyps( - preds, asr_model, ws_results, decoder_type="rnnt", blank_idx=blank_idx, + preds, + asr_model, + ws_results, + decoder_type="rnnt", + blank_idx=blank_idx, ) assert raw_text == "" assert pred_text == "gpu" diff --git a/tests/collections/asr/test_asr_ctc_encoder_model_bpe.py b/tests/collections/asr/test_asr_ctc_encoder_model_bpe.py index 247906247091..02442291a918 100644 --- a/tests/collections/asr/test_asr_ctc_encoder_model_bpe.py +++ b/tests/collections/asr/test_asr_ctc_encoder_model_bpe.py @@ -19,9 +19,12 @@ import pytest import torch +from lhotse import CutSet, MonoCut +from lhotse.testing.dummies import DummyManifest from omegaconf import DictConfig from nemo.collections.asr.data import audio_to_text +from nemo.collections.asr.data.audio_to_text_lhotse import LhotseSpeechToTextBpeDataset from nemo.collections.asr.models import configs from nemo.collections.asr.models.ctc_bpe_models import EncDecCTCModelBPE from nemo.collections.asr.parts.submodules import ctc_beam_decoding as beam_decode @@ -118,6 +121,18 @@ def test_forward(self, asr_model): diff = torch.max(torch.abs(logprobs_instance - logprobs_batch)) assert diff <= 1e-6 + @pytest.mark.unit + def test_predict_step(self, asr_model): + asr_model = asr_model.eval() + cuts = DummyManifest(CutSet, begin_id=0, end_id=1, with_data=True) + dataset = LhotseSpeechToTextBpeDataset(tokenizer=asr_model.tokenizer, return_cuts=True) + batch = dataset[cuts] + outputs = asr_model.predict_step(batch, 0) + assert len(outputs) == 1 + assert len(outputs[0]) == 2 + assert isinstance(outputs[0][0], MonoCut) + assert isinstance(outputs[0][1], str) + @pytest.mark.with_downloads() @pytest.mark.unit def test_save_restore_artifact(self, asr_model): @@ -333,6 +348,15 @@ def test_ASRDatasetConfig_for_AudioToBPEDataset(self): 'bucketing_strategy', 'bucketing_weights', 'channel_selector', + 'use_lhotse', + 'tarred_random_access', + 'use_bucketing', + 'batch_duration', + 'quadratic_duration', + 'bucket_batch_size', + 'bucket_duration_bins', + 'num_buckets', + 'pin_memory', ] REMAP_ARGS = {'trim_silence': 'trim', 'labels': 'tokenizer'} @@ -372,6 +396,15 @@ def test_ASRDatasetConfig_for_TarredAudioToBPEDataset(self): 'bucketing_strategy', 'bucketing_weights', 'max_utts', + 'use_lhotse', + 'tarred_random_access', + 'use_bucketing', + 'batch_duration', + 'quadratic_duration', + 'bucket_batch_size', + 'bucket_duration_bins', + 'num_buckets', + 'pin_memory', ] REMAP_ARGS = { diff --git a/tests/collections/asr/test_asr_ctcencdec_model.py b/tests/collections/asr/test_asr_ctcencdec_model.py index 28a07fd54663..55451758578f 100644 --- a/tests/collections/asr/test_asr_ctcencdec_model.py +++ b/tests/collections/asr/test_asr_ctcencdec_model.py @@ -15,12 +15,16 @@ import pytest import torch +from lhotse import CutSet, MonoCut +from lhotse.testing.dummies import DummyManifest from omegaconf import DictConfig, OmegaConf, open_dict import nemo.collections.asr as nemo_asr from nemo.collections.asr.data import audio_to_text +from nemo.collections.asr.data.audio_to_text_lhotse import LhotseSpeechToTextBpeDataset from nemo.collections.asr.models import EncDecCTCModel, configs from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecoding, CTCDecodingConfig +from nemo.collections.common.parts.preprocessing.parsers import make_parser from nemo.utils.config_utils import assert_dataclass_signature_match, update_model_config @@ -131,6 +135,19 @@ def test_forward(self, asr_model): diff = torch.max(torch.abs(logprobs_instance - logprobs_batch)) assert diff <= 1e-6 + @pytest.mark.unit + def test_predict_step(self, asr_model): + token_list = [" ", "a", "b", "c"] + asr_model = asr_model.eval() + cuts = DummyManifest(CutSet, begin_id=0, end_id=1, with_data=True) + dataset = LhotseSpeechToTextBpeDataset(tokenizer=make_parser(labels=token_list), return_cuts=True) + batch = dataset[cuts] + outputs = asr_model.predict_step(batch, 0) + assert len(outputs) == 1 + assert len(outputs[0]) == 2 + assert isinstance(outputs[0][0], MonoCut) + assert isinstance(outputs[0][1], str) + @pytest.mark.unit def test_vocab_change(self, asr_model): old_vocab = copy.deepcopy(asr_model.decoder.vocabulary) @@ -274,6 +291,15 @@ def test_ASRDatasetConfig_for_AudioToCharDataset(self): 'bucketing_strategy', 'bucketing_weights', 'channel_selector', + 'use_lhotse', + 'tarred_random_access', + 'use_bucketing', + 'batch_duration', + 'quadratic_duration', + 'bucket_batch_size', + 'bucket_duration_bins', + 'num_buckets', + 'pin_memory', ] REMAP_ARGS = {'trim_silence': 'trim'} @@ -307,6 +333,15 @@ def test_ASRDatasetConfig_for_TarredAudioToCharDataset(self): 'bucketing_strategy', 'bucketing_weights', 'max_utts', + 'use_lhotse', + 'tarred_random_access', + 'use_bucketing', + 'batch_duration', + 'quadratic_duration', + 'bucket_batch_size', + 'bucket_duration_bins', + 'num_buckets', + 'pin_memory', ] REMAP_ARGS = { diff --git a/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_bpe.py b/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_bpe.py index 1743acc6878c..d13c879e47f9 100644 --- a/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_bpe.py +++ b/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_bpe.py @@ -18,8 +18,11 @@ import pytest import torch +from lhotse import CutSet, MonoCut +from lhotse.testing.dummies import DummyManifest from omegaconf import DictConfig +from nemo.collections.asr.data.audio_to_text_lhotse import LhotseSpeechToTextBpeDataset from nemo.collections.asr.models.hybrid_rnnt_ctc_bpe_models import EncDecHybridRNNTCTCBPEModel from nemo.collections.asr.parts.submodules import rnnt_beam_decoding as beam_decode from nemo.collections.asr.parts.submodules import rnnt_greedy_decoding as greedy_decode @@ -166,6 +169,18 @@ def test_forward(self, hybrid_asr_model): diff = torch.max(torch.abs(logits_instance - logprobs_batch)) assert diff <= 1e-6 + @pytest.mark.unit + def test_predict_step(self, hybrid_asr_model): + hybrid_asr_model = hybrid_asr_model.eval() + cuts = DummyManifest(CutSet, begin_id=0, end_id=1, with_data=True) + dataset = LhotseSpeechToTextBpeDataset(tokenizer=hybrid_asr_model.tokenizer, return_cuts=True) + batch = dataset[cuts] + outputs = hybrid_asr_model.predict_step(batch, 0) + assert len(outputs) == 1 + assert len(outputs[0]) == 2 + assert isinstance(outputs[0][0], MonoCut) + assert isinstance(outputs[0][1], str) + @pytest.mark.with_downloads() @pytest.mark.skipif( not NUMBA_RNNT_LOSS_AVAILABLE, diff --git a/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_char.py b/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_char.py index 5362966e2e9e..b5c34e197237 100644 --- a/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_char.py +++ b/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_char.py @@ -16,14 +16,18 @@ import pytest import torch +from lhotse import CutSet, MonoCut +from lhotse.testing.dummies import DummyManifest from omegaconf import DictConfig, ListConfig +from nemo.collections.asr.data.audio_to_text_lhotse import LhotseSpeechToTextBpeDataset from nemo.collections.asr.models import EncDecHybridRNNTCTCModel from nemo.collections.asr.modules import RNNTDecoder, RNNTJoint, SampledRNNTJoint, StatelessTransducerDecoder from nemo.collections.asr.parts.submodules import rnnt_beam_decoding as beam_decode from nemo.collections.asr.parts.submodules import rnnt_greedy_decoding as greedy_decode from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecoding, CTCDecodingConfig from nemo.collections.asr.parts.utils import rnnt_utils +from nemo.collections.common.parts.preprocessing.parsers import make_parser from nemo.core.utils import numba_utils from nemo.core.utils.numba_utils import __NUMBA_MINIMUM_VERSION__ from nemo.utils.config_utils import assert_dataclass_signature_match @@ -164,6 +168,19 @@ def test_forward(self, hybrid_asr_model): diff = torch.max(torch.abs(logprobs_instance - logprobs_batch)) assert diff <= 1e-6 + @pytest.mark.unit + def test_predict_step(self, hybrid_asr_model): + token_list = [" ", "a", "b", "c"] + hybrid_asr_model = hybrid_asr_model.eval() + cuts = DummyManifest(CutSet, begin_id=0, end_id=1, with_data=True) + dataset = LhotseSpeechToTextBpeDataset(tokenizer=make_parser(labels=token_list), return_cuts=True) + batch = dataset[cuts] + outputs = hybrid_asr_model.predict_step(batch, 0) + assert len(outputs) == 1 + assert len(outputs[0]) == 2 + assert isinstance(outputs[0][0], MonoCut) + assert isinstance(outputs[0][1], str) + @pytest.mark.skipif( not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', diff --git a/tests/collections/asr/test_asr_interctc_models.py b/tests/collections/asr/test_asr_interctc_models.py index 8d5e4b0b689c..a8d7101033ab 100644 --- a/tests/collections/asr/test_asr_interctc_models.py +++ b/tests/collections/asr/test_asr_interctc_models.py @@ -13,8 +13,8 @@ # limitations under the License. from typing import Dict +import lightning.pytorch as pl import pytest -import pytorch_lightning as pl import torch from omegaconf import DictConfig, ListConfig @@ -68,7 +68,8 @@ def squeezeformer_encoder_config() -> Dict: class TestInterCTCLoss: @pytest.mark.unit @pytest.mark.parametrize( - "model_class", [EncDecCTCModel, EncDecHybridRNNTCTCModel], + "model_class", + [EncDecCTCModel, EncDecHybridRNNTCTCModel], ) @pytest.mark.parametrize( "encoder_config", @@ -241,10 +242,12 @@ def __getitem__(self, idx): trainer.fit( asr_model, train_dataloaders=torch.utils.data.DataLoader( - DummyDataset([input_signal, input_length, target, target_length]), collate_fn=lambda x: x[0], + DummyDataset([input_signal, input_length, target, target_length]), + collate_fn=lambda x: x[0], ), val_dataloaders=torch.utils.data.DataLoader( - DummyDataset([input_signal, input_length, target, target_length]), collate_fn=lambda x: x[0], + DummyDataset([input_signal, input_length, target, target_length]), + collate_fn=lambda x: x[0], ), ) required_metrics = ['final_loss'] if len(loss_weights) > 0 else [] @@ -264,7 +267,8 @@ def __getitem__(self, idx): trainer.test( asr_model, dataloaders=torch.utils.data.DataLoader( - DummyDataset([input_signal, input_length, target, target_length]), collate_fn=lambda x: x[0], + DummyDataset([input_signal, input_length, target, target_length]), + collate_fn=lambda x: x[0], ), ) required_metrics = [f'inter_ctc_loss_l{idx}' for idx in apply_at_layers] diff --git a/tests/collections/asr/test_asr_lhotse_dataset.py b/tests/collections/asr/test_asr_lhotse_dataset.py index 5a1450e606ac..c131fac70310 100644 --- a/tests/collections/asr/test_asr_lhotse_dataset.py +++ b/tests/collections/asr/test_asr_lhotse_dataset.py @@ -65,3 +65,35 @@ def test_lhotse_asr_dataset(tokenizer): assert tokens[2].tolist() == [1, 7, 10, 19, 20, 21, 1, 20, 6, 4, 16, 15, 5] assert token_lens.tolist() == [11, 11, 13] + + +def test_lhotse_asr_dataset_metadata(tokenizer): + + cuts = DummyManifest(CutSet, begin_id=0, end_id=2, with_data=True) + + cuts[0].id = "cuts0" + cuts[1].id = "cuts1" + cuts[0].supervisions = [ + SupervisionSegment(id="cuts0-sup0", recording_id=cuts[0].recording_id, start=0.2, duration=0.5, text="first"), + ] + cuts[1].supervisions = [ + SupervisionSegment(id="cuts1-sup0", recording_id=cuts[1].recording_id, start=0, duration=1, text=""), + ] + + datasets_metadata = LhotseSpeechToTextBpeDataset(tokenizer=tokenizer, return_cuts=True) + batch = datasets_metadata[cuts] + assert isinstance(batch, tuple) + assert len(batch) == 5 + + _, _, _, _, cuts_metadata = batch + + assert cuts_metadata[0].supervisions[0].text == "first" + assert cuts_metadata[1].supervisions[0].text == "" + assert cuts_metadata[0].id == "cuts0" + assert cuts_metadata[1].id == "cuts1" + + assert cuts_metadata[0].supervisions[0].duration == 0.5 + assert cuts_metadata[0].supervisions[0].start == 0.2 + + assert cuts_metadata[1].supervisions[0].duration == 1 + assert cuts_metadata[1].supervisions[0].start == 0.0 diff --git a/tests/collections/asr/test_asr_local_attn.py b/tests/collections/asr/test_asr_local_attn.py index 257dc0949af3..3013c0efbddf 100644 --- a/tests/collections/asr/test_asr_local_attn.py +++ b/tests/collections/asr/test_asr_local_attn.py @@ -15,8 +15,8 @@ import shutil import tempfile +import lightning.pytorch as pl import pytest -import pytorch_lightning as pl import torch from omegaconf import DictConfig @@ -89,10 +89,12 @@ def test_change_save_restore(self): @pytest.mark.unit @pytest.mark.parametrize( - "global_tokens", [0, 1, 4], + "global_tokens", + [0, 1, 4], ) @pytest.mark.parametrize( - "global_tokens_spacing", [1, 4], + "global_tokens_spacing", + [1, 4], ) def test_train(self, global_tokens, global_tokens_spacing): preprocessor_config = {'_target_': 'nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor'} @@ -178,15 +180,18 @@ def __getitem__(self, idx): trainer.fit( asr_model, train_dataloaders=torch.utils.data.DataLoader( - DummyDataset([input_signal, input_length, target, target_length]), collate_fn=lambda x: x[0], + DummyDataset([input_signal, input_length, target, target_length]), + collate_fn=lambda x: x[0], ), val_dataloaders=torch.utils.data.DataLoader( - DummyDataset([input_signal, input_length, target, target_length]), collate_fn=lambda x: x[0], + DummyDataset([input_signal, input_length, target, target_length]), + collate_fn=lambda x: x[0], ), ) trainer.test( asr_model, dataloaders=torch.utils.data.DataLoader( - DummyDataset([input_signal, input_length, target, target_length]), collate_fn=lambda x: x[0], + DummyDataset([input_signal, input_length, target, target_length]), + collate_fn=lambda x: x[0], ), ) diff --git a/tests/collections/asr/test_asr_rnnt_encdec_model.py b/tests/collections/asr/test_asr_rnnt_encdec_model.py index d68088fce376..5e810243c919 100644 --- a/tests/collections/asr/test_asr_rnnt_encdec_model.py +++ b/tests/collections/asr/test_asr_rnnt_encdec_model.py @@ -17,13 +17,17 @@ import pytest import torch import torch.nn.functional as F +from lhotse import CutSet, MonoCut +from lhotse.testing.dummies import DummyManifest from omegaconf import DictConfig, ListConfig +from nemo.collections.asr.data.audio_to_text_lhotse import LhotseSpeechToTextBpeDataset from nemo.collections.asr.models import EncDecRNNTModel from nemo.collections.asr.modules import HATJoint, RNNTDecoder, RNNTJoint, SampledRNNTJoint, StatelessTransducerDecoder from nemo.collections.asr.parts.submodules import rnnt_beam_decoding as beam_decode from nemo.collections.asr.parts.submodules import rnnt_greedy_decoding as greedy_decode from nemo.collections.asr.parts.utils import rnnt_utils +from nemo.collections.common.parts.preprocessing.parsers import make_parser from nemo.core.utils import numba_utils from nemo.core.utils.numba_utils import __NUMBA_MINIMUM_VERSION__ from nemo.utils.config_utils import assert_dataclass_signature_match @@ -296,6 +300,19 @@ def test_forward(self, asr_model): diff = torch.max(torch.abs(logprobs_instance - logprobs_batch)) assert diff <= 1e-6 + @pytest.mark.unit + def test_predict_step(self, asr_model): + token_list = [" ", "a", "b", "c"] + asr_model = asr_model.eval() + cuts = DummyManifest(CutSet, begin_id=0, end_id=1, with_data=True) + dataset = LhotseSpeechToTextBpeDataset(tokenizer=make_parser(labels=token_list), return_cuts=True) + batch = dataset[cuts] + outputs = asr_model.predict_step(batch, 0) + assert len(outputs) == 1 + assert len(outputs[0]) == 2 + assert isinstance(outputs[0][0], MonoCut) + assert isinstance(outputs[0][1], str) + @pytest.mark.skipif( not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', diff --git a/tests/collections/asr/test_asr_rnnt_encoder_model_bpe.py b/tests/collections/asr/test_asr_rnnt_encoder_model_bpe.py index 960445061e24..aba364868e88 100644 --- a/tests/collections/asr/test_asr_rnnt_encoder_model_bpe.py +++ b/tests/collections/asr/test_asr_rnnt_encoder_model_bpe.py @@ -18,8 +18,11 @@ import pytest import torch +from lhotse import CutSet, MonoCut +from lhotse.testing.dummies import DummyManifest from omegaconf import DictConfig +from nemo.collections.asr.data.audio_to_text_lhotse import LhotseSpeechToTextBpeDataset from nemo.collections.asr.models import ASRModel from nemo.collections.asr.models.rnnt_bpe_models import EncDecRNNTBPEModel from nemo.collections.asr.parts.submodules import rnnt_beam_decoding as beam_decode @@ -64,12 +67,18 @@ def asr_model(test_data_dir): decoder = { '_target_': 'nemo.collections.asr.modules.RNNTDecoder', - 'prednet': {'pred_hidden': model_defaults['pred_hidden'], 'pred_rnn_layers': 1,}, + 'prednet': { + 'pred_hidden': model_defaults['pred_hidden'], + 'pred_rnn_layers': 1, + }, } joint = { '_target_': 'nemo.collections.asr.modules.RNNTJoint', - 'jointnet': {'joint_hidden': 32, 'activation': 'relu',}, + 'jointnet': { + 'joint_hidden': 32, + 'activation': 'relu', + }, } decoding = {'strategy': 'greedy_batch', 'greedy': {'max_symbols': 30}} @@ -123,7 +132,8 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): class TestEncDecRNNTBPEModel: @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.with_downloads() @pytest.mark.unit @@ -137,7 +147,8 @@ def test_constructor(self, asr_model): @pytest.mark.with_downloads() @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit def test_forward(self, asr_model): @@ -170,9 +181,22 @@ def test_forward(self, asr_model): diff = torch.max(torch.abs(logits_instance - logprobs_batch)) assert diff <= 1e-6 + @pytest.mark.unit + def test_predict_step(self, asr_model): + asr_model = asr_model.eval() + cuts = DummyManifest(CutSet, begin_id=0, end_id=1, with_data=True) + dataset = LhotseSpeechToTextBpeDataset(tokenizer=asr_model.tokenizer, return_cuts=True) + batch = dataset[cuts] + outputs = asr_model.predict_step(batch, 0) + assert len(outputs) == 1 + assert len(outputs[0]) == 2 + assert isinstance(outputs[0][0], MonoCut) + assert isinstance(outputs[0][1], str) + @pytest.mark.with_downloads() @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit def test_save_restore_artifact(self, asr_model): @@ -190,7 +214,8 @@ def test_save_restore_artifact(self, asr_model): @pytest.mark.with_downloads() @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit def test_save_restore_artifact_spe(self, asr_model, test_data_dir): @@ -236,7 +261,8 @@ def test_save_restore_artifact_agg(self, asr_model, test_data_dir): @pytest.mark.with_downloads() @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit def test_vocab_change(self, test_data_dir, asr_model): @@ -266,7 +292,8 @@ def test_vocab_change(self, test_data_dir, asr_model): @pytest.mark.with_downloads() @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit def test_decoding_change(self, asr_model): @@ -309,7 +336,8 @@ def test_decoding_change(self, asr_model): @pytest.mark.with_downloads() @pytest.mark.unit @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) def test_save_restore_nested_model(self): with tempfile.TemporaryDirectory() as tmp_dir: @@ -330,7 +358,7 @@ def test_save_restore_nested_model(self): # Check size of the checkpoint, which contains weights from pretrained model + linear layer fp_weights = os.path.join(tmp_dir, 'model_weights.ckpt') - assert os.path.getsize(fp_weights) > 50 * (2 ** 20) # Assert the weights are more than 50 MB + assert os.path.getsize(fp_weights) > 50 * (2**20) # Assert the weights are more than 50 MB # Check if param after restoration is exact match original_state_dict = model.inner_model.state_dict() diff --git a/tests/collections/asr/test_diar_utils.py b/tests/collections/asr/test_diar_utils.py index f48292d27981..cb364675fcf4 100644 --- a/tests/collections/asr/test_diar_utils.py +++ b/tests/collections/asr/test_diar_utils.py @@ -48,7 +48,7 @@ get_online_subsegments_from_buffer, get_speech_labels_for_update, get_sub_range_list, - get_subsegments, + get_subsegments_scriptable, get_target_sig, int2fl, is_overlap, @@ -82,8 +82,7 @@ def matrix(mat, use_tensor=True, dtype=torch.long): def generate_orthogonal_embs(total_spks, perturb_sigma, emb_dim): - """Generate a set of artificial orthogonal embedding vectors from random numbers - """ + """Generate a set of artificial orthogonal embedding vectors from random numbers""" gaus = torch.randn(emb_dim, emb_dim) _svd = torch.linalg.svd(gaus) orth = _svd[0] @ _svd[2] @@ -110,7 +109,7 @@ def generate_toy_data( random_orthogonal_embs = generate_orthogonal_embs(n_spks, perturb_sigma, emb_dim) for scale_idx, (window, shift) in enumerate(zip(ms_window, ms_shift)): for spk_idx, (offset, dur) in enumerate(spk_timestamps): - segments_stt_dur = get_subsegments(offset=offset, window=window, shift=shift, duration=dur) + segments_stt_dur = get_subsegments_scriptable(offset=offset, window=window, shift=shift, duration=dur) segments = [[x[0], x[0] + x[1]] for x in segments_stt_dur] emb_cent = random_orthogonal_embs[spk_idx, :] emb = emb_cent.tile((len(segments), 1)) + 0.1 * torch.rand(len(segments), emb_dim) @@ -130,8 +129,7 @@ def generate_toy_data( class TestDiarizationSequneceUtilFunctions: - """Tests diarization and speaker-task related utils. - """ + """Tests diarization and speaker-task related utils.""" @pytest.mark.unit @pytest.mark.parametrize("Y", [[3, 3, 3, 4, 4, 5], [100, 100, 100, 104, 104, 1005]]) @@ -278,7 +276,10 @@ def test_embedding_reducer(self, n_spks, target_speaker_index, merge_quantity): em, ts, mc, mw, spk_ts, gt = generate_toy_data(n_spks=n_spks, spk_dur=10) em_s, ts_s = split_input_data(em, ts, mc) merged_embs, merged_clus_labels, _ = run_reducer( - pre_embs=em_s[-1], target_spk_idx=target_speaker_index, merge_quantity=merge_quantity, pre_clus_labels=gt, + pre_embs=em_s[-1], + target_spk_idx=target_speaker_index, + merge_quantity=merge_quantity, + pre_clus_labels=gt, ) assert (torch.sum(gt == target_speaker_index).item() - merge_quantity) == merged_clus_labels.shape[0] @@ -287,7 +288,11 @@ def test_embedding_reducer(self, n_spks, target_speaker_index, merge_quantity): @pytest.mark.parametrize("pcl", [torch.tensor([0] * 70 + [1] * 32)]) @pytest.mark.parametrize("mspb", [25]) def test_merge_scheduler_2clus(self, ntbr, pcl, mspb): - class_target_vol = get_merge_quantity(num_to_be_removed=ntbr, pre_clus_labels=pcl, min_count_per_cluster=mspb,) + class_target_vol = get_merge_quantity( + num_to_be_removed=ntbr, + pre_clus_labels=pcl, + min_count_per_cluster=mspb, + ) assert all(class_target_vol == torch.tensor([3, 0])) @pytest.mark.unit @@ -295,7 +300,11 @@ def test_merge_scheduler_2clus(self, ntbr, pcl, mspb): @pytest.mark.parametrize("pcl", [torch.tensor([0] * 80 + [1] * 35 + [2] * 32)]) @pytest.mark.parametrize("mspb", [0, 25]) def test_merge_scheduler_3clus(self, ntbr, pcl, mspb): - class_target_vol = get_merge_quantity(num_to_be_removed=ntbr, pre_clus_labels=pcl, min_count_per_cluster=mspb,) + class_target_vol = get_merge_quantity( + num_to_be_removed=ntbr, + pre_clus_labels=pcl, + min_count_per_cluster=mspb, + ) assert all(class_target_vol == torch.tensor([3, 0, 0])) @pytest.mark.unit @@ -303,7 +312,11 @@ def test_merge_scheduler_3clus(self, ntbr, pcl, mspb): @pytest.mark.parametrize("pcl", [torch.tensor([2] * 70 + [0] * 32 + [1] * 27 + [3] * 3)]) @pytest.mark.parametrize("mspb", [3, 10]) def test_merge_scheduler_4clus_shuff(self, ntbr, pcl, mspb): - class_target_vol = get_merge_quantity(num_to_be_removed=ntbr, pre_clus_labels=pcl, min_count_per_cluster=mspb,) + class_target_vol = get_merge_quantity( + num_to_be_removed=ntbr, + pre_clus_labels=pcl, + min_count_per_cluster=mspb, + ) assert all(class_target_vol == torch.tensor([18, 13, 56, 0])) @pytest.mark.unit @@ -311,7 +324,11 @@ def test_merge_scheduler_4clus_shuff(self, ntbr, pcl, mspb): @pytest.mark.parametrize("pcl", [torch.tensor([0] * 5 + [1] * 4 + [2] * 3)]) @pytest.mark.parametrize("mspb", [0, 2]) def test_merge_scheduler_3clus(self, ntbr, pcl, mspb): - class_target_vol = get_merge_quantity(num_to_be_removed=ntbr, pre_clus_labels=pcl, min_count_per_cluster=mspb,) + class_target_vol = get_merge_quantity( + num_to_be_removed=ntbr, + pre_clus_labels=pcl, + min_count_per_cluster=mspb, + ) assert all(class_target_vol == torch.tensor([2, 1, 0])) @pytest.mark.unit @@ -319,7 +336,11 @@ def test_merge_scheduler_3clus(self, ntbr, pcl, mspb): @pytest.mark.parametrize("pcl", [torch.tensor([0] * 7 + [1] * 5 + [2] * 3 + [3] * 5)]) @pytest.mark.parametrize("mspb", [2]) def test_merge_scheduler_3clus_repeat(self, ntbr, pcl, mspb): - class_target_vol = get_merge_quantity(num_to_be_removed=ntbr, pre_clus_labels=pcl, min_count_per_cluster=mspb,) + class_target_vol = get_merge_quantity( + num_to_be_removed=ntbr, + pre_clus_labels=pcl, + min_count_per_cluster=mspb, + ) assert all(class_target_vol == torch.tensor([2, 0, 0, 0])) @@ -414,13 +435,21 @@ def test_is_overlap_false(self, rangeA, rangeB): @pytest.mark.parametrize("x", [1.0, 2.3456]) @pytest.mark.parametrize("decimals", [1, 2, 3, 4]) def test_fl2int(self, x, decimals): - assert fl2int(x, decimals) == round(x * 10 ** decimals, 0) + assert fl2int(x, decimals) == round(x * 10**decimals, 0) @pytest.mark.unit @pytest.mark.parametrize("x", [1234]) - @pytest.mark.parametrize("decimals", [1, 2, 3, 4,]) + @pytest.mark.parametrize( + "decimals", + [ + 1, + 2, + 3, + 4, + ], + ) def test_int2fl(self, x, decimals): - assert abs(int2fl(x, decimals) - round(x / (10 ** decimals), decimals)) < (10 ** -(decimals + 1)) + assert abs(int2fl(x, decimals) - round(x / (10**decimals), decimals)) < (10 ** -(decimals + 1)) @pytest.mark.unit def test_merge_float_intervals_edge_margin_test(self): @@ -462,7 +491,11 @@ def test_get_speech_labels_for_update(self): vad_timestamps = torch.tensor([[0.9600, 4.8400]]) cursor_for_old_segments = 1.0 speech_labels_for_update, cumulative_speech_labels = get_speech_labels_for_update( - frame_start, buffer_end, cumulative_speech_labels, vad_timestamps, cursor_for_old_segments, + frame_start, + buffer_end, + cumulative_speech_labels, + vad_timestamps, + cursor_for_old_segments, ) assert (speech_labels_for_update - torch.tensor([[1.0000, 3.7600]])).sum() < 1e-8 assert (cumulative_speech_labels - torch.tensor([[0.9600, 4.8400]])).sum() < 1e-8 @@ -532,7 +565,10 @@ def test_tensor_to_list(self, source_range_list): @pytest.mark.unit @pytest.mark.parametrize( "buffer_start, buffer_end, subsegments, ind_offset, window, sample_rate", - [(0.0, 2.0, [[0.5, 1.0], [1.5, 2.0]], 0, 0.1, 16000), (0.0, 5.0, [[0.5, 2.5], [2.7, 5.0]], 0, 1.0, 16000),], + [ + (0.0, 2.0, [[0.5, 1.0], [1.5, 2.0]], 0, 0.1, 16000), + (0.0, 5.0, [[0.5, 2.5], [2.7, 5.0]], 0, 1.0, 16000), + ], ) def test_get_online_segments_from_slices( self, buffer_start, buffer_end, subsegments, ind_offset, window, sample_rate @@ -665,7 +701,13 @@ def test_offline_speaker_clustering_cpu(self, n_spks, total_sec, SSV, perturb_si @pytest.mark.parametrize("SSV, enhanced_count_thres, min_samples_for_nmesc", [(5, 40, 6)]) @pytest.mark.parametrize("seed", [0]) def test_offline_speaker_clustering_very_short_cpu( - self, n_spks, spk_dur, SSV, enhanced_count_thres, min_samples_for_nmesc, seed, + self, + n_spks, + spk_dur, + SSV, + enhanced_count_thres, + min_samples_for_nmesc, + seed, ): em, ts, mc, mw, spk_ts, gt = generate_toy_data( n_spks=n_spks, spk_dur=spk_dur, perturb_sigma=0.1, torch_seed=seed @@ -697,7 +739,13 @@ def test_offline_speaker_clustering_very_short_cpu( @pytest.mark.parametrize("n_spks, SSV, enhanced_count_thres, min_samples_for_nmesc", [(1, 5, 40, 6)]) @pytest.mark.parametrize("seed", [0]) def test_offline_speaker_clustering_very_short_gpu( - self, n_spks, spk_dur, SSV, enhanced_count_thres, min_samples_for_nmesc, seed, + self, + n_spks, + spk_dur, + SSV, + enhanced_count_thres, + min_samples_for_nmesc, + seed, ): em, ts, mc, mw, spk_ts, gt = generate_toy_data( n_spks=n_spks, spk_dur=spk_dur, perturb_sigma=0.1, torch_seed=seed @@ -908,7 +956,7 @@ def test_linear_sum_assignment_algorithm_cost_matrix(self, cost_matrix): Test the linear sum assignment algorithm with a cost matrix Compare with the scipy implementation and make sure the final cost is the same. - NOTE: There could be multiple solutions with the same cost in linear sum assignment problem. + NOTE: There could be multiple solutions with the same cost in linear sum assignment problem. This test only checks if the cost is the same. """ row_ind_nm, col_ind_nm = nemo_linear_sum_assignment(cost_matrix) diff --git a/tests/collections/common/test_ema.py b/tests/collections/common/test_ema.py index 98f733f1c568..18ee04e371e2 100644 --- a/tests/collections/common/test_ema.py +++ b/tests/collections/common/test_ema.py @@ -15,13 +15,13 @@ import os.path from typing import Any, Dict, Union +import lightning.pytorch as pl import pytest -import pytorch_lightning as pl import torch +from lightning.pytorch import Callback, Trainer +from lightning.pytorch.utilities.exceptions import MisconfigurationException +from lightning.pytorch.utilities.types import STEP_OUTPUT from omegaconf import DictConfig, OmegaConf -from pytorch_lightning import Callback, Trainer -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.types import STEP_OUTPUT from nemo.collections.common.callbacks import EMA from nemo.collections.common.callbacks.ema import EMAOptimizer @@ -349,7 +349,12 @@ class TestEMATrain: @pytest.mark.parametrize("validate_original_weights", [True, False]) @pytest.mark.run_only_on('GPU') def test_ema_run_cuda( - self, test_data_dir, precision, accumulate_grad_batches, validate_original_weights, tmpdir, + self, + test_data_dir, + precision, + accumulate_grad_batches, + validate_original_weights, + tmpdir, ): self.run_training_test( accumulate_grad_batches=accumulate_grad_batches, diff --git a/tests/collections/llm/common.py b/tests/collections/llm/common.py index 95b8bc0de584..c17243936bd1 100644 --- a/tests/collections/llm/common.py +++ b/tests/collections/llm/common.py @@ -14,7 +14,7 @@ import os -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from nemo import lightning as nl diff --git a/tests/collections/llm/gpt/model/megatron_ssm_pretraining.py b/tests/collections/llm/gpt/model/megatron_ssm_pretraining.py index d7ecaafaaf8c..55bea59d6274 100644 --- a/tests/collections/llm/gpt/model/megatron_ssm_pretraining.py +++ b/tests/collections/llm/gpt/model/megatron_ssm_pretraining.py @@ -16,9 +16,11 @@ ## There are no guarantees that this script is up-to-date with latest NeMo. import argparse + import torch +from lightning.pytorch.loggers import TensorBoardLogger from megatron.core.optimizer import OptimizerConfig -from pytorch_lightning.loggers import TensorBoardLogger + from nemo import lightning as nl from nemo.collections import llm from nemo.collections.llm.api import train diff --git a/tests/collections/llm/gpt_finetuning.py b/tests/collections/llm/gpt_finetuning.py index 7eaa7744729c..91796585bf96 100644 --- a/tests/collections/llm/gpt_finetuning.py +++ b/tests/collections/llm/gpt_finetuning.py @@ -94,8 +94,8 @@ def get_args(): ), ) - if args.peft == 'lora': - peft = llm.peft.LoRA() + if args.peft in ['lora', 'dora']: + peft = llm.peft.PEFT_STR2CLS[args.peft]() else: peft = None diff --git a/tests/collections/llm/lora_mistralai.py b/tests/collections/llm/lora_mistralai.py index 09a52668e3ee..0415569304ac 100644 --- a/tests/collections/llm/lora_mistralai.py +++ b/tests/collections/llm/lora_mistralai.py @@ -14,7 +14,7 @@ import argparse -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from megatron.core.optimizer import OptimizerConfig diff --git a/tests/collections/llm/megatron_gpt_pretraining.py b/tests/collections/llm/megatron_gpt_pretraining.py index a73b2a694c76..9722ba9d6c68 100644 --- a/tests/collections/llm/megatron_gpt_pretraining.py +++ b/tests/collections/llm/megatron_gpt_pretraining.py @@ -18,8 +18,8 @@ import argparse import torch +from lightning.pytorch.loggers import TensorBoardLogger from megatron.core.optimizer import OptimizerConfig -from pytorch_lightning.loggers import TensorBoardLogger from nemo import lightning as nl from nemo.collections import llm diff --git a/tests/collections/llm/megatron_t5_finetuning.py b/tests/collections/llm/megatron_t5_finetuning.py index e8f4947c9674..976ad5c48053 100644 --- a/tests/collections/llm/megatron_t5_finetuning.py +++ b/tests/collections/llm/megatron_t5_finetuning.py @@ -18,8 +18,8 @@ import argparse import torch +from lightning.pytorch.loggers import WandbLogger from megatron.core.optimizer import OptimizerConfig -from pytorch_lightning.loggers import WandbLogger from nemo import lightning as nl from nemo.collections import llm @@ -103,7 +103,7 @@ def get_args(): optimizer='adam', lr=2.0e-5, use_distributed_optimizer=False, - bf16=False, + bf16=True, weight_decay=0.1, ) opt = MegatronOptimizerModule( @@ -124,7 +124,7 @@ def get_args(): log_every_n_steps=1, limit_val_batches=2, val_check_interval=50, - plugins=nl.MegatronMixedPrecision(precision="32"), + plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"), ) if args.wandb_project is not None: diff --git a/tests/collections/llm/megatron_t5_pretraining.py b/tests/collections/llm/megatron_t5_pretraining.py index a5460be3d154..ad63ae88fb73 100644 --- a/tests/collections/llm/megatron_t5_pretraining.py +++ b/tests/collections/llm/megatron_t5_pretraining.py @@ -18,8 +18,8 @@ import argparse import torch +from lightning.pytorch.loggers import WandbLogger from megatron.core.optimizer import OptimizerConfig -from pytorch_lightning.loggers import WandbLogger from nemo import lightning as nl from nemo.collections import llm diff --git a/tests/collections/llm/peft/lora_merge.py b/tests/collections/llm/peft/lora_merge.py new file mode 100644 index 000000000000..2ca7390ea7e6 --- /dev/null +++ b/tests/collections/llm/peft/lora_merge.py @@ -0,0 +1,42 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +from dataclasses import dataclass + +from nemo.collections import llm + + +@dataclass +class Llama3ConfigCI(llm.Llama3Config8B): + seq_length: int = 2048 + num_layers: int = 2 + hidden_size: int = 768 + ffn_hidden_size: int = 3072 + num_attention_heads: int = 8 + + +def get_args(): + parser = argparse.ArgumentParser(description='Merge LoRA weights with base LLM') + parser.add_argument('--lora_checkpoint_path', type=str, help="Path to finetuned LORA checkpoint") + parser.add_argument('--output_path', type=str, help="Path to save merged checkpoint") + return parser.parse_args() + + +if __name__ == '__main__': + args = get_args() + + llm.peft.merge_lora( + lora_checkpoint_path=args.lora_checkpoint_path, + output_path=args.output_path, + ) diff --git a/tests/collections/llm/test_mnist_model_nemo2.py b/tests/collections/llm/test_mnist_model_nemo2.py index a5c2aa96fc03..92cffc2a35bb 100644 --- a/tests/collections/llm/test_mnist_model_nemo2.py +++ b/tests/collections/llm/test_mnist_model_nemo2.py @@ -23,16 +23,16 @@ from pathlib import Path from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Tuple, TypedDict, TypeVar, Union +import lightning.pytorch as pl import megatron.core.num_microbatches_calculator import pytest -import pytorch_lightning as pl import torch import torch.distributed +from lightning.pytorch.loggers import TensorBoardLogger from megatron.core import ModelParallelConfig, parallel_state from megatron.core.optimizer import OptimizerConfig from megatron.core.transformer.enums import ModelType from megatron.core.transformer.module import MegatronModule -from pytorch_lightning.loggers import TensorBoardLogger from torch import Tensor, nn from torch.utils.data import DataLoader from torchvision import transforms diff --git a/tests/collections/llm/test_mnist_model_nemo2_fsdp.py b/tests/collections/llm/test_mnist_model_nemo2_fsdp.py index 8a6c1f993d28..9418ee7e5e90 100644 --- a/tests/collections/llm/test_mnist_model_nemo2_fsdp.py +++ b/tests/collections/llm/test_mnist_model_nemo2_fsdp.py @@ -23,16 +23,16 @@ from pathlib import Path from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Tuple, TypedDict, TypeVar, Union +import lightning.pytorch as pl import megatron.core.num_microbatches_calculator import pytest -import pytorch_lightning as pl import torch import torch.distributed +from lightning.pytorch.loggers import TensorBoardLogger from megatron.core import ModelParallelConfig, parallel_state from megatron.core.optimizer import OptimizerConfig from megatron.core.transformer.enums import ModelType from megatron.core.transformer.module import MegatronModule -from pytorch_lightning.loggers import TensorBoardLogger from torch import Tensor, nn from torch.optim import Adam from torch.utils.data import DataLoader @@ -525,6 +525,7 @@ def run_train_mnist_litautoencoder_with_fsdp_strategy_single_gpu(): every_n_train_steps=5, # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe always_save_context=True, + filename="{model_name}--{val_loss:.2f}-{step}-{consumed_samples}", ) root_dir = tmpdir save_dir = root_dir / name @@ -572,6 +573,7 @@ def run_train_mnist_litautoencoder_with_fsdp_strategy_single_gpu(): global_batch_size=2, output_log=False, # Disable logs to support predict_step ), + ckpt_load_optimizer=False, ) predict_trainer = nl.Trainer( accelerator="gpu", diff --git a/tests/collections/multimodal/test_speechllm_models.py b/tests/collections/multimodal/test_speechllm_models.py index 8698fed205ea..09149064b657 100644 --- a/tests/collections/multimodal/test_speechllm_models.py +++ b/tests/collections/multimodal/test_speechllm_models.py @@ -16,13 +16,13 @@ import tempfile from pathlib import Path +import lightning.pytorch as pl import numpy as np import pytest -import pytorch_lightning as pl import torch +from lightning.pytorch.plugins.environments import TorchElasticEnvironment from megatron.core import parallel_state from omegaconf import DictConfig, OmegaConf -from pytorch_lightning.plugins.environments import TorchElasticEnvironment from nemo.collections.multimodal.speech_llm.models import modular_models from nemo.collections.multimodal.speech_llm.parts.utils.data_utils import shift_tokens_by_multi_audios diff --git a/tests/collections/nlp/test_falcon_model.py b/tests/collections/nlp/test_falcon_model.py index 23430ad36300..62a4591092a9 100644 --- a/tests/collections/nlp/test_falcon_model.py +++ b/tests/collections/nlp/test_falcon_model.py @@ -14,8 +14,8 @@ import pytest import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig -from pytorch_lightning import Trainer from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel diff --git a/tests/collections/nlp/test_flash_attention.py b/tests/collections/nlp/test_flash_attention.py index f5585ddc1636..c8309b34b433 100644 --- a/tests/collections/nlp/test_flash_attention.py +++ b/tests/collections/nlp/test_flash_attention.py @@ -16,8 +16,8 @@ import pytest import torch +from lightning.pytorch.trainer.trainer import Trainer from megatron.core import ModelParallelConfig -from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.modules.common.megatron.attention import CoreAttention from nemo.collections.nlp.modules.common.megatron.megatron_init import initialize_model_parallel_for_nemo diff --git a/tests/collections/nlp/test_gpt_eval.py b/tests/collections/nlp/test_gpt_eval.py index fb3f9fda5ac3..020185ec7385 100644 --- a/tests/collections/nlp/test_gpt_eval.py +++ b/tests/collections/nlp/test_gpt_eval.py @@ -16,7 +16,7 @@ import numpy as np import pytest -from pytorch_lightning.trainer.trainer import Trainer +from lightning.pytorch.trainer.trainer import Trainer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel from nemo.collections.nlp.modules.common.transformer.text_generation import LengthParam, SamplingParam diff --git a/tests/collections/nlp/test_gpt_model.py b/tests/collections/nlp/test_gpt_model.py index 7b6c02f948a4..334167f3dcf8 100644 --- a/tests/collections/nlp/test_gpt_model.py +++ b/tests/collections/nlp/test_gpt_model.py @@ -16,8 +16,8 @@ import pytest import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig -from pytorch_lightning import Trainer from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel diff --git a/tests/collections/nlp/test_nlp_exportables.py b/tests/collections/nlp/test_nlp_exportables.py index b404764e7eed..6da0f8c93cc0 100644 --- a/tests/collections/nlp/test_nlp_exportables.py +++ b/tests/collections/nlp/test_nlp_exportables.py @@ -14,9 +14,9 @@ import os import tempfile +import lightning.pytorch as pl import onnx import pytest -import pytorch_lightning as pl import torch import wget from omegaconf import DictConfig, OmegaConf diff --git a/tests/collections/nlp/test_pretrained_models_performance.py b/tests/collections/nlp/test_pretrained_models_performance.py index 82ff6ed103f1..b51f00681f57 100644 --- a/tests/collections/nlp/test_pretrained_models_performance.py +++ b/tests/collections/nlp/test_pretrained_models_performance.py @@ -17,8 +17,8 @@ from shutil import rmtree from unittest import TestCase +import lightning.pytorch as pl import pytest -import pytorch_lightning as pl from omegaconf import OmegaConf import nemo.collections.nlp.models as models diff --git a/tests/collections/nlp/test_rampup_batch_size.py b/tests/collections/nlp/test_rampup_batch_size.py index c7efb5f57f4c..763dfaaf3c51 100644 --- a/tests/collections/nlp/test_rampup_batch_size.py +++ b/tests/collections/nlp/test_rampup_batch_size.py @@ -16,8 +16,8 @@ import pytest import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig -from pytorch_lightning import Trainer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy diff --git a/tests/collections/nlp/test_retrieval_module.py b/tests/collections/nlp/test_retrieval_module.py index 426e393c85bf..381d009f0e02 100644 --- a/tests/collections/nlp/test_retrieval_module.py +++ b/tests/collections/nlp/test_retrieval_module.py @@ -16,7 +16,7 @@ import pytest import torch from einops import rearrange -from pytorch_lightning.trainer.trainer import Trainer +from lightning.pytorch.trainer.trainer import Trainer from nemo.collections.nlp.modules.common.megatron.attention import ParallelChunkedCrossAttention from nemo.collections.nlp.modules.common.megatron.layer_type import LayerType @@ -73,7 +73,13 @@ def setup_class(cls): MB_SIZE = 4 GB_SIZE = 8 SEED = 1234 - trainer = Trainer(strategy=NLPDDPStrategy(), devices=GPUS, accelerator='gpu', num_nodes=1, logger=None,) + trainer = Trainer( + strategy=NLPDDPStrategy(), + devices=GPUS, + accelerator='gpu', + num_nodes=1, + logger=None, + ) initialize_model_parallel_for_nemo( world_size=trainer.world_size, @@ -134,7 +140,9 @@ def test_cross_attn(self, model_parallel_config): dec_attn_mask = rearrange(hidden_mask, '(k n) b -> (b k) n', k=chunks) context_attn_mask = rearrange(context_mask, 'k r n b -> (b k) (r n)') enc_dec_attn_mask_3d = build_attention_mask_3d( - source_mask=dec_attn_mask, target_mask=context_attn_mask, attn_mask_type=AttnMaskType.padding, + source_mask=dec_attn_mask, + target_mask=context_attn_mask, + attn_mask_type=AttnMaskType.padding, ) enc_dec_attn_mask_3d = enc_dec_attn_mask_3d[:, None, :, :] diff --git a/tests/collections/nlp/test_retrieval_module_inference.py b/tests/collections/nlp/test_retrieval_module_inference.py index ccb426ce4ab1..a7da05340708 100644 --- a/tests/collections/nlp/test_retrieval_module_inference.py +++ b/tests/collections/nlp/test_retrieval_module_inference.py @@ -17,7 +17,7 @@ import torch import torch.nn.functional as F from einops import rearrange -from pytorch_lightning.trainer.trainer import Trainer +from lightning.pytorch.trainer.trainer import Trainer from nemo.collections.nlp.modules.common.megatron.attention import ParallelChunkedCrossAttention from nemo.collections.nlp.modules.common.megatron.layer_type import LayerType @@ -73,7 +73,13 @@ def setup_class(cls): MB_SIZE = 4 GB_SIZE = 8 SEED = 1234 - trainer = Trainer(strategy=NLPDDPStrategy(), devices=GPUS, accelerator='gpu', num_nodes=1, logger=None,) + trainer = Trainer( + strategy=NLPDDPStrategy(), + devices=GPUS, + accelerator='gpu', + num_nodes=1, + logger=None, + ) initialize_model_parallel_for_nemo( world_size=trainer.world_size, @@ -176,15 +182,33 @@ def test_retrieval_encoder_inference(self, model_parallel_config): neighbors=neighbors, ) assert (encoder.encoder_output - hidden_emb[:, :64]).abs().max().item() < 1e-5 - assert (out_gt[:, 0,] - out_2[:, 0]).abs().max().item() < 1e-2 + assert ( + out_gt[ + :, + 0, + ] + - out_2[:, 0] + ).abs().max().item() < 1e-2 out_test = encoder( retrieved_emb[:, :1], context_mask[:, :1], context_attn_mask=hidden_mask[:, :64], encoder_output=hidden_emb[:, :64], ) - assert (out_gt[:, 0,] - out_test[:, 0]).abs().max().item() < 1e-2 - assert (out_gt[:, 0,] - out_2[:, 0]).abs().max().item() < 1e-2 + assert ( + out_gt[ + :, + 0, + ] + - out_test[:, 0] + ).abs().max().item() < 1e-2 + assert ( + out_gt[ + :, + 0, + ] + - out_2[:, 0] + ).abs().max().item() < 1e-2 for i in range(64, 127): out_3 = encoder( @@ -207,7 +231,13 @@ def test_retrieval_encoder_inference(self, model_parallel_config): neighbors=neighbors, ) assert (encoder.encoder_output - hidden_emb[:, 64:128]).abs().max().item() < 1e-5 - assert (out_gt[:, :2,] - out_3).abs().max().item() < 1e-2 + assert ( + out_gt[ + :, + :2, + ] + - out_3 + ).abs().max().item() < 1e-2 # test inference for i in range(128, 191): out_4 = encoder( @@ -231,7 +261,13 @@ def test_retrieval_encoder_inference(self, model_parallel_config): ) assert (encoder.encoder_output - hidden_emb[:, 128:192]).abs().max().item() < 1e-5 - assert (out_gt[:, :3,] - out_4).abs().max().item() < 1e-2 + assert ( + out_gt[ + :, + :3, + ] + - out_4 + ).abs().max().item() < 1e-2 out_2 = encoder( retrieved_emb[:, :2], @@ -263,7 +299,13 @@ def test_retrieval_encoder_inference(self, model_parallel_config): neighbors=neighbors, ) assert (encoder.encoder_output - hidden_emb[:, 128:192]).abs().max().item() < 1e-5 - assert (out_gt[:, :3,] - out_4).abs().max().item() < 1e-2 + assert ( + out_gt[ + :, + :3, + ] + - out_4 + ).abs().max().item() < 1e-2 @pytest.mark.unit def test_cross_attn_inference(self, model_parallel_config): @@ -309,7 +351,9 @@ def get_attn_mask_3d(hidden_mask, context_mask, chunks): dec_attn_mask = rearrange(hidden_mask, '(k n) b -> (b k) n', k=chunks) context_attn_mask = rearrange(context_mask, 'k r n b -> (b k) (r n)') enc_dec_attn_mask_3d = build_attention_mask_3d( - source_mask=dec_attn_mask, target_mask=context_attn_mask, attn_mask_type=AttnMaskType.padding, + source_mask=dec_attn_mask, + target_mask=context_attn_mask, + attn_mask_type=AttnMaskType.padding, ) enc_dec_attn_mask_3d = enc_dec_attn_mask_3d[:, None, :, :] return enc_dec_attn_mask_3d diff --git a/tests/collections/nlp/test_retro_model.py b/tests/collections/nlp/test_retro_model.py index b96016c8d7ec..e91590915ba5 100644 --- a/tests/collections/nlp/test_retro_model.py +++ b/tests/collections/nlp/test_retro_model.py @@ -16,8 +16,8 @@ import pytest import torch +from lightning.pytorch import Trainer from omegaconf import DictConfig -from pytorch_lightning import Trainer from nemo.collections.nlp.models.language_modeling.megatron_retro_model import MegatronRetroModel from nemo.collections.nlp.modules.common.megatron.utils import get_ltor_masks_and_position_ids diff --git a/tests/core/test_config_utils.py b/tests/core/test_config_utils.py index bb0a0f177dfb..9716fc160629 100644 --- a/tests/core/test_config_utils.py +++ b/tests/core/test_config_utils.py @@ -15,9 +15,9 @@ from dataclasses import dataclass from typing import Any +import lightning.pytorch as ptl import pytest -import pytorch_lightning as ptl -from pytorch_lightning.callbacks.early_stopping import EarlyStopping +from lightning.pytorch.callbacks.early_stopping import EarlyStopping from nemo.core.config.pytorch_lightning import TrainerConfig from nemo.utils import config_utils @@ -126,7 +126,9 @@ def test_ptl_config(self): assert dataclass_subset is None @pytest.mark.unit - def test_early_stopping_config(self,): + def test_early_stopping_config( + self, + ): result = config_utils.assert_dataclass_signature_match(EarlyStopping, EarlyStoppingParams) signatures_match, cls_subset, dataclass_subset = result diff --git a/tests/core/test_dist_ckpt.py b/tests/core/test_dist_ckpt.py index 0a483c0f58ab..6c066d1856a2 100644 --- a/tests/core/test_dist_ckpt.py +++ b/tests/core/test_dist_ckpt.py @@ -17,11 +17,11 @@ from pathlib import Path from typing import Any, Dict +import lightning.pytorch as pl import pytest -import pytorch_lightning as pl import torch -from lightning_fabric.plugins import TorchCheckpointIO -from pytorch_lightning.demos.boring_classes import BoringModel +from lightning.fabric.plugins import TorchCheckpointIO +from lightning.pytorch.demos.boring_classes import BoringModel from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy from nemo.utils.callbacks.dist_ckpt_io import ( diff --git a/tests/core/test_exp_manager.py b/tests/core/test_exp_manager.py index d4b1d37c1938..32d401b2051f 100644 --- a/tests/core/test_exp_manager.py +++ b/tests/core/test_exp_manager.py @@ -18,13 +18,13 @@ from pathlib import Path from typing import Any +import lightning.pytorch as pl import pytest -import pytorch_lightning as pl import torch +from lightning.pytorch import Callback +from lightning.pytorch.loops import _TrainingEpochLoop from omegaconf import OmegaConf from omegaconf.errors import OmegaConfBaseException -from pytorch_lightning import Callback -from pytorch_lightning.loops import _TrainingEpochLoop from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy from nemo.constants import NEMO_ENV_VARNAME_VERSION diff --git a/tests/core/test_fault_tolerance.py b/tests/core/test_fault_tolerance.py index 5b4e0ecba4aa..f916a7b44454 100644 --- a/tests/core/test_fault_tolerance.py +++ b/tests/core/test_fault_tolerance.py @@ -13,8 +13,8 @@ # limitations under the License. import os +import lightning.pytorch as pl import pytest -import pytorch_lightning as pl from nemo.utils.exp_manager import exp_manager diff --git a/tests/core/test_optimizers_schedulers.py b/tests/core/test_optimizers_schedulers.py index 5e5d1ee20c83..419db309a918 100644 --- a/tests/core/test_optimizers_schedulers.py +++ b/tests/core/test_optimizers_schedulers.py @@ -15,12 +15,12 @@ import math import random +import lightning.pytorch as pl import omegaconf import pytest -import pytorch_lightning as pl import torch import torch.optim -from pytorch_lightning.utilities import rank_zero_only +from lightning.pytorch.utilities import rank_zero_only from nemo.core import config, optim from nemo.core.optim.lr_scheduler import AVAILABLE_SCHEDULERS @@ -936,7 +936,13 @@ def train( enable_progress_bar=False, ) max_steps = optim.lr_scheduler.compute_max_steps( - max_epochs, accumulate_grad_batches, limit_train_batches, devices, dataset_len, batch_size, drop_last, + max_epochs, + accumulate_grad_batches, + limit_train_batches, + devices, + dataset_len, + batch_size, + drop_last, ) model = ExampleModel(batch_size, dataset_len, drop_last, max_steps) trainer.callbacks.append(Callback()) @@ -991,7 +997,13 @@ def train( dataset_len = random.randint(20, devices * 500) batch_size = random.randint(math.ceil(5.0 / devices), min(dataset_len // devices, 128)) train( - max_epochs, accumulate_grad_batches, limit_train_batches, devices, batch_size, dataset_len, drop_last, + max_epochs, + accumulate_grad_batches, + limit_train_batches, + devices, + batch_size, + dataset_len, + drop_last, ) @pytest.mark.unit diff --git a/tests/core/test_straggler_det.py b/tests/core/test_straggler_det.py index ee5222854889..1f938214d792 100644 --- a/tests/core/test_straggler_det.py +++ b/tests/core/test_straggler_det.py @@ -14,8 +14,8 @@ import sys +import lightning.pytorch as pl import pytest -import pytorch_lightning as pl import torch from omegaconf import OmegaConf diff --git a/tests/core_ptl/check_for_ranks.py b/tests/core_ptl/check_for_ranks.py index a1eae66790c4..dfbc05166c5a 100644 --- a/tests/core_ptl/check_for_ranks.py +++ b/tests/core_ptl/check_for_ranks.py @@ -16,9 +16,9 @@ import shutil import torch +from lightning.pytorch import Trainer +from lightning.pytorch.utilities import rank_zero_only from omegaconf import OmegaConf -from pytorch_lightning import Trainer -from pytorch_lightning.utilities import rank_zero_only from nemo.core import ModelPT from nemo.utils import logging diff --git a/tests/core_ptl/check_manual_upload_to_hf_hub.py b/tests/core_ptl/check_manual_upload_to_hf_hub.py index f411ee72332c..912eabb805bf 100644 --- a/tests/core_ptl/check_manual_upload_to_hf_hub.py +++ b/tests/core_ptl/check_manual_upload_to_hf_hub.py @@ -14,7 +14,7 @@ import shutil from huggingface_hub import HfApi -from pytorch_lightning.utilities import rank_zero_only +from lightning.pytorch.utilities import rank_zero_only from nemo.core import ModelPT from nemo.utils import AppState, logging @@ -40,7 +40,9 @@ def load_model_from_unpacked_hf_dir(repo_id): def upload_model_as_single_nemo_file(model: ModelPT, repo_id, token): # Upload the model to HF Hub model.push_to_hf_hub( - repo_id=repo_id, pack_nemo_file=True, token=token, + repo_id=repo_id, + pack_nemo_file=True, + token=token, ) @@ -48,7 +50,9 @@ def upload_model_as_single_nemo_file(model: ModelPT, repo_id, token): def upload_model_as_single_nemo_file(model: ModelPT, repo_id, token): # Upload the model to HF Hub model.push_to_hf_hub( - repo_id=repo_id, pack_nemo_file=True, token=token, + repo_id=repo_id, + pack_nemo_file=True, + token=token, ) @@ -56,7 +60,9 @@ def upload_model_as_single_nemo_file(model: ModelPT, repo_id, token): def upload_model_as_unpacked_files(model: ModelPT, repo_id, token): # Upload the model to HF Hub model.push_to_hf_hub( - repo_id=repo_id, pack_nemo_file=False, token=token, + repo_id=repo_id, + pack_nemo_file=False, + token=token, ) diff --git a/tests/core_ptl/test_ptl_stateless_timer.py b/tests/core_ptl/test_ptl_stateless_timer.py index 25f354a23c0d..5cfbbda39bbf 100644 --- a/tests/core_ptl/test_ptl_stateless_timer.py +++ b/tests/core_ptl/test_ptl_stateless_timer.py @@ -17,8 +17,8 @@ import pytest import torch +from lightning.pytorch import Trainer from omegaconf import OmegaConf -from pytorch_lightning import Trainer from nemo.core import ModelPT from nemo.utils import logging diff --git a/tests/deploy/nemo_deploy.py b/tests/deploy/nemo_deploy.py index 23db7c4f01f3..45f2bae3425e 100644 --- a/tests/deploy/nemo_deploy.py +++ b/tests/deploy/nemo_deploy.py @@ -21,7 +21,7 @@ import torch -from nemo.deploy.nlp.megatronllm_deployable import MegatronLLMDeployable +from nemo.deploy.nlp import MegatronLLMDeployable from tests.infer_data_path import get_infer_test_data run_export_tests = True diff --git a/tests/export/nemo_export.py b/tests/export/nemo_export.py index e929f2601022..cb2b3619e4d3 100644 --- a/tests/export/nemo_export.py +++ b/tests/export/nemo_export.py @@ -43,7 +43,8 @@ from nemo.deploy.nlp import MegatronLLMDeployable, NemoQueryLLMPyTorch except Exception as e: LOGGER.warning( - f"Cannot import MegatronLLMDeployable, in-framework inference will not be available. {type(e).__name__}: {e}" + "Cannot import MegatronLLMDeployable or NemoQueryLLMPyTorch," + f" in-framework inference will not be available. {type(e).__name__}: {e}" ) in_framework_supported = False @@ -103,7 +104,8 @@ def get_accuracy_with_lambada(model, nq, task_ids, lora_uids, test_data_path): expected_output = record["last_word"].strip().lower() all_expected_outputs.append(expected_output) if model is not None: - if isinstance(model, MegatronLLMDeployable): + + if in_framework_supported and isinstance(model, MegatronLLMDeployable): model_output = model.generate( inputs=[prompt], length_params={"min_length": 1, "max_length": 1}, @@ -147,7 +149,7 @@ def get_accuracy_with_lambada(model, nq, task_ids, lora_uids, test_data_path): correct_answers_relaxed += 1 if nq is not None: - if isinstance(nq, NemoQueryLLMPyTorch): + if in_framework_supported and isinstance(nq, NemoQueryLLMPyTorch): deployed_output = nq.query_llm( prompts=[prompt], max_length=1, @@ -849,6 +851,9 @@ def run_inference_tests(args): "Use the same value for --min_tps and --max_tps." ) + if args.debug: + LOGGER.setLevel(logging.DEBUG) + result_dic: Dict[int, Tuple[FunctionalResult, Optional[AccuracyResult]]] = {} if args.existing_test_models: diff --git a/tests/export/test_trt_compile.py b/tests/export/test_trt_compile.py new file mode 100644 index 000000000000..09a77004678f --- /dev/null +++ b/tests/export/test_trt_compile.py @@ -0,0 +1,139 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import tempfile +import unittest +from typing import List + +import torch + +TEST_CASE_1 = ["fp32"] +TEST_CASE_2 = ["fp16"] + + +class ListAdd(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: List[torch.Tensor], y: torch.Tensor, z: torch.Tensor, bs: float = 0.1): + y1 = y.clone() + x1 = x.copy() + z1 = z + y + for xi in x: + y1 = y1 + xi + bs + return x1, [y1, z1], y1 + z1 + + +@unittest.skip +class TestTRTCompile(unittest.TestCase): + + def setUp(self): + self.gpu_device = torch.cuda.current_device() + + def tearDown(self): + current_device = torch.cuda.current_device() + if current_device != self.gpu_device: + torch.cuda.set_device(self.gpu_device) + + def test_torch_trt(self): + + model = torch.nn.Sequential(*[torch.nn.PReLU(), torch.nn.PReLU()]) + data1 = model.state_dict() + data1["0.weight"] = torch.tensor([0.1]) + data1["1.weight"] = torch.tensor([0.2]) + model.load_state_dict(data1) + model.cuda() + x = torch.randn(1, 16).to("cuda") + + with tempfile.TemporaryDirectory() as tempdir: + args = { + "method": "torch_trt", + "dynamic_batchsize": [1, 4, 8], + } + input_example = (x,) + output_example = model(*input_example) + trt_compile( + model, + f"{tempdir}/test_lists", + args=args, + ) + self.assertIsNone(model._trt_compiler.engine) + trt_output = model(*input_example) + # Check that lazy TRT build succeeded + self.assertIsNotNone(model._trt_compiler.engine) + torch.testing.assert_close(trt_output, output_example, rtol=0.01, atol=0.01) + + def test_profiles(self): + model = ListAdd().cuda() + + with torch.no_grad(), tempfile.TemporaryDirectory() as tmpdir: + args = { + "export_args": { + "dynamo": False, + }, + "input_profiles": [ + { + "x_0": [[1, 8], [2, 16], [2, 32]], + "x_1": [[1, 8], [2, 16], [2, 32]], + "x_2": [[1, 8], [2, 16], [2, 32]], + "y": [[1, 8], [2, 16], [2, 32]], + "z": [[1, 8], [1, 16], [1, 32]], + } + ], + "output_lists": [[-1], [2], []], + } + x = torch.randn(1, 16).to("cuda") + y = torch.randn(1, 16).to("cuda") + z = torch.randn(1, 16).to("cuda") + input_example = ([x, y, z], y.clone(), z.clone()) + output_example = model(*input_example) + trt_compile( + model, + f"{tmpdir}/test_dynamo_trt", + args=args, + ) + self.assertIsNone(model._trt_compiler.engine) + trt_output = model(*input_example) + # Check that lazy TRT build succeeded + self.assertIsNotNone(model._trt_compiler.engine) + torch.testing.assert_close(trt_output, output_example, rtol=0.01, atol=0.01) + + def test_lists(self): + model = ListAdd().cuda() + + with torch.no_grad(), tempfile.TemporaryDirectory() as tmpdir: + args = { + "export_args": { + "dynamo": True, + }, + "output_lists": [[-1], [2], []], + } + x = torch.randn(1, 16).to("cuda") + y = torch.randn(1, 16).to("cuda") + z = torch.randn(1, 16).to("cuda") + input_example = ([x, y, z], y.clone(), z.clone()) + output_example = model(*input_example) + trt_compile( + model, + f"{tmpdir}/test_lists", + args=args, + ) + self.assertIsNone(model._trt_compiler.engine) + trt_output = model(*input_example) + # Check that lazy TRT build succeeded + self.assertIsNotNone(model._trt_compiler.engine) + torch.testing.assert_close(trt_output, output_example, rtol=0.01, atol=0.01) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/lightning/_fabric/test_conversion.py b/tests/lightning/_fabric/test_conversion.py index e690557ec2eb..e97e766c86a7 100644 --- a/tests/lightning/_fabric/test_conversion.py +++ b/tests/lightning/_fabric/test_conversion.py @@ -13,10 +13,10 @@ # limitations under the License. import pytest -from lightning_fabric import plugins as fl_plugins -from lightning_fabric import strategies as fl_strategies -from pytorch_lightning import plugins as pl_plugins -from pytorch_lightning import strategies as pl_strategies +from lightning.fabric import plugins as fl_plugins +from lightning.fabric import strategies as fl_strategies +from lightning.pytorch import plugins as pl_plugins +from lightning.pytorch import strategies as pl_strategies from nemo import lightning as nl from nemo.lightning.fabric.conversion import to_fabric diff --git a/tests/lightning/_io/test_api.py b/tests/lightning/_io/test_api.py index a4d458cef17b..e0aaac1a6aa2 100644 --- a/tests/lightning/_io/test_api.py +++ b/tests/lightning/_io/test_api.py @@ -19,7 +19,7 @@ import fiddle as fdl import pytest import yaml -from pytorch_lightning.loggers import TensorBoardLogger +from lightning.pytorch.loggers import TensorBoardLogger from nemo import lightning as nl from nemo.collections import llm diff --git a/tests/lightning/pytorch/callbacks/test_model_checkpoint.py b/tests/lightning/pytorch/callbacks/test_model_checkpoint.py index 802f2b28c25c..edaa8a6f4ec9 100644 --- a/tests/lightning/pytorch/callbacks/test_model_checkpoint.py +++ b/tests/lightning/pytorch/callbacks/test_model_checkpoint.py @@ -17,12 +17,12 @@ from pathlib import Path from typing import Iterator, Optional, Sequence, Tuple +import lightning.pytorch as pl import megatron import pytest -import pytorch_lightning as pl import torch +from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS from megatron.core import ModelParallelConfig, parallel_state -from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS from torch import Tensor import nemo.lightning as nl diff --git a/tests/lightning/pytorch/callbacks/test_model_transform.py b/tests/lightning/pytorch/callbacks/test_model_transform.py index c59a82895125..cfae55cf99a9 100644 --- a/tests/lightning/pytorch/callbacks/test_model_transform.py +++ b/tests/lightning/pytorch/callbacks/test_model_transform.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import lightning.pytorch as pl import pytest -import pytorch_lightning as pl from torch import nn from nemo.lightning.pytorch.callbacks.model_transform import ModelTransform diff --git a/tests/lightning/pytorch/callbacks/test_peft.py b/tests/lightning/pytorch/callbacks/test_peft.py index 49a6aa0784aa..fb6728acee8f 100644 --- a/tests/lightning/pytorch/callbacks/test_peft.py +++ b/tests/lightning/pytorch/callbacks/test_peft.py @@ -15,7 +15,7 @@ from unittest.mock import MagicMock, call, patch import torch.nn as nn -from pytorch_lightning.trainer.states import TrainerFn +from lightning.pytorch.trainer.states import TrainerFn from nemo.collections.llm import fn from nemo.lightning.pytorch.callbacks.peft import PEFT, WrappedAdapterIO from nemo.utils.callbacks.dist_ckpt_io import AsyncFinalizableCheckpointIO diff --git a/tests/lightning/pytorch/callbacks/test_preemption.py b/tests/lightning/pytorch/callbacks/test_preemption.py index 4152f7fcce59..802d898c5a2b 100644 --- a/tests/lightning/pytorch/callbacks/test_preemption.py +++ b/tests/lightning/pytorch/callbacks/test_preemption.py @@ -17,7 +17,7 @@ import pytest import torch -from pytorch_lightning import Trainer +from lightning.pytorch import Trainer from nemo.lightning.pytorch.callbacks.preemption import PreemptionCallback diff --git a/tests/lightning/test_dist_ckpt.py b/tests/lightning/test_dist_ckpt.py index 886b1085ed55..107d15061792 100644 --- a/tests/lightning/test_dist_ckpt.py +++ b/tests/lightning/test_dist_ckpt.py @@ -21,8 +21,8 @@ def set_env(): from pathlib import Path +import lightning.pytorch as pl import pytest -import pytorch_lightning as pl import torch import nemo.lightning as nl diff --git a/tests/lightning/test_nemo_logger.py b/tests/lightning/test_nemo_logger.py index a5a5ec32c886..8a63a92f0ee6 100644 --- a/tests/lightning/test_nemo_logger.py +++ b/tests/lightning/test_nemo_logger.py @@ -19,8 +19,8 @@ from unittest.mock import patch import pytest -from pytorch_lightning.callbacks import ModelCheckpoint as PTLModelCheckpoint -from pytorch_lightning.loggers import WandbLogger +from lightning.pytorch.callbacks import ModelCheckpoint as PTLModelCheckpoint +from lightning.pytorch.loggers import WandbLogger from nemo import lightning as nl from nemo.constants import NEMO_ENV_VARNAME_VERSION diff --git a/tests/lightning/test_precision_plugin.py b/tests/lightning/test_precision_plugin.py index 44ffa5939fab..960e658187c5 100644 --- a/tests/lightning/test_precision_plugin.py +++ b/tests/lightning/test_precision_plugin.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import lightning.pytorch as pl import pytest -import pytorch_lightning as pl import torch from megatron.core.optimizer import OptimizerConfig diff --git a/tests/lightning/test_state_restoration.py b/tests/lightning/test_state_restoration.py index ccc0eed64d56..59c5cc2234f7 100644 --- a/tests/lightning/test_state_restoration.py +++ b/tests/lightning/test_state_restoration.py @@ -17,8 +17,8 @@ import pytest import torch +from lightning.pytorch.callbacks import Callback from megatron.core.optimizer import OptimizerConfig -from pytorch_lightning.callbacks import Callback from nemo import lightning as nl from nemo.collections import llm diff --git a/tests/utils/test_trainer_utils.py b/tests/utils/test_trainer_utils.py index 55eee92a523c..251e59d4b648 100644 --- a/tests/utils/test_trainer_utils.py +++ b/tests/utils/test_trainer_utils.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from lightning.pytorch.strategies import DDPStrategy from omegaconf import OmegaConf -from pytorch_lightning.strategies import DDPStrategy from nemo.utils.trainer_utils import resolve_trainer_cfg @@ -25,7 +25,7 @@ def test_resolve_trainer_cfg_strategy(): assert ans["strategy"] == "ddp" cfg = OmegaConf.create( - {"strategy": {"_target_": "pytorch_lightning.strategies.DDPStrategy", "gradient_as_bucket_view": True}} + {"strategy": {"_target_": "lightning.pytorch.strategies.DDPStrategy", "gradient_as_bucket_view": True}} ) ans = resolve_trainer_cfg(cfg) assert isinstance(ans, dict) diff --git a/tutorials/01_NeMo_Models.ipynb b/tutorials/01_NeMo_Models.ipynb index 4255a6656b8a..eb76e00cd981 100644 --- a/tutorials/01_NeMo_Models.ipynb +++ b/tutorials/01_NeMo_Models.ipynb @@ -984,7 +984,7 @@ "id": "0TsfmCYthMux" }, "source": [ - "import pytorch_lightning as ptl\n", + "import lightning.pytorch as ptl\n", "from nemo.core import ModelPT\n", "from omegaconf import OmegaConf" ], diff --git a/tutorials/asr/ASR_CTC_Language_Finetuning.ipynb b/tutorials/asr/ASR_CTC_Language_Finetuning.ipynb index a02ee4f99714..6ad3307da496 100644 --- a/tutorials/asr/ASR_CTC_Language_Finetuning.ipynb +++ b/tutorials/asr/ASR_CTC_Language_Finetuning.ipynb @@ -1292,7 +1292,7 @@ }, "source": [ "import torch\n", - "import pytorch_lightning as ptl\n", + "import lightning.pytorch as ptl\n", "\n", "if torch.cuda.is_available():\n", " accelerator = 'gpu'\n", @@ -2088,7 +2088,7 @@ }, "source": [ "import torch\n", - "import pytorch_lightning as ptl\n", + "import lightning.pytorch as ptl\n", "\n", "if torch.cuda.is_available():\n", " accelerator = 'gpu'\n", diff --git a/tutorials/asr/ASR_TTS_Tutorial.ipynb b/tutorials/asr/ASR_TTS_Tutorial.ipynb index 709f96d14ba5..544255f76d06 100644 --- a/tutorials/asr/ASR_TTS_Tutorial.ipynb +++ b/tutorials/asr/ASR_TTS_Tutorial.ipynb @@ -172,7 +172,7 @@ "import tempfile\n", "\n", "from omegaconf import OmegaConf\n", - "import pytorch_lightning as pl\n", + "import lightning.pytorch as pl\n", "import torch\n", "from tqdm.auto import tqdm\n", "import wget\n", diff --git a/tutorials/asr/ASR_with_NeMo.ipynb b/tutorials/asr/ASR_with_NeMo.ipynb index bd95c7194655..bb62e2f5eb9d 100644 --- a/tutorials/asr/ASR_with_NeMo.ipynb +++ b/tutorials/asr/ASR_with_NeMo.ipynb @@ -619,7 +619,7 @@ "id": "GUfR6tAK0k2u" }, "source": [ - "import pytorch_lightning as pl\n", + "import lightning.pytorch as pl\n", "trainer = pl.Trainer(devices=1, accelerator='gpu', max_epochs=50)" ], "execution_count": null, diff --git a/tutorials/asr/ASR_with_Subword_Tokenization.ipynb b/tutorials/asr/ASR_with_Subword_Tokenization.ipynb index ff15a5f75532..7a69735ae542 100644 --- a/tutorials/asr/ASR_with_Subword_Tokenization.ipynb +++ b/tutorials/asr/ASR_with_Subword_Tokenization.ipynb @@ -765,7 +765,7 @@ "id": "3rslHEKeq9qy" }, "source": [ - "import pytorch_lightning as pl\r\n", + "import lightning.pytorch as pl\r\n", "trainer = pl.Trainer(devices=1, accelerator='gpu', max_epochs=50)" ], "execution_count": null, diff --git a/tutorials/asr/ASR_with_Transducers.ipynb b/tutorials/asr/ASR_with_Transducers.ipynb index d20042b9b970..95eecbfb8916 100644 --- a/tutorials/asr/ASR_with_Transducers.ipynb +++ b/tutorials/asr/ASR_with_Transducers.ipynb @@ -754,7 +754,7 @@ "outputs": [], "source": [ "import torch\n", - "from pytorch_lightning import Trainer\n", + "from lightning.pytorch import Trainer\n", "\n", "if torch.cuda.is_available():\n", " accelerator = 'gpu'\n", diff --git a/tutorials/asr/Confidence_Ensembles.ipynb b/tutorials/asr/Confidence_Ensembles.ipynb index 734ddc9a0604..5a999df304b0 100644 --- a/tutorials/asr/Confidence_Ensembles.ipynb +++ b/tutorials/asr/Confidence_Ensembles.ipynb @@ -214,7 +214,7 @@ "# check out https://github.com/NVIDIA/NeMo/blob/main/tutorials/asr/ASR_CTC_Language_Finetuning.ipynb\n", "# to learn more about finetuning NeMo ASR models\n", "from omegaconf import open_dict, OmegaConf\n", - "from pytorch_lightning import Trainer\n", + "from lightning.pytorch import Trainer\n", "\n", "from nemo.collections.asr.models.ctc_bpe_models import EncDecCTCModelBPE\n", "import nemo.utils.exp_manager as exp_manager\n", diff --git a/tutorials/asr/Multilang_ASR.ipynb b/tutorials/asr/Multilang_ASR.ipynb index 612271a8baab..800f8a2d2ded 100644 --- a/tutorials/asr/Multilang_ASR.ipynb +++ b/tutorials/asr/Multilang_ASR.ipynb @@ -1527,7 +1527,7 @@ "outputs": [], "source": [ "import torch\n", - "import pytorch_lightning as ptl" + "import lightning.pytorch as ptl" ] }, { diff --git a/tutorials/asr/Self_Supervised_Pre_Training.ipynb b/tutorials/asr/Self_Supervised_Pre_Training.ipynb index c2e1e7362b3e..0506bafb56e3 100644 --- a/tutorials/asr/Self_Supervised_Pre_Training.ipynb +++ b/tutorials/asr/Self_Supervised_Pre_Training.ipynb @@ -433,7 +433,7 @@ }, "outputs": [], "source": [ - "import pytorch_lightning as pl\n", + "import lightning.pytorch as pl\n", "from omegaconf import OmegaConf\n", "\n", "from nemo.collections.asr.models.ssl_models import SpeechEncDecSelfSupervisedModel\n", diff --git a/tutorials/asr/Speech_Commands.ipynb b/tutorials/asr/Speech_Commands.ipynb index 438533f0f03a..c8a54e5135b2 100644 --- a/tutorials/asr/Speech_Commands.ipynb +++ b/tutorials/asr/Speech_Commands.ipynb @@ -408,7 +408,7 @@ }, "source": [ "import torch\n", - "import pytorch_lightning as pl" + "import lightning.pytorch as pl" ], "execution_count": null, "outputs": [] diff --git a/tutorials/asr/Transducers_with_HF_Datasets.ipynb b/tutorials/asr/Transducers_with_HF_Datasets.ipynb index a47cd00a0b9a..82f17fe8c1ac 100644 --- a/tutorials/asr/Transducers_with_HF_Datasets.ipynb +++ b/tutorials/asr/Transducers_with_HF_Datasets.ipynb @@ -554,7 +554,7 @@ "outputs": [], "source": [ "import torch\n", - "from pytorch_lightning import Trainer\n", + "from lightning.pytorch import Trainer\n", "\n", "if torch.cuda.is_available():\n", " accelerator = 'gpu'\n", diff --git a/tutorials/asr/Voice_Activity_Detection.ipynb b/tutorials/asr/Voice_Activity_Detection.ipynb index 123a03efc28e..fb3cef1b44ea 100644 --- a/tutorials/asr/Voice_Activity_Detection.ipynb +++ b/tutorials/asr/Voice_Activity_Detection.ipynb @@ -425,7 +425,7 @@ "outputs": [], "source": [ "import torch\n", - "import pytorch_lightning as pl" + "import lightning.pytorch as pl" ] }, { diff --git a/tutorials/asr/asr_adapters/ASR_with_Adapters.ipynb b/tutorials/asr/asr_adapters/ASR_with_Adapters.ipynb index c9c547a8383e..c3334a59b0d2 100644 --- a/tutorials/asr/asr_adapters/ASR_with_Adapters.ipynb +++ b/tutorials/asr/asr_adapters/ASR_with_Adapters.ipynb @@ -260,7 +260,7 @@ "source": [ "import torch\n", "from omegaconf import OmegaConf, open_dict\n", - "from pytorch_lightning import Trainer\n", + "from lightning.pytorch import Trainer\n", "\n", "import nemo.collections.asr as nemo_asr" ], diff --git a/tutorials/asr/asr_adapters/Multi_Task_Adapters.ipynb b/tutorials/asr/asr_adapters/Multi_Task_Adapters.ipynb index cb364ab7396d..0d35feb11a9a 100644 --- a/tutorials/asr/asr_adapters/Multi_Task_Adapters.ipynb +++ b/tutorials/asr/asr_adapters/Multi_Task_Adapters.ipynb @@ -908,7 +908,7 @@ "\n", "from torch.utils.data import DataLoader, Dataset\n", "\n", - "import pytorch_lightning as L\n", + "import lightning.pytorch as L\n", "\n", "from transformers import T5Tokenizer, T5ForConditionalGeneration\n", "\n", diff --git a/tutorials/audio/speech_enhancement/Speech_Enhancement_with_NeMo.ipynb b/tutorials/audio/speech_enhancement/Speech_Enhancement_with_NeMo.ipynb index 5c697840ba09..faef27d18abf 100644 --- a/tutorials/audio/speech_enhancement/Speech_Enhancement_with_NeMo.ipynb +++ b/tutorials/audio/speech_enhancement/Speech_Enhancement_with_NeMo.ipynb @@ -91,7 +91,7 @@ "import IPython.display as ipd\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", - "import pytorch_lightning as pl\n", + "import lightning.pytorch as pl\n", "import soundfile as sf\n", "\n", "from omegaconf import OmegaConf, open_dict\n", diff --git a/tutorials/audio/speech_enhancement/Speech_Enhancement_with_Online_Augmentation.ipynb b/tutorials/audio/speech_enhancement/Speech_Enhancement_with_Online_Augmentation.ipynb index ff6970d98522..e8b734537a41 100644 --- a/tutorials/audio/speech_enhancement/Speech_Enhancement_with_Online_Augmentation.ipynb +++ b/tutorials/audio/speech_enhancement/Speech_Enhancement_with_Online_Augmentation.ipynb @@ -93,7 +93,7 @@ "import IPython.display as ipd\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", - "import pytorch_lightning as pl\n", + "import lightning.pytorch as pl\n", "import soundfile as sf\n", "from pathlib import Path\n", "from omegaconf import OmegaConf, open_dict\n", @@ -981,4 +981,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/tutorials/llm/llama-3/README.rst b/tutorials/llm/llama-3/README.rst index 3bb1a0896b82..1d12b8847c0d 100755 --- a/tutorials/llm/llama-3/README.rst +++ b/tutorials/llm/llama-3/README.rst @@ -2,7 +2,7 @@ Getting Started with Llama 3 and Llama 3.1 ========================================== -This repository contains jupyter notebook tutorials using NeMo Framework for Llama-3 and Llama-3.1 models by Meta. +This repository contains Jupyter Notebook tutorials using the NeMo Framework for Llama-3 and Llama-3.1 models by Meta. .. list-table:: :widths: 100 25 100 @@ -16,7 +16,7 @@ This repository contains jupyter notebook tutorials using NeMo Framework for Lla - Perform LoRA PEFT on Llama 3 8B Instruct using a dataset for bio-medical domain question answering. Deploy multiple LoRA adapters with NVIDIA NIM. * - `Llama 3.1 Law-Domain LoRA Fine-Tuning and Deployment with NeMo Framework and NVIDIA NIM <./sdg-law-title-generation>`_ - `Law StackExchange `_ - - Perform LoRA PEFT on Llama 3.1 8B Instruct using a synthetically augmented version of Law StackExchange with NeMo Framework, followed by deployment with NVIDIA NIM. As a pre-requisite, follow the tutorial for `data curation using NeMo Curator `__. + - Perform LoRA PEFT on Llama 3.1 8B Instruct using a synthetically augmented version of Law StackExchange with NeMo Framework, followed by deployment with NVIDIA NIM. As a prerequisite, follow the tutorial for `data curation using NeMo Curator `_. * - `Llama 3.1 Pruning and Distillation with NeMo Framework <./pruning-distillation>`_ - `WikiText-103-v1 `_ - Perform pruning and distillation on Llama 3.1 8B using the WikiText-103-v1 dataset with NeMo Framework. diff --git a/tutorials/llm/llama-3/nemo2-sft-peft/README.rst b/tutorials/llm/llama-3/nemo2-sft-peft/README.rst new file mode 100644 index 000000000000..74a581c52f44 --- /dev/null +++ b/tutorials/llm/llama-3/nemo2-sft-peft/README.rst @@ -0,0 +1,61 @@ +Llama 3 Supervised Fine-Tuning and Parameter Efficient Fine-Tuning with NeMo 2.0 +================================================================================ + +`Llama 3 `_ is an open-source large language model by Meta that delivers state-of-the-art performance on popular industry benchmarks. It has been pretrained on over 15 trillion tokens and supports an 8K token context length. It is available in two sizes, 8B and 70B, and each size has two variants—base pretrained and instruction tuned. + +Supervised Fine-Tuning (SFT) refers to unfreezing all the weights and layers in our model and training on a newly labeled set of examples. We can fine-tune to incorporate new, domain-specific knowledge, or teach the foundation model what type of response to provide. + +`Low-Rank Adaptation (LoRA) `__ has emerged as a popular Parameter-Efficient Fine-Tuning (PEFT) technique that tunes a very small number of additional parameters as compared to full fine-tuning, thereby reducing the compute required. + +`NVIDIA NeMo +Framework `__ provides tools to perform SFT and LoRA on Llama 3 to fit your use case. + +Requirements +------------ + +* System Configuration + * For SFT: access to at least 2 NVIDIA GPUs with a cumulative memory of at least 80GB, for example: 2 x H100-80GB or 2 x A100-80GB. + * For LoRA: access to at least 1 NVIDIA GPUs with a cumulative memory of at least 80GB, for example: 1 x H100-80GB or 1 x A100-80GB. + * A Docker-enabled environment, with `NVIDIA Container Runtime `_ installed, which will make the container GPU-aware. + +* Software Requirements + * Use the latest [NeMo Framework Container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo/tags) . Note that you must be logged in to the container registry to view this page. + * This notebook uses the container: `nvcr.io/nvidia/nemo:dev`. + * Get your Hugging Face [access token](https://huggingface.co/docs/hub/en/security-tokens), which will be used to obtain the tokenizer required during training. + +* NeMo 2.0 and NeMo-Run + * We will use NeMo 2.0 and NeMo-Run to perform SFT and LoRA on Llama 3. Both are already available in the NeMo Framework Container. + + +Start the NeMo Framework Container +---------------------------------- + +1. You can start and enter the dev container by: + +.. code:: bash + + docker run \ + --gpus device=1 \ + --shm-size=2g \ + --net=host \ + --ulimit memlock=-1 \ + --rm -it \ + -v ${PWD}:/workspace \ + -w /workspace \ + nvcr.io/nvidia/nemo:dev bash + + +2. You need to request download permission from Meta and Hugging Face. Then, from within the container, log in through `huggingface-cli` using your Hugging Face token. + +.. code:: bash + + huggingface-cli login + + +3. From within the container, start the Jupyter lab: + +.. code:: bash + + jupyter lab --ip 0.0.0.0 --port=8888 --allow-root + +4. Then, navigate to `the SFT notebook <./nemo2-sft.ipynb>`__ or `the LoRA notebook <./nemo2-peft.ipynb>`__ to perform SFT or LoRA on Llama 3, respectively. diff --git a/tutorials/llm/llama-3/nemo2-sft-peft/nemo2-peft.ipynb b/tutorials/llm/llama-3/nemo2-sft-peft/nemo2-peft.ipynb new file mode 100644 index 000000000000..aa463e2b84be --- /dev/null +++ b/tutorials/llm/llama-3/nemo2-sft-peft/nemo2-peft.ipynb @@ -0,0 +1,572 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Learning Goals\n", + "\n", + "## Optimizing Foundation Models with Parameter-Efficient Fine-Tuning (PEFT)\n", + "\n", + "This notebook aims to demonstrate how to adapt or customize foundation models to improve performance on specific tasks using NeMo 2.0.\n", + "\n", + "This optimization process is known as fine-tuning, which involves adjusting the weights of a pre-trained foundation model with custom data.\n", + "\n", + "Considering that foundation models can be significantly large, a variant of fine-tuning has gained traction recently known as PEFT. PEFT encompasses several methods, including P-Tuning, LoRA, Adapters, IA3, etc. NeMo 2.0 currently supports [Low-Rank Adaptation (LoRA)](https://arxiv.org/pdf/2106.09685) method.\n", + "\n", + "NeMo 2.0 introduces Python-based configurations, PyTorch Lightning’s modular abstractions, and NeMo-Run for scaling experiments across multiple GPUs. In this notebook, we will use NeMo-Run to streamline the configuration and execution of our experiments.\n", + "\n", + "## Data\n", + "This notebook uses the SQuAD dataset. For more details about the data, refer to [SQuAD: 100,000+ Questions for Machine Comprehension of Text](https://arxiv.org/abs/1606.05250)\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "## Step 1. Import the Hugging Face Checkpoint\n", + "We use the `llm.import_ckpt` API to download the specified model using the \"hf://\" URL format. It will then convert the model into NeMo 2.0 format. For all model supported in NeMo 2.0, refer to [Large Language Models](https://docs.nvidia.com/nemo-framework/user-guide/24.09/llms/index.html#large-language-models) section of NeMo Framework User Guide." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import nemo_run as run\n", + "from nemo import lightning as nl\n", + "from nemo.collections import llm\n", + "from megatron.core.optimizer import OptimizerConfig\n", + "from nemo.collections.llm.peft.lora import LoRA\n", + "import torch\n", + "import pytorch_lightning as pl\n", + "from pathlib import Path\n", + "from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed\n", + "\n", + "\n", + "# llm.import_ckpt is the nemo2 API for converting Hugging Face checkpoint to NeMo format\n", + "# example usage:\n", + "# llm.import_ckpt(model=llm.llama3_8b.model(), source=\"hf://meta-llama/Meta-Llama-3-8B\")\n", + "#\n", + "# We use run.Partial to configure this function\n", + "def configure_checkpoint_conversion():\n", + " return run.Partial(\n", + " llm.import_ckpt,\n", + " model=llm.llama3_8b.model(),\n", + " source=\"hf://meta-llama/Meta-Llama-3-8B\",\n", + " overwrite=False,\n", + " )\n", + "\n", + "# configure your function\n", + "import_ckpt = configure_checkpoint_conversion()\n", + "# define your executor\n", + "local_executor = run.LocalExecutor()\n", + "\n", + "# run your experiment\n", + "run.run(import_ckpt, executor=local_executor)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 2. Prepare the Data\n", + "\n", + "We will be using SQuAD for this notebook. NeMo 2.0 already provides a `SquadDataModule`. Example usage:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def squad() -> run.Config[pl.LightningDataModule]:\n", + " return run.Config(llm.SquadDataModule, seq_length=2048, micro_batch_size=1, global_batch_size=8, num_workers=0)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To learn how to use your own data to create a custom `DataModule` for performing PEFT, refer to [NeMo 2.0 SFT notebook](./nemo2-sft.ipynb)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 3.1: Configure PEFT with NeMo 2.0 API and NeMo-Run\n", + "\n", + "The following Python script utilizes the NeMo 2.0 API to perform PEFT. In this script, we are configuring the following components for training. These components are similar between SFT and PEFT. SFT and PEFT both use `llm.finetune` API. To switch from SFT to PEFT, you just need to add `peft` with the LoRA adapter to the API parameter.\n", + "\n", + "### Configure the Trainer\n", + "The NeMo 2.0 Trainer works similarly to the PyTorch Lightning trainer.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def trainer() -> run.Config[nl.Trainer]:\n", + " strategy = run.Config(\n", + " nl.MegatronStrategy,\n", + " tensor_model_parallel_size=1\n", + " )\n", + " trainer = run.Config(\n", + " nl.Trainer,\n", + " devices=1,\n", + " max_steps=20,\n", + " accelerator=\"gpu\",\n", + " strategy=strategy,\n", + " plugins=bf16_mixed(),\n", + " log_every_n_steps=1,\n", + " limit_val_batches=2,\n", + " val_check_interval=2,\n", + " num_sanity_val_steps=0,\n", + " )\n", + " return trainer\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Configure the Logger\n", + "Configure your training steps, output directories and logging through `NeMoLogger`. In the following example, the experiment output will be saved at `./results/nemo2_peft`.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def logger() -> run.Config[nl.NeMoLogger]:\n", + " ckpt = run.Config(\n", + " nl.ModelCheckpoint,\n", + " save_last=True,\n", + " every_n_train_steps=10,\n", + " monitor=\"reduced_train_loss\",\n", + " save_top_k=1,\n", + " save_on_train_epoch_end=True,\n", + " save_optim_on_train_end=True,\n", + " )\n", + "\n", + " return run.Config(\n", + " nl.NeMoLogger,\n", + " name=\"nemo2_peft\",\n", + " log_dir=\"./results\",\n", + " use_datetime_version=False,\n", + " ckpt=ckpt,\n", + " wandb=None\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "\n", + "### Configure the Optimizer\n", + "In the following example, we will be using the distributed adam optimizer and pass in the optimizer configuration through `OptimizerConfig`: " + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def adam_with_cosine_annealing() -> run.Config[nl.OptimizerModule]:\n", + " opt_cfg = run.Config(\n", + " OptimizerConfig,\n", + " optimizer=\"adam\",\n", + " lr=0.0001,\n", + " adam_beta2=0.98,\n", + " use_distributed_optimizer=True,\n", + " clip_grad=1.0,\n", + " bf16=True,\n", + " )\n", + " return run.Config(\n", + " nl.MegatronOptimizerModule,\n", + " config=opt_cfg\n", + " )\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Pass in the LoRA Adapter\n", + "We need to pass in the LoRA adapter to our fine-tuning API to perform LoRA fine-tuning. We can configure the adapter as follows. The target module we support includes: `linear_qkv`, `linear_proj`, `linear_fc1` and `linear_fc2`. In the final script, we used the default configurations for LoRA (`llm.peft.LoRA()`), which will use the full list with `dim=32`." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def lora() -> run.Config[nl.pytorch.callbacks.PEFT]:\n", + " return run.Config(LoRA)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Configure the Base Model\n", + "We will perform PEFT on top of Llama-3-8b, so we create a `LlamaModel` to pass to the NeMo 2.0 finetune API." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def llama3_8b() -> run.Config[pl.LightningModule]:\n", + " return run.Config(llm.LlamaModel, config=run.Config(llm.Llama3Config8B))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Auto Resume\n", + "In NeMo 2.0, we can directly pass in the Llama3-8b Hugging Face ID to start PEFT without manually converting it into the NeMo checkpoint, as required in NeMo 1.0." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def resume() -> run.Config[nl.AutoResume]:\n", + " return run.Config(\n", + " nl.AutoResume,\n", + " restore_config=run.Config(nl.RestoreConfig,\n", + " path=\"nemo://meta-llama/Meta-Llama-3-8B\"\n", + " ),\n", + " resume_if_exists=True,\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "### Configure the NeMo 2.0 finetune API\n", + "Using all the components we created above, we can call the NeMo 2.0 finetune API. The python example usage is as below:\n", + "```\n", + "llm.finetune(\n", + " model=llama3_8b(),\n", + " data=squad(),\n", + " trainer=trainer(),\n", + " peft=lora(),\n", + " log=logger(),\n", + " optim=adam_with_cosine_annealing(),\n", + " resume=resume(),\n", + ")\n", + "```\n", + "We configure the `llm.finetune` API as below:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def configure_finetuning_recipe():\n", + " return run.Partial(\n", + " llm.finetune,\n", + " model=llama3_8b(),\n", + " trainer=trainer(),\n", + " data=squad(),\n", + " log=logger(),\n", + " peft=lora(),\n", + " optim=adam_with_cosine_annealing(),\n", + " resume=resume(),\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 3.2: Run PEFT with NeMo 2.0 API and NeMo-Run\n", + "\n", + "We use `LocalExecutor` for executing our configured finetune function. For more details on the NeMo-Run executor, refer to [Execute NeMo Run](https://github.com/NVIDIA/NeMo-Run/blob/main/docs/source/guides/execution.md) of NeMo-Run Guides. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def local_executor_torchrun(nodes: int = 1, devices: int = 1) -> run.LocalExecutor:\n", + " # Env vars for jobs are configured here\n", + " env_vars = {\n", + " \"TORCH_NCCL_AVOID_RECORD_STREAMS\": \"1\",\n", + " \"NCCL_NVLS_ENABLE\": \"0\",\n", + " \"NVTE_DP_AMAX_REDUCE_INTERVAL\": \"0\",\n", + " \"NVTE_ASYNC_AMAX_REDUCTION\": \"1\",\n", + " \"NVTE_FUSED_ATTN\": \"0\",\n", + " }\n", + "\n", + " executor = run.LocalExecutor(ntasks_per_node=devices, launcher=\"torchrun\", env_vars=env_vars)\n", + "\n", + " return executor\n", + "\n", + "if __name__ == '__main__':\n", + " run.run(configure_finetuning_recipe(), executor=local_executor_torchrun())\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 4. Generate Results from Trained PEFT Checkpoints \n", + "\n", + "We use the `llm.generate` API in NeMo 2.0 to generate results from the trained PEFT checkpoint. Find your last saved checkpoint from your experiment dir: `results/nemo2_peft/checkpoints`. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "peft_ckpt_path=str(next((d for d in Path(\"./results/nemo2_peft/checkpoints/\").iterdir() if d.is_dir() and d.name.endswith(\"-last\")), None))\n", + "print(\"We will load PEFT checkpoint from:\", peft_ckpt_path)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The SQuAD test set contains over 10,000 samples. For a quick demonstration, we will use the first 100 lines as an example input. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "%%bash\n", + "head -n 100 /root/.cache/nemo/datasets/squad/test.jsonl > toy_testset.jsonl\n", + "head -n 3 /root/.cache/nemo/datasets/squad/test.jsonl" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You should see something like:\n", + "```\n", + "{\"input\": \"Context: Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24\\u201310 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the \\\"golden anniversary\\\" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as \\\"Super Bowl L\\\"), so that the logo could prominently feature the Arabic numerals 50. Question: Which NFL team represented the AFC at Super Bowl 50? Answer:\", \"output\": \"Denver Broncos\", \"original_answers\": [\"Denver Broncos\", \"Denver Broncos\", \"Denver Broncos\"]}\n", + "{\"input\": \"Context: Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24\\u201310 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the \\\"golden anniversary\\\" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as \\\"Super Bowl L\\\"), so that the logo could prominently feature the Arabic numerals 50. Question: Which NFL team represented the NFC at Super Bowl 50? Answer:\", \"output\": \"Carolina Panthers\", \"original_answers\": [\"Carolina Panthers\", \"Carolina Panthers\", \"Carolina Panthers\"]}\n", + "{\"input\": \"Context: Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24\\u201310 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the \\\"golden anniversary\\\" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as \\\"Super Bowl L\\\"), so that the logo could prominently feature the Arabic numerals 50. Question: Where did Super Bowl 50 take place? Answer:\", \"output\": \"Santa Clara, California\", \"original_answers\": [\"Santa Clara, California\", \"Levi's Stadium\", \"Levi's Stadium in the San Francisco Bay Area at Santa Clara, California.\"]}\n", + "\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We will pass the string `toy_testset.jsonl` to the `input_dataset` parameter of `llm.generate`. To evaluate the entire test set, you can instead pass the SQuAD data module directly, using `input_dataset=squad()`. The input JSONL file should follow the format shown above, containing `input` and `output` fields (additional keys are optional)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from megatron.core.inference.common_inference_params import CommonInferenceParams\n", + "\n", + "\n", + "def trainer() -> run.Config[nl.Trainer]:\n", + " strategy = run.Config(\n", + " nl.MegatronStrategy,\n", + " tensor_model_parallel_size=1\n", + " )\n", + " trainer = run.Config(\n", + " nl.Trainer,\n", + " accelerator=\"gpu\",\n", + " devices=1,\n", + " num_nodes=1,\n", + " strategy=strategy,\n", + " plugins=bf16_mixed(),\n", + " )\n", + " return trainer\n", + "\n", + "def configure_inference():\n", + " return run.Partial(\n", + " llm.generate,\n", + " path=str(peft_ckpt_path),\n", + " trainer=trainer(),\n", + " input_dataset=\"toy_testset.jsonl\",\n", + " inference_params=CommonInferenceParams(num_tokens_to_generate=20, top_k=1),\n", + " output_path=\"peft_prediction.jsonl\",\n", + " )\n", + "\n", + "\n", + "def local_executor_torchrun(nodes: int = 1, devices: int = 1) -> run.LocalExecutor:\n", + " # Env vars for jobs are configured here\n", + " env_vars = {\n", + " \"TORCH_NCCL_AVOID_RECORD_STREAMS\": \"1\",\n", + " \"NCCL_NVLS_ENABLE\": \"0\",\n", + " \"NVTE_DP_AMAX_REDUCE_INTERVAL\": \"0\",\n", + " \"NVTE_ASYNC_AMAX_REDUCTION\": \"1\",\n", + " \"NVTE_FUSED_ATTN\": \"0\",\n", + " }\n", + "\n", + " executor = run.LocalExecutor(ntasks_per_node=devices, launcher=\"torchrun\", env_vars=env_vars)\n", + "\n", + " return executor\n", + "\n", + "if __name__ == '__main__':\n", + " run.run(configure_inference(), executor=local_executor_torchrun())\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "After the inference is complete, you will see results similar to the following:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "%%bash\n", + "head -n 3 peft_prediction.jsonl" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You should see outputs similar to the following:\n", + "```\n", + "{\"input\": \"Context: Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24\\u201310 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the \\\"golden anniversary\\\" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as \\\"Super Bowl L\\\"), so that the logo could prominently feature the Arabic numerals 50. Question: Which NFL team represented the AFC at Super Bowl 50? Answer:\", \"original_answers\": [\"Denver Broncos\", \"Denver Broncos\", \"Denver Broncos\"], \"label\": \"Denver Broncos\", \"prediction\": \" Denver Broncos\"}\n", + "{\"input\": \"Context: Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24\\u201310 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the \\\"golden anniversary\\\" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as \\\"Super Bowl L\\\"), so that the logo could prominently feature the Arabic numerals 50. Question: Which NFL team represented the NFC at Super Bowl 50? Answer:\", \"original_answers\": [\"Carolina Panthers\", \"Carolina Panthers\", \"Carolina Panthers\"], \"label\": \"Carolina Panthers\", \"prediction\": \" Carolina Panthers\"}\n", + "{\"input\": \"Context: Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24\\u201310 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the \\\"golden anniversary\\\" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as \\\"Super Bowl L\\\"), so that the logo could prominently feature the Arabic numerals 50. Question: Where did Super Bowl 50 take place? Answer:\", \"original_answers\": [\"Santa Clara, California\", \"Levi's Stadium\", \"Levi's Stadium in the San Francisco Bay Area at Santa Clara, California.\"], \"label\": \"Santa Clara, California\", \"prediction\": \" Levi's Stadium\"}\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 5. Calculate Evaluation Metrics\n", + "\n", + "We can evaluate the model's predictions by calculating the Exact Match (EM) and F1 scores.\n", + "- Exact Match is a binary measure (0 or 1) checking if the model outputs match one of the\n", + "ground truth answer exactly.\n", + "- F1 score is the harmonic mean of precision and recall for the answer words.\n", + "\n", + "Below is a script that computes these metrics. The sample scores can be improved by training the model further and performing hyperparameter tuning. In this notebook, we only train for 20 steps.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "!python /opt/NeMo/scripts/metric_calculation/peft_metric_calc.py --pred_file peft_prediction.jsonl --label_field \"original_answers\" --pred_field \"prediction\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# NeMo Tools and Resources\n", + "1. [NeMo GitHub repo](https://github.com/NVIDIA/NeMo)\n", + "\n", + "2. [NeMo-Run GitHub repo](https://github.com/NVIDIA/NeMo-Run/)\n", + "\n", + "3. NeMo Framework Container: `nvcr.io/nvidia/nemo:dev`\n", + "\n", + "\n", + "\n", + "# Educational Resources\n", + "1. Blog: [Mastering LLM Techniques: Customization](https://developer.nvidia.com/blog/selecting-large-language-model-customization-techniques/)\n", + "\n", + "2. Whitepaper: [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685)\n", + "\n", + "3. [NeMo 2.0 Overview](https://docs.nvidia.com/nemo-framework/user-guide/latest/nemo-2.0/index.html)\n", + "\n", + "4. Blog: [Tune and Deploy LoRA LLMs with NVIDIA TensorRT-LLM](https://developer.nvidia.com/blog/tune-and-deploy-lora-llms-with-nvidia-tensorrt-llm/)\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/tutorials/llm/llama-3/nemo2-sft-peft/nemo2-sft.ipynb b/tutorials/llm/llama-3/nemo2-sft-peft/nemo2-sft.ipynb new file mode 100644 index 000000000000..e84ff916fc4e --- /dev/null +++ b/tutorials/llm/llama-3/nemo2-sft-peft/nemo2-sft.ipynb @@ -0,0 +1,657 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Learning Goals\n", + "\n", + "## Optimizing Foundation Models with Supervised Fine-Tuning (SFT)\n", + "\n", + "Often, we want to adapt or customize foundation models to be more performant on our specific task. Fine-tuning refers to how we can modify the weights of a pre-trained foundation model with additional custom data. Supervised Fine-Tuning (SFT) refers to unfreezing all the weights and layers in our model and training on a newly labeled set of examples. We can fine-tune to incorporate new, domain-specific knowledge, or teach the foundation model what type of response to provide. One specific type of SFT is also referred to as “instruction tuning” where we use SFT to teach a model to follow instructions better. In this tutorial, will demonstrate how to perform SFT with Llama3-8b using NeMo 2.0.\n", + "\n", + "NeMo 2.0 introduces Python-based configurations, PyTorch Lightning’s modular abstractions, and NeMo-Run for scaling experiments across multiple GPUs. In this notebook, we will use NeMo-Run to streamline the configuration and execution of our experiments.\n", + "\n", + "## Data\n", + "Databricks-dolly-15k is an open-source dataset created by the collaborative efforts of Databricks employees. It consists of high-quality, human-generated prompt/response pairs specifically designed for instruction tuning LLMs. These pairs cover a diverse range of behaviors, from brainstorming and content generation to information extraction and summarization. \n", + "\n", + "For more information, refer to [databricks-dolly-15k | Hugging Face](https://huggingface.co/datasets/databricks/databricks-dolly-15k)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "## Step 1. Import the Hugging Face Checkpoint\n", + "We use the `llm.import_ckpt` API to download the specified model using the \"hf://\" URL format. It will then convert the model into NeMo 2.0 format. For all model supported in NeMo 2.0, refer to [Large Language Models](https://docs.nvidia.com/nemo-framework/user-guide/24.09/llms/index.html#large-language-models) section of NeMo Framework User Guide." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import nemo_run as run\n", + "from nemo import lightning as nl\n", + "from nemo.collections import llm\n", + "from megatron.core.optimizer import OptimizerConfig\n", + "import torch\n", + "import pytorch_lightning as pl\n", + "from pathlib import Path\n", + "from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed\n", + "\n", + "\n", + "# llm.import_ckpt is the nemo2 API for converting Hugging Face checkpoint to NeMo format\n", + "# example python usage:\n", + "# llm.import_ckpt(model=llm.llama3_8b.model(), source=\"hf://meta-llama/Meta-Llama-3-8B\")\n", + "#\n", + "# We use run.Partial to configure this function\n", + "def configure_checkpoint_conversion():\n", + " return run.Partial(\n", + " llm.import_ckpt,\n", + " model=llm.llama3_8b.model(),\n", + " source=\"hf://meta-llama/Meta-Llama-3-8B\",\n", + " overwrite=False,\n", + " )\n", + "\n", + "# configure your function\n", + "import_ckpt = configure_checkpoint_conversion()\n", + "# define your executor\n", + "local_executor = run.LocalExecutor()\n", + "\n", + "# run your experiment\n", + "run.run(import_ckpt, executor=local_executor)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 2. Prepare the Data and Customize the DataModule\n", + "\n", + "We will be using Databricks-dolly-15k for this notebook. NeMo 2.0 already provides a `DollyDataModule`. For all data modules that are included in NeMo 2.0, refer to the [data module directory](https://github.com/NVIDIA/NeMo/tree/main/nemo/collections/llm/gpt/data). Example usage:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def dolly() -> run.Config[pl.LightningDataModule]:\n", + " return run.Config(llm.DollyDataModule, seq_length=2048, micro_batch_size=1, global_batch_size=8, num_workers=0)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To use your own data, you will need to create a custom `DataModule`. This involves extending the base class `FineTuningDataModule` so that you have access to existing data handling logic, such as packed sequences. Here we walk you through the process step by step, using the already existing [`DollyDataModule`](https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/llm/gpt/data/dolly.py) as an example. \n", + "\n", + "### Subclass the FineTuningDataModule\n", + "You need to extend the `FineTuningDataModule` if you're fine-tuning NeMo models. This provides access to existing data handling logic, such as packed sequences. The `data_root` parameter is where you store your generated `train/validation/test.jsonl` in NeMo format. Below is how `DollyDataModule` does it:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from typing import List, Optional\n", + "from nemo.lightning.io.mixin import IOMixin\n", + "from nemo.collections.llm.gpt.data.fine_tuning import FineTuningDataModule\n", + "\n", + "class DollyDataModule(FineTuningDataModule, IOMixin):\n", + " def __init__(\n", + " self,\n", + " seq_length: int = 2048,\n", + " tokenizer: Optional[\"TokenizerSpec\"] = None,\n", + " micro_batch_size: int = 4,\n", + " global_batch_size: int = 8,\n", + " rampup_batch_size: Optional[List[int]] = None,\n", + " force_redownload: bool = False,\n", + " delete_raw: bool = True,\n", + " seed: int = 1234,\n", + " memmap_workers: int = 1,\n", + " num_workers: int = 8,\n", + " pin_memory: bool = True,\n", + " persistent_workers: bool = False,\n", + " pad_to_max_length: bool = False,\n", + " packed_sequence_size: int = -1,\n", + " ):\n", + " self.force_redownload = force_redownload\n", + " self.delete_raw = delete_raw\n", + "\n", + " super().__init__(\n", + " dataset_root=get_dataset_root(\"dolly\"),\n", + " seq_length=seq_length,\n", + " tokenizer=tokenizer,\n", + " micro_batch_size=micro_batch_size,\n", + " global_batch_size=global_batch_size,\n", + " rampup_batch_size=rampup_batch_size,\n", + " seed=seed,\n", + " memmap_workers=memmap_workers,\n", + " num_workers=num_workers,\n", + " pin_memory=pin_memory,\n", + " persistent_workers=persistent_workers,\n", + " pad_to_max_length=pad_to_max_length,\n", + " packed_sequence_size=packed_sequence_size,\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Override the `prepare_data` Method\n", + "\n", + "The `prepare_data` method is responsible for downloading and preprocessing data if needed. If the dataset is already downloaded, you can skip this step.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def prepare_data(self) -> None:\n", + " # if train file is specified, no need to do anything\n", + " if not self.train_path.exists() or self.force_redownload:\n", + " dset = self._download_data()\n", + " self._preprocess_and_split_data(dset)\n", + " super().prepare_data()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Implement Data Download and Preprocessing Logic\n", + "\n", + "If your dataset requires downloading or preprocessing, implement this logic within the helper methods. Skip the download part if it's not needed." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def _download_data(self):\n", + " logging.info(f\"Downloading {self.__class__.__name__}...\")\n", + " return load_dataset(\n", + " \"databricks/databricks-dolly-15k\",\n", + " cache_dir=str(self.dataset_root),\n", + " download_mode=\"force_redownload\" if self.force_redownload else None,\n", + " )\n", + "\n", + "def _preprocess_and_split_data(self, dset, train_ratio: float = 0.80, val_ratio: float = 0.15):\n", + " logging.info(f\"Preprocessing {self.__class__.__name__} to jsonl format and splitting...\")\n", + "\n", + " test_ratio = 1 - train_ratio - val_ratio\n", + " save_splits = {}\n", + " dataset = dset.get('train')\n", + " split_dataset = dataset.train_test_split(test_size=val_ratio + test_ratio, seed=self.seed)\n", + " split_dataset2 = split_dataset['test'].train_test_split(\n", + " test_size=test_ratio / (val_ratio + test_ratio), seed=self.seed\n", + " )\n", + " save_splits['training'] = split_dataset['train']\n", + " save_splits['validation'] = split_dataset2['train']\n", + " save_splits['test'] = split_dataset2['test']\n", + "\n", + " for split_name, dataset in save_splits.items():\n", + " output_file = self.dataset_root / f\"{split_name}.jsonl\"\n", + " with output_file.open(\"w\", encoding=\"utf-8\") as f:\n", + " for example in dataset:\n", + " context = example[\"context\"].strip()\n", + " if context != \"\":\n", + " # Randomize context and instruction order.\n", + " context_first = np.random.randint(0, 2) == 0\n", + " if context_first:\n", + " instruction = example[\"instruction\"].strip()\n", + " assert instruction != \"\"\n", + " _input = f\"{context}\\n\\n{instruction}\"\n", + " _output = example[\"response\"]\n", + " else:\n", + " instruction = example[\"instruction\"].strip()\n", + " assert instruction != \"\"\n", + " _input = f\"{instruction}\\n\\n{context}\"\n", + " _output = example[\"response\"]\n", + " else:\n", + " _input = example[\"instruction\"]\n", + " _output = example[\"response\"]\n", + "\n", + " f.write(json.dumps({\"input\": _input, \"output\": _output, \"category\": example[\"category\"]}) + \"\\n\")\n", + "\n", + " logging.info(f\"{split_name} split saved to {output_file}\")\n", + "\n", + " if self.delete_raw:\n", + " for p in self.dataset_root.iterdir():\n", + " if p.is_dir():\n", + " shutil.rmtree(p)\n", + " elif '.jsonl' not in str(p.name):\n", + " p.unlink()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The original example in Dolly dataset looks like:\n", + "```\n", + "{'instruction': 'Extract all the movies from this passage and the year they were released out. Write each movie as a separate sentence', 'context': \"The genre has existed since the early years of silent cinema, when Georges Melies' A Trip to the Moon (1902) employed trick photography effects. The next major example (first in feature length in the genre) was the film Metropolis (1927). From the 1930s to the 1950s, the genre consisted mainly of low-budget B movies. After Stanley Kubrick's landmark 2001: A Space Odyssey (1968), the science fiction film genre was taken more seriously. In the late 1970s, big-budget science fiction films filled with special effects became popular with audiences after the success of Star Wars (1977) and paved the way for the blockbuster hits of subsequent decades.\", 'response': 'A Trip to the Moon was released in 1902. Metropolis came out in 1927. 2001: A Space Odyssey was released in 1968. Star Wars came out in 1977.', 'category': 'information_extraction'}\n", + "```\n", + "After the preprocessing logic, the data examples are transformed into NeMo format, as below:\n", + "```\n", + "{'input': \"Extract all the movies from this passage and the year they were released out. Write each movie as a separate sentence\\n\\nThe genre has existed since the early years of silent cinema, when Georges Melies' A Trip to the Moon (1902) employed trick photography effects. The next major example (first in feature length in the genre) was the film Metropolis (1927). From the 1930s to the 1950s, the genre consisted mainly of low-budget B movies. After Stanley Kubrick's landmark 2001: A Space Odyssey (1968), the science fiction film genre was taken more seriously. In the late 1970s, big-budget science fiction films filled with special effects became popular with audiences after the success of Star Wars (1977) and paved the way for the blockbuster hits of subsequent decades.\", 'output': 'A Trip to the Moon was released in 1902. Metropolis came out in 1927. 2001: A Space Odyssey was released in 1968. Star Wars came out in 1977.', 'category': 'information_extraction'}\n", + "```\n", + "Each data example is saved as a json string as one line in the `train/validation/test.jsonl` file, under `data_root` directory you specified earlier." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 3.1: Configure SFT with the NeMo 2.0 API \n", + "\n", + "In this notebook we use NeMo 2.0 API to perform SFT. First we configure the following components for training. These components are similar between SFT and PEFT. SFT and PEFT both uses `llm.finetune` API. To switch from PEFT to SFT, you just need to remove the `peft` parameter.\n", + "\n", + "### Configure the Trainer\n", + "The NeMo 2.0 Trainer works similarly to the PyTorch Lightning trainer." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def trainer() -> run.Config[nl.Trainer]:\n", + " strategy = run.Config(\n", + " nl.MegatronStrategy,\n", + " tensor_model_parallel_size=2\n", + " )\n", + " trainer = run.Config(\n", + " nl.Trainer,\n", + " devices=2,\n", + " max_steps=20,\n", + " accelerator=\"gpu\",\n", + " strategy=strategy,\n", + " plugins=bf16_mixed(),\n", + " log_every_n_steps=1,\n", + " limit_val_batches=2,\n", + " val_check_interval=2,\n", + " num_sanity_val_steps=0,\n", + " )\n", + " return trainer" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "### Configure the Logger\n", + "Configure your training steps, output directories and logging through `NeMoLogger`. In the following example, the experiment output will be saved at `./results/nemo2_sft`." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def logger() -> run.Config[nl.NeMoLogger]:\n", + " ckpt = run.Config(\n", + " nl.ModelCheckpoint,\n", + " save_last=True,\n", + " every_n_train_steps=10,\n", + " monitor=\"reduced_train_loss\",\n", + " save_top_k=1,\n", + " save_on_train_epoch_end=True,\n", + " save_optim_on_train_end=True,\n", + " )\n", + "\n", + " return run.Config(\n", + " nl.NeMoLogger,\n", + " name=\"nemo2_sft\",\n", + " log_dir=\"./results\",\n", + " use_datetime_version=False,\n", + " ckpt=ckpt,\n", + " wandb=None\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "\n", + "### Configure the Optimizer\n", + "In the following example, we will be using the distributed adam optimizer and pass in the optimizer configuration through `OptimizerConfig`: " + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def adam_with_cosine_annealing() -> run.Config[nl.OptimizerModule]:\n", + " opt_cfg = run.Config(\n", + " OptimizerConfig,\n", + " optimizer=\"adam\",\n", + " lr=5e-6,\n", + " adam_beta2=0.98,\n", + " use_distributed_optimizer=True,\n", + " clip_grad=1.0,\n", + " bf16=True,\n", + " )\n", + " return run.Config(\n", + " nl.MegatronOptimizerModule,\n", + " config=opt_cfg\n", + " )\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Configure the Base Model\n", + "We will perform SFT on top of Llama3-8B, so we create a `LlamaModel` to pass to the finetune API." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def llama3_8b() -> run.Config[pl.LightningModule]:\n", + " return run.Config(llm.LlamaModel, config=run.Config(llm.Llama3Config8B))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Auto Resume\n", + "In NeMo 2.0, we can directly pass in the Llama3-8b Hugging Face ID to start SFT without manually converting it into the NeMo checkpoint format, as required in NeMo 1.0." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def resume() -> run.Config[nl.AutoResume]:\n", + " return run.Config(\n", + " nl.AutoResume,\n", + " restore_config=run.Config(nl.RestoreConfig,\n", + " path=\"nemo://meta-llama/Meta-Llama-3-8B\"\n", + " ),\n", + " resume_if_exists=True,\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "### Configure NeMo 2.0 finetune API\n", + "Using all the components we created above, we can call the NeMo 2.0 finetune API. The python example usage is as below:\n", + "```\n", + "llm.finetune(\n", + " model=llama3_8b(),\n", + " data=dolly(),\n", + " trainer=trainer(),\n", + " log=logger(),\n", + " optim=adam_with_cosine_annealing(),\n", + " resume=resume(),\n", + ")\n", + "```\n", + "We configure the `llm.finetune` API as below:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def configure_finetuning_recipe():\n", + " return run.Partial(\n", + " llm.finetune,\n", + " model=llama3_8b(),\n", + " trainer=trainer(),\n", + " data=dolly(),\n", + " log=logger(),\n", + " optim=adam_with_cosine_annealing(),\n", + " resume=resume(),\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 3.2: Run SFT with NeMo 2.0 API and NeMo-Run\n", + "\n", + "We use `LocalExecutor` for executing our configured finetune function. For more details on the NeMo-Run executor, refer to [Execute NeMo Run](https://github.com/NVIDIA/NeMo-Run/blob/main/docs/source/guides/execution.md) of NeMo-Run Guides. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def local_executor_torchrun(nodes: int = 1, devices: int = 2) -> run.LocalExecutor:\n", + " # Env vars for jobs are configured here\n", + " env_vars = {\n", + " \"TORCH_NCCL_AVOID_RECORD_STREAMS\": \"1\",\n", + " \"NCCL_NVLS_ENABLE\": \"0\",\n", + " \"NVTE_DP_AMAX_REDUCE_INTERVAL\": \"0\",\n", + " \"NVTE_ASYNC_AMAX_REDUCTION\": \"1\",\n", + " \"NVTE_FUSED_ATTN\": \"0\",\n", + " }\n", + "\n", + " executor = run.LocalExecutor(ntasks_per_node=devices, launcher=\"torchrun\", env_vars=env_vars)\n", + "\n", + " return executor\n", + "\n", + "if __name__ == '__main__':\n", + " run.run(configure_finetuning_recipe(), executor=local_executor_torchrun())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 4. Generate Results from Trained SFT Checkpoints\n", + "\n", + "We use the `llm.generate` API in NeMo 2.0 to generate results from the trained SFT checkpoint. Find your last saved checkpoint from your experiment dir: `results/nemo2_sft/checkpoints`. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "sft_ckpt_path=str(next((d for d in Path(\"./results/nemo2_sft/checkpoints/\").iterdir() if d.is_dir() and d.name.endswith(\"-last\")), None))\n", + "print(\"We will load SFT checkpoint from:\", sft_ckpt_path)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "When using `llm.generate` API, you can pass a data module such as dolly: `input_dataset=dolly()`. This will use the test set from the specified data module to generate predictions. In the following example, the generated predictions are saved to the `sft_predictions.txt` file. Note that while fine-tuning required a minimum of 2 GPUs with `tensor_model_parallel_size=2`, generating predictions only requires `tensor_model_parallel_size=1`. However, using multiple GPUs can speed up the inference." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from megatron.core.inference.common_inference_params import CommonInferenceParams\n", + "\n", + "\n", + "def trainer() -> run.Config[nl.Trainer]:\n", + " strategy = run.Config(\n", + " nl.MegatronStrategy,\n", + " tensor_model_parallel_size=1,\n", + " )\n", + " trainer = run.Config(\n", + " nl.Trainer,\n", + " accelerator=\"gpu\",\n", + " devices=1,\n", + " num_nodes=1,\n", + " strategy=strategy,\n", + " plugins=bf16_mixed(),\n", + " )\n", + " return trainer\n", + "\n", + "def configure_inference():\n", + " return run.Partial(\n", + " llm.generate,\n", + " path=str(sft_ckpt_path),\n", + " trainer=trainer(),\n", + " input_dataset=dolly(),\n", + " inference_params=CommonInferenceParams(num_tokens_to_generate=20, top_k=1),\n", + " output_path=\"sft_prediction.jsonl\",\n", + " )\n", + "\n", + "\n", + "def local_executor_torchrun(nodes: int = 1, devices: int = 1) -> run.LocalExecutor:\n", + " # Env vars for jobs are configured here\n", + " env_vars = {\n", + " \"TORCH_NCCL_AVOID_RECORD_STREAMS\": \"1\",\n", + " \"NCCL_NVLS_ENABLE\": \"0\",\n", + " \"NVTE_DP_AMAX_REDUCE_INTERVAL\": \"0\",\n", + " \"NVTE_ASYNC_AMAX_REDUCTION\": \"1\",\n", + " \"NVTE_FUSED_ATTN\": \"0\",\n", + " }\n", + "\n", + " executor = run.LocalExecutor(ntasks_per_node=devices, launcher=\"torchrun\", env_vars=env_vars)\n", + "\n", + " return executor\n", + "\n", + "if __name__ == '__main__':\n", + " run.run(configure_inference(), executor=local_executor_torchrun())\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "After the inference is complete, you will see results similar to the following:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "%%bash\n", + "head -n 3 sft_prediction.jsonl" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You should see output similar to the following:\n", + "```\n", + "{\"input\": \"What is best creator's platform\", \"category\": \"brainstorming\", \"label\": \"Youtube. Youtube should be best creator platform\", \"prediction\": \" for video content creators. YouTube is best creator's platform for video content creators.\"}\n", + "{\"input\": \"When was the last time the Raiders won the Super Bowl?\", \"category\": \"open_qa\", \"label\": \"The Raiders have won three Super Bowl championships (1977, 1981, and 1984), one American Football League (AFL) championship (1967), and four American Football Conference (AFC) titles. The most recent Super Bowl ring was won in 1984 against the Washington Redskins of the NFC.\", \"prediction\": \" 2003\"}\n", + "{\"input\": \"Muckle Water is a long, narrow fresh water loch on Ward Hill on Rousay, Orkney, Scotland. It is the biggest loch on the island and is popular for fishing. It can be reached by a track from the roadside. The Suso Burn on the north eastern shore drains the loch into the Sound of Rousay.\\n\\nWhere is Muckle Water?\", \"category\": \"closed_qa\", \"label\": \"Muckle water is located in Rousay, Orkney, Scotland.\", \"prediction\": \" Muckle Water is a long, narrow fresh water loch on Ward Hill on Rousay,\"}\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 5. Calculate Evaluation Metrics\n", + "\n", + "We can evaluate the model's predictions by calculating the Exact Match (EM) and F1 scores.\n", + "- Exact Match is a binary measure (0 or 1) checking if the model outputs match one of the\n", + "ground truth answer exactly.\n", + "- F1 score is the harmonic mean of precision and recall for the answer words.\n", + "\n", + "Below is a script that computes these metrics. The sample scores can be improved by training the model further and performing hyperparameter tuning. In this notebook, we only train for 20 steps." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "!python /opt/NeMo/scripts/metric_calculation/peft_metric_calc.py --pred_file sft_prediction.jsonl --label_field \"label\" --pred_field \"prediction\"" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/tutorials/llm/llama-3/pruning-distillation/01_data_preparation.ipynb b/tutorials/llm/llama-3/pruning-distillation/01_data_preparation.ipynb index 1f84dd2719e6..8548c0cfb1d0 100644 --- a/tutorials/llm/llama-3/pruning-distillation/01_data_preparation.ipynb +++ b/tutorials/llm/llama-3/pruning-distillation/01_data_preparation.ipynb @@ -9,7 +9,7 @@ "\n", "The dataset has to be preprocessed using the [preprocess_data_for_megatron.py](https://github.com/NVIDIA/NeMo/blob/main/scripts/nlp_language_modeling/preprocess_data_for_megatron.py) script included in the NeMo Framework. This step will also tokenize data using the `meta-llama/Meta-Llama-3.1-8B` tokenizer model to convert the data into a memory map format.\n", "\n", - "> `NOTE:` In the block of code below, pass the paths to your train, test and validation data files." + "> `NOTE:` In the block of code below, pass the paths to your train, test, and validation data files." ] }, { diff --git a/tutorials/llm/llama-3/pruning-distillation/02_teacher_finetuning.ipynb b/tutorials/llm/llama-3/pruning-distillation/02_teacher_finetuning.ipynb index 8d08793bbe9a..7d58ac4779aa 100644 --- a/tutorials/llm/llama-3/pruning-distillation/02_teacher_finetuning.ipynb +++ b/tutorials/llm/llama-3/pruning-distillation/02_teacher_finetuning.ipynb @@ -6,15 +6,15 @@ "metadata": {}, "source": [ "\n", - "### Step 2: Finetune the teacher on the dataset\n", + "### Step 2: Fine-tune the teacher on the dataset\n", "\n", - "NeMo framework includes a standard python script [megatron_gpt_pretraining.py](https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/language_modeling/megatron_gpt_pretraining.py) for training a model. Once you have your model downloaded and the dataset ready, fine-tuning the teacher model with NeMo is essentially just running this script!\n", + "NeMo Framework includes a standard Python script, [megatron_gpt_pretraining.py](https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/language_modeling/megatron_gpt_pretraining.py), for training a model. Once you have your model downloaded and the dataset ready, fine-tuning the teacher model with NeMo is essentially just running this script!\n", "\n", - "We finetune the unpruned model on our dataset to correct the distribution shift across the original dataset the model was trained on. Per the [blog](https://developer.nvidia.com/blog/how-to-prune-and-distill-llama-3-1-8b-to-an-nvidia-llama-3-1-minitron-4b-model/) and [tech report](https://arxiv.org/pdf/2408.11796), experiments showed that, without correcting for the distribution shift, the teacher provides suboptimal guidance on the dataset when being distilled.\n", + "We fine-tune the unpruned model on our dataset to correct the distribution shift from the original dataset the model was trained on. According to the [blog](https://developer.nvidia.com/blog/how-to-prune-and-distill-llama-3-1-8b-to-an-nvidia-llama-3-1-minitron-4b-model/) and [tech report](https://arxiv.org/pdf/2408.11796), experiments showed that without correcting for this distribution shift, the teacher provides suboptimal guidance on the dataset during distillation.\n", "\n", "For this demonstration, this training run is capped by `STEPS`, and validation is carried out every `VAL_INTERVAL` steps.\n", "\n", - "> `NOTE:` In the block of code below, pass the paths to your pre-processed train, test and validation data files as well as path to the teacher .nemo model." + "> `NOTE:` In the block of code below, pass the paths to your pre-processed train, test, and validation data files, as well as the path to the teacher .nemo model." ] }, { @@ -124,8 +124,8 @@ "id": "3040a993-8423-475f-8bc6-d1dd1ce16a83", "metadata": {}, "source": [ - "This will create a finetuned teacher model named `megatron_llama_ft.nemo` in `./distill_trainings/megatron_llama_ft/checkpoints/`. We'll use this later.\n", - "> `NOTE:`This script takes at least 20 minutes to run (depending on GPU) and will generate the finetuned teacher model." + "This will create a fine-tuned teacher model named `megatron_llama_ft.nemo` in `./distill_trainings/megatron_llama_ft/checkpoints/`. We'll use this later.\n", + "> `NOTE:`This script takes at least 20 minutes to run (depending on GPU) and will generate the fine-tuned teacher model." ] } ], diff --git a/tutorials/llm/llama-3/pruning-distillation/03_a_depth_pruning.ipynb b/tutorials/llm/llama-3/pruning-distillation/03_a_depth_pruning.ipynb index a195c2f3a405..d64f8c15bd00 100644 --- a/tutorials/llm/llama-3/pruning-distillation/03_a_depth_pruning.ipynb +++ b/tutorials/llm/llama-3/pruning-distillation/03_a_depth_pruning.ipynb @@ -5,8 +5,8 @@ "id": "8bc99d2f-9ac6-40c2-b072-12b6cb7b9aca", "metadata": {}, "source": [ - "### Step 3: Prune the finetuned-teacher model to create a student\n", - "In this step, we will explore two methods to prune the finetuned teacher model. Refer to the ``NOTE`` in the **_step-by-step instructions_** section of [introduction.ipynb](./introduction.ipynb) to decide which pruning techniques you would like to explore.\n", + "### Step 3: Prune the fine-tuned teacher model to create a student\n", + "In this step, we will explore two methods to prune the fine-tuned teacher model. Refer to the ``NOTE`` in the **_step-by-step instructions_** section of [introduction.ipynb](./introduction.ipynb) to decide which pruning techniques you would like to explore.\n", "\n", "In the first method, depth-pruning, we trim the layers of the model." ] @@ -21,7 +21,7 @@ "\n", "Per the [blog](https://developer.nvidia.com/blog/how-to-prune-and-distill-llama-3-1-8b-to-an-nvidia-llama-3-1-minitron-4b-model/) and [tech report](https://arxiv.org/pdf/2408.11796), removing contiguous layers from the second last block (layers 16 to 31 continuously) yields the best overall results. \n", "\n", - "> `NOTE:` In the block of code below, pass the paths to your finetuned teacher .nemo model." + "> `NOTE:` In the block of code below, pass the paths to your fine-tuned teacher .nemo model." ] }, { diff --git a/tutorials/llm/llama-3/pruning-distillation/03_b_width_pruning.ipynb b/tutorials/llm/llama-3/pruning-distillation/03_b_width_pruning.ipynb index 7d91d36cbb32..b4e323463078 100644 --- a/tutorials/llm/llama-3/pruning-distillation/03_b_width_pruning.ipynb +++ b/tutorials/llm/llama-3/pruning-distillation/03_b_width_pruning.ipynb @@ -5,8 +5,8 @@ "id": "8bc99d2f-9ac6-40c2-b072-12b6cb7b9aca", "metadata": {}, "source": [ - "### Step 3: Prune the finetuned-teacher model to create a student\n", - "In the second method, we will width-prune. In width-pruning, we trim the neurons, attention heads and embedding channels. \n", + "### Step 3: Prune the fine-tuned teacher model to create a student\n", + "In the second method, we will width-prune. In width-pruning, we trim the neurons, attention heads, and embedding channels.\n", "\n", "Refer to the ``NOTE`` in the **_step-by-step instructions_** section of [introduction.ipynb](./introduction.ipynb) to decide which pruning techniques you would like to explore." ] @@ -20,15 +20,15 @@ "source": [ "#### Step 3.b.: Using width-pruning\n", "To width-prune the model, we do the following:\n", - "- prune (trim) the MLP intermediate dimension from 14336 to 9216.\n", - "- prune the hidden size from 4096 to 3072.\n", - "- and retrain the attention headcount and number of layers\n", + "- Prune (trim) the MLP intermediate dimension from 14336 to 9216.\n", + "- Prune the hidden size from 4096 to 3072.\n", + "- Retrain the attention headcount and number of layers\n", "\n", - "For width-pruning we will use the [megatron_gpt_prune.py](https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/language_modeling/megatron_gpt_prune.py) script in the NeMo Framework. To see the detailed list of parameters for width-pruning, you can view the [megatron_gpt_prune.yaml](https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/language_modeling/conf/megatron_gpt_prune.yaml) file.\n", + "For width-pruning, we will use the [megatron_gpt_prune.py](https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/language_modeling/megatron_gpt_prune.py) script in the NeMo Framework. To see the detailed list of parameters for width-pruning, you can view the [megatron_gpt_prune.yaml](https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/language_modeling/conf/megatron_gpt_prune.yaml) file.\n", "\n", "We use the above parameters to get a competitive model for this demonstration. You can use other strategies or parameters from the [blog](https://developer.nvidia.com/blog/how-to-prune-and-distill-llama-3-1-8b-to-an-nvidia-llama-3-1-minitron-4b-model/) or the [tech report](https://arxiv.org/pdf/2408.11796) for your experiments. \n", "\n", - "> `NOTE:` In the block of code below, pass the paths to your finetuned teacher .nemo model.\n", + "> `NOTE:` In the block of code below, pass the paths to your fine-tuned teacher .nemo model.\n", "\n", "> `TIP:` You can increase the ``batch_size`` (upto 1024) to speed up the width-pruning script execution." ] diff --git a/tutorials/llm/llama-3/pruning-distillation/04_a_distilling_depth_pruned_student.ipynb b/tutorials/llm/llama-3/pruning-distillation/04_a_distilling_depth_pruned_student.ipynb index ccbe1cbf394b..488225837731 100644 --- a/tutorials/llm/llama-3/pruning-distillation/04_a_distilling_depth_pruned_student.ipynb +++ b/tutorials/llm/llama-3/pruning-distillation/04_a_distilling_depth_pruned_student.ipynb @@ -6,9 +6,9 @@ "metadata": {}, "source": [ "### Step 4: Distill knowledge from teacher into student\n", - "Distillation of a model with NeMo Framework is also possible using a python script: [megatron_gpt_distillation.py](https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/language_modeling/megatron_gpt_distillation.py). In this notebook, we will explore distillation with the depth-pruned model as the `STUDENT` model. \n", + "Distillation of a model with NeMo Framework is also possible using a Python script: [megatron_gpt_distillation.py](https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/language_modeling/megatron_gpt_distillation.py). In this notebook, we will explore distillation with the depth-pruned model as the `STUDENT` model.\n", "\n", - "For this demonstration, the `TEACHER` would be the finetuned teacher model `megatron_llama_ft.nemo` and the `STUDENT` model would be the pruned 4B model. This training run is capped by `STEPS`, and validation is carried out every `VAL_INTERVAL` steps." + "For this demonstration, the `TEACHER` would be the fine-tuned teacher model `megatron_llama_ft.nemo` and the `STUDENT` model would be the pruned 4B model. This training run is capped by `STEPS`, and validation is carried out every `VAL_INTERVAL` steps." ] }, { @@ -19,7 +19,7 @@ "#### Step 4.a.: Using depth-pruned student\n", "While distilling knowledge from the teacher to depth-pruned model, the `STUDENT` model would be `4b_depth_pruned_model.nemo` as produced by the [depth-pruning](./03_a_depth_pruning.ipynb) notebook. This training run is capped by `STEPS`, and validation is carried out every `VAL_INTERVAL` steps.\n", "\n", - "> `NOTE:` In the block of code below, pass the paths to your pre-processed train, test and validation data files as well as path to the teacher and student .nemo models." + "> `NOTE:` In the block of code below, pass the paths to your pre-processed train, test, and validation data files, as well as path to the teacher and student .nemo models." ] }, { diff --git a/tutorials/llm/llama-3/pruning-distillation/04_b_distilling_width_pruned_student.ipynb b/tutorials/llm/llama-3/pruning-distillation/04_b_distilling_width_pruned_student.ipynb index 48e81c96cdcf..95110dd19dd9 100644 --- a/tutorials/llm/llama-3/pruning-distillation/04_b_distilling_width_pruned_student.ipynb +++ b/tutorials/llm/llama-3/pruning-distillation/04_b_distilling_width_pruned_student.ipynb @@ -6,10 +6,10 @@ "metadata": {}, "source": [ "### Step 4: Distill knowledge from teacher into student\n", - "Distillation of a model with NeMo Framework is also possible using a python script: [megatron_gpt_distillation.py](https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/language_modeling/megatron_gpt_distillation.py). \n", + "Distillation of a model with NeMo Framework is also possible using a Python script: [megatron_gpt_distillation.py](https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/language_modeling/megatron_gpt_distillation.py). \n", "In this notebook, we will explore distillation with the width-pruned model as the `STUDENT` model.\n", "\n", - "For this demonstration, the `TEACHER` would be the finetuned teacher model `megatron_llama_ft.nemo` and the `STUDENT` model would be the pruned 4B model. This training run is capped by `STEPS`, and validation is carried out every `VAL_INTERVAL` steps." + "For this demonstration, the `TEACHER` would be the fine-tuned teacher model `megatron_llama_ft.nemo` and the `STUDENT` model would be the pruned 4B model. This training run is capped by `STEPS`, and validation is carried out every `VAL_INTERVAL` steps." ] }, { @@ -20,7 +20,7 @@ "#### Step 4.b.: Using width-pruned student\n", "While distilling knowledge from the teacher to width-pruned model, the `STUDENT` model would be `4b_width_pruned_model.nemo` as produced by the [width-pruning](./03_b_width_pruning.ipynb) notebook. This training run is capped by `STEPS`, and validation is carried out every `VAL_INTERVAL` steps.\n", "\n", - "> `NOTE:` In the block of code below, pass the paths to your pre-processed train, test and validation data files as well as path to the teacher and student .nemo models." + "> `NOTE:` In the block of code below, pass the paths to your pre-processed train, test, and validation data files, as well as path to the teacher and student .nemo models." ] }, { diff --git a/tutorials/llm/llama-3/pruning-distillation/05_display_results.ipynb b/tutorials/llm/llama-3/pruning-distillation/05_display_results.ipynb index 0264cc288957..dcb483c55ab6 100644 --- a/tutorials/llm/llama-3/pruning-distillation/05_display_results.ipynb +++ b/tutorials/llm/llama-3/pruning-distillation/05_display_results.ipynb @@ -8,7 +8,8 @@ "### Step 5: Display the validation loss\n", "\n", "Now that the results are in, let's visualize the validation loss of the two distilled models using the `tensorboard` library. \n", - "> `NOTE:` This notebook demonstrates the use of the teacher finetuning, pruning and the distillation script. These scripts should ideally be run on a multi-node cluster with a larger `GLOBAL_BATCH_SIZE` and `STEPS` to see improvement in the validation loss." + "\n", + "> `NOTE:` This notebook demonstrates the use of the teacher fine-tuning, pruning, and the distillation script. These scripts should ideally be run on a multi-node cluster with a larger `GLOBAL_BATCH_SIZE` and `STEPS` to see improvement in the validation loss." ] }, { @@ -16,8 +17,8 @@ "id": "b5822d62-8131-4046-8c22-0bf0fce81df7", "metadata": {}, "source": [ - "#### Validation Loss using depth-pruned model as student in distillation script\n", - "Here is an image of the validation loss over 30 steps of running the training step in the distillation script when we distill the knowledge from the finetuned teacher model to the depth-pruned student." + "#### Validation Loss Using Depth-Pruned Model as Student in Distillation Script\n", + "Here is an image of the validation loss over 30 steps of running the training step in the distillation script, where we distill the knowledge from the fine-tuned teacher model to the depth-pruned student." ] }, { @@ -35,7 +36,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 1, "id": "db6fcf26-8ae8-40e1-875a-0a10bf85be81", "metadata": { "tags": [] @@ -44,7 +45,7 @@ { "data": { "text/html": [ - "
Validation Loss over 30 Training Steps with Depth-Pruned model as Student
" + "
Validation Loss over 30 Training Steps with Depth-Pruned Model as Student
" ], "text/plain": [ "" @@ -68,7 +69,7 @@ ], "source": [ "from IPython.display import Image, display, HTML\n", - "title = \"Validation Loss over 30 Training Steps with Depth-Pruned model as Student\"\n", + "title = \"Validation Loss over 30 Training Steps with Depth-Pruned Model as Student\"\n", "display(HTML(f\"
{title}
\"))\n", "display(Image(url=\"https://github.com/NVIDIA/NeMo/releases/download/r2.0.0rc1/val_loss_depth_pruned_student_distillation.png\", width=400))" ] @@ -78,8 +79,8 @@ "id": "f10041ae-6533-47de-9f76-f97d4469c27a", "metadata": {}, "source": [ - "#### Validation Loss using width-pruned model as student in distillation script\n", - "Here is an image of the validation loss over 30 steps of running the training step in the distillation script when we distill the knowledge from the finetuned teacher model to the width-pruned student." + "#### Validation Loss Using Width-Pruned Model as Student in Distillation Script\n", + "Here is an image of the validation loss over 30 steps of running the training step in the distillation script, where we distill the knowledge from the fine-tuned teacher model to the width-pruned student." ] }, { @@ -97,7 +98,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 2, "id": "ecd79583-f662-40c6-a690-9f4bb847de4e", "metadata": { "tags": [] @@ -106,7 +107,7 @@ { "data": { "text/html": [ - "
Validation Loss over 30 Training Steps with Width-Pruned model as Student
" + "
Validation Loss over 30 Training Steps with Width-Pruned Model as Student
" ], "text/plain": [ "" @@ -130,18 +131,10 @@ ], "source": [ "from IPython.display import Image, display, HTML\n", - "title = \"Validation Loss over 30 Training Steps with Width-Pruned model as Student\"\n", + "title = \"Validation Loss over 30 Training Steps with Width-Pruned Model as Student\"\n", "display(HTML(f\"
{title}
\"))\n", "display(Image(url=\"https://github.com/NVIDIA/NeMo/releases/download/r2.0.0rc1/val_loss_width_pruned_student_distillation.png\", width=400))" ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7ab6ed6f-8bc3-4188-919f-7cee842635ed", - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/tutorials/llm/llama-3/pruning-distillation/README.rst b/tutorials/llm/llama-3/pruning-distillation/README.rst index 34febcffa366..88086bb37ea4 100644 --- a/tutorials/llm/llama-3/pruning-distillation/README.rst +++ b/tutorials/llm/llama-3/pruning-distillation/README.rst @@ -1,13 +1,13 @@ Llama 3.1 Pruning and Distillation with NeMo Framework ======================================================================================= -`Llama 3.1 `_ are open-source large language models by Meta that deliver state-of-the-art performance on popular industry benchmarks. They have been pretrained on over 15 trillion tokens, and support a 128K token context length. They are available in three sizes, 8B, 70B, and 405B, and each size has two variants—base pretrained and instruction tuned. +`Llama 3.1 `_ models, developed by Meta, are open-source large language models that deliver state-of-the-art performance on popular industry benchmarks. Pretrained on over 15 trillion tokens, they support a 128K token context length. These models are available in three sizes: 8B, 70B, and 405B. Each size offers two variants: base pretrained and instruction tuned. -`NVIDIA NeMo Framework `_ provides tools to perform teacher finetuning, pruning and distillation on Llama 3.1 to fit your use case. +`NVIDIA NeMo Framework `_ provides tools to perform teacher fine-tuning, pruning, and distillation on Llama 3.1 to fit your use case. `NVIDIA TensorRT Model Optimizer `_ is a library (referred to as **Model Optimizer**, or **ModelOpt**) comprising state-of-the-art model optimization techniques including `quantization `_, `sparsity `_, `distillation `_, and `pruning `_ to compress models. -`LLM Pruning and Distillation in Practice: The Minitron Approach `_ provides tools to perform teacher finetuning, pruning and distillation on Llama 3.1 as described in the `tech report `_. +`LLM Pruning and Distillation in Practice: The Minitron Approach `_ provides tools to perform teacher fine-tuning, pruning, and distillation on Llama 3.1 as described in the `tech report `_. `How to Prune and Distill Llama-3.1 8B to an NVIDIA Llama-3.1-Minitron 4B Model `_ provides practical and effective structured compression best practices for LLMs that combine depth, width, attention, and MLP pruning with knowledge distillation-based retraining. These strategies are presented in the `Compact Language Models via Pruning and Knowledge Distillation `_ paper. @@ -16,30 +16,33 @@ Llama 3.1 Pruning and Distillation with NeMo Framework Objectives ---------- -This tutorial shows how to perform depth-pruning, teacher finetuning and distillation on **Llama 3.1 8B** using the `WikiText-103-v1 `_ dataset with NeMo Framework. The `WikiText-103-v1 `_ language modeling dataset is a collection of over 100 million tokens extracted from the set of verified Good and Featured articles on Wikipedia. For this demonstration, we will perform teacher correction by running a light finetuning procedure on the ``Meta Llama 3.1 8B`` teacher model to generate a finetuned teacher model ``megatron_llama_ft.nemo`` needed for optimal distillation. This finetuned teacher model is then trimmed. There are two methods to prune a model: depth-pruning and width-pruning. We will be exploring both pruning techniques which will yield ``4b_depth_pruned_model.nemo`` and ``4b_width_pruned_model.nemo`` respectively. These models will serve as a starting point for distillation to create the final distilled 4B models. +This tutorial demonstrates how to perform depth-pruning, width-pruning, teacher fine-tuning, and distillation on **Llama 3.1 8B** using the `WikiText-103-v1 `_ dataset with the NeMo Framework. The `WikiText-103-v1 `_ language modeling dataset comprises over 100 million tokens extracted from verified Good and Featured articles on Wikipedia. + +For this demonstration, we will perform teacher correction by running a light fine-tuning procedure on the ``Meta LLama 3.1 8B`` teacher model to generate a fine-tuned teacher model, ``megatron_llama_ft.nemo``, needed for optimal distillation. This fine-tuned teacher model is then trimmed. There are two methods to prune a model: depth-pruning and width-pruning. We will explore both techniques, yielding ``4b_depth_pruned_model.nemo`` and ``4b_width_pruned_model.nemo``, respectively. These models will serve as starting points for distillation to create the final distilled 4B models. + We are using models utilizing the ``meta-llama/Meta-Llama-3.1-8B`` tokenizer for this demonstration. -``NOTE:`` A subset of functions is being demonstrated in the notebooks. Some features like Neural Architecture Search (NAS) are unavailable but will be supported in future releases. +``NOTE:`` A subset of functions is being demonstrated in the notebooks. Some features like Neural Architecture Search (NAS) are unavailable, but will be supported in future releases. Requirements ------------- * System Configuration - * Access to at least 8 NVIDIA GPU with an individual memory of at least 80GB, for example: 8 x H100-80GB or 8 x A100-80GB. + * Access to at least 8 NVIDIA GPUs, each with a memory of at least 80GB (e.g., 8 x H100-80GB or 8 x A100-80GB). * A Docker-enabled environment, with `NVIDIA Container Runtime `_ installed, which will make the container GPU-aware. -* `Authenticate with NVIDIA NGC `_, and download `NGC CLI Tool `_. You will use this tool to download the model and customize it with NeMo Framework. +* `Authenticate with NVIDIA NGC `_ and download `NGC CLI Tool `_. You will use this tool to download the model and customize it with NeMo Framework. * Get your Hugging Face `access token `_, which will be used to obtain the tokenizer required during training. -``NOTE:`` The default configuration in the notebook runs on 8 x 80GB NVIDIA GPUs but you can potentially reduce Tensor Parallel size ``(TENSOR_PARALLEL_SIZE)`` along with the Micro-Batchsize ``(MICRO_BATCH_SIZE)`` in the teacher finetuning and distillation scripts to accommodate lower resource availability. +``NOTE:`` The default configuration in the notebook runs on 8 x 80GB NVIDIA GPUs. However, you can potentially reduce the Tensor Parallel size ``(TENSOR_PARALLEL_SIZE)`` along with the Micro-Batchsize ``(MICRO_BATCH_SIZE)`` in the teacher fine-tuning and distillation scripts to accommodate lower resource availability. -Create a pruned and distilled model with NeMo Framework +Create a Pruned and Distilled Model with NeMo Framework ------------------------------------------------------------------------------ -For pruning and distilling the model, you will use the NeMo Framework which is available as a `docker container `_. +For pruning and distilling the model, you will use the NeMo Framework, which is available as a `Docker container `_. -``NOTE:`` These notebooks use `NVIDIA TensorRT Model Optimizer `_ under the hood for pruning and distillation. +``NOTE:`` These notebooks use the `NVIDIA TensorRT Model Optimizer `_ under the hood for pruning and distillation. 1. Download the `Llama 3.1 8B .nemo `_ from NVIDIA NGC using the `NGC CLI `_. Generate the ``NGC_API_KEY`` following these `instructions `_. The following command saves the ``.nemo`` format model in a folder named ``llama-3_1-8b-nemo_v1.0`` in the current directory. You can specify another path using the ``-d`` option in the CLI tool. @@ -75,7 +78,7 @@ For pruning and distilling the model, you will use the NeMo Framework which is a 4. Then, navigate to `this notebook <./introduction.ipynb>`_ to get started. -This directory contains a list of notebooks which will go over all the steps to create a distilled 4B model. +This directory contains a list of notebooks that cover all the steps to create a distilled 4B model. :: @@ -91,7 +94,7 @@ This directory contains a list of notebooks which will go over all the steps to Results ------------------------------------------------------------------------------ -``NOTE:`` This notebook demonstrates the use of the teacher finetuning, pruning and the distillation scripts. These scripts should ideally be run on a multi-node cluster with a larger ``GLOBAL_BATCH_SIZE`` and ``STEPS`` to see improvement in the validation loss. +``NOTE:`` This notebook demonstrates the use of the teacher fine-tuning, pruning, and the distillation scripts. These scripts should ideally be run on a multi-node cluster with a larger ``GLOBAL_BATCH_SIZE`` and ``STEPS`` to see improvement in the validation loss. Here are the validation loss plots over 30 steps of running the training step in the distillation script (at the end of the `notebook <./05_display_results.ipynb>`_). @@ -100,11 +103,11 @@ Here are the validation loss plots over 30 steps of running the training step in :alt: Diagram showing the validation loss over 30 steps of running the training step in the distillation script when using the depth-pruned model as the student :align: center - Figure 1: Validation Loss Plot when using the depth-pruned model as the student + Figure 1: Validation Loss Plot When Using the Depth-Pruned Model as the Student .. figure:: https://github.com/NVIDIA/NeMo/releases/download/r2.0.0rc1/val_loss_width_pruned_student_distillation.png :width: 400px :alt: Diagram showing the validation loss over 30 steps of running the training step in the distillation script when using the width-pruned model as the student :align: center - Figure 2: Validation Loss Plot when using the width-pruned model as the student \ No newline at end of file + Figure 2: Validation Loss Plot When Using the Width-Pruned Model as the Student \ No newline at end of file diff --git a/tutorials/llm/llama-3/pruning-distillation/introduction.ipynb b/tutorials/llm/llama-3/pruning-distillation/introduction.ipynb index 1a3efc9f5f1e..71a5a6cfb03c 100644 --- a/tutorials/llm/llama-3/pruning-distillation/introduction.ipynb +++ b/tutorials/llm/llama-3/pruning-distillation/introduction.ipynb @@ -7,7 +7,7 @@ "tags": [] }, "source": [ - "# Pruning and Distillation of Llama 3.1 model with NeMo Framework" + "# Efficient Model Reduction with Pruning and Distillation of Llama 3.1 Using NeMo Framework" ] }, { @@ -15,15 +15,15 @@ "id": "03fd1cf4-c67a-4b8d-a5e5-46531be0f991", "metadata": {}, "source": [ - "This demonstration showcases performing pruning and distillation on **Llama 3.1-8B** with the [WikiText-103-v1](https://huggingface.co/datasets/Salesforce/wikitext/viewer/wikitext-103-v1) dataset using NeMo Framework. The [WikiText-103-v1](https://huggingface.co/datasets/Salesforce/wikitext/viewer/wikitext-103-v1) language modeling dataset is a collection of over 100 million tokens extracted from the set of verified 'Good' and 'Featured' articles on Wikipedia. \n", + "This tutorial demonstrates how to perform depth-pruning, teacher fine-tuning, and distillation on **Llama 3.1-8B** using the [WikiText-103-v1](https://huggingface.co/datasets/Salesforce/wikitext/viewer/wikitext-103-v1) dataset with NeMo Framework. The [WikiText-103-v1](https://huggingface.co/datasets/Salesforce/wikitext/viewer/wikitext-103-v1) language modeling dataset comprises over 100 million tokens extracted from verified Good and Featured articles on Wikipedia.\n", "\n", - "For this demonstration, we will perform a light finetuning procedure on the `Meta Llama 3.1 8B` teacher model to generate a finetuned teacher model. This finetuned teacher model will then be trimmed. There are two methods to prune a model: depth-pruning and width-pruning. This workflow will showcase both methods which will yield `4b_depth_pruned_model.nemo` and `4b_width_pruned_model.nemo` respectively, that will serve as a starting point for distillation to the final 4B models. \n", + "For this demonstration, we will perform teacher correction by running a light fine-tuning procedure on the `Meta Llama 3.1 8B` teacher model to generate a fine-tuned teacher model, `megatron_llama_ft.nemo`, needed for optimal distillation. This fine-tuned teacher model is then trimmed. There are two methods to prune a model: depth-pruning and width-pruning. We will explore both techniques, yielding `4b_depth_pruned_model.nemo` and `4b_width_pruned_model.nemo`, respectively. These models will serve as starting points for distillation to create the final distilled 4B models.\n", "\n", "> We are using models utilizing the `meta-llama/Meta-Llama-3.1-8B` tokenizer for this demonstration.\n", "\n", "> `NOTE:` Ensure that you run this notebook inside the [NeMo Framework container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/nemo) which has all the required dependencies. \n", "\n", - "**Instructions are available in the associated tutorial README to download the model and the container.**" + "**Instructions for downloading the model and the container are available in the [README](./README.rst).**" ] }, { @@ -49,8 +49,8 @@ "source": [ "---\n", "## Prerequisites\n", - "Ensure you have the following -\n", - "1. **Get the teacher model**: Download the `Meta Llama 3.1 8B .nemo` model. You must follow the instructions in the associated README to download and mount the folder to the NeMo FW container." + "Ensure you meet the prerequisites listed in this section.\n", + "1. **Get the teacher model**: Download the `Meta Llama 3.1 8B .nemo` model. You must follow the instructions in the associated README to download and mount the folder to the NeMo Framework container." ] }, { @@ -149,12 +149,12 @@ }, "source": [ "---\n", - "## Step-by-step instructions\n", + "## Step-by-Step Instructions\n", "\n", "This workflow is structured into seven notebooks:\n", "1. [Prepare the dataset](./01_data_preparation.ipynb)\n", - "2. [Finetune the teacher on the dataset](./02_teacher_finetuning.ipynb)\n", - "3. Prune the finetuned-teacher model to create a student \n", + "2. [Fine-tune the teacher on the dataset](./02_teacher_finetuning.ipynb)\n", + "3. Prune the fine-tuned teacher model to create a student \n", " - 3.a. [Using depth-pruning](./03_a_depth_pruning.ipynb)\n", " - 3.b. [Using width-pruning](./03_b_width_pruning.ipynb)\n", "4. Distill knowledge from teacher into student\n", @@ -162,7 +162,7 @@ " - 4.b. [Using width-pruned student](./04_b_distilling_width_pruned_student.ipynb)\n", "5. [Display the validation loss](./05_display_results.ipynb)\n", "\n", - "> `NOTE:` We are exploring two methods to prune the finetuned teacher model: [depth-pruning](./03_a_depth_pruning.ipynb) and [width-pruning](./03_b_width_pruning.ipynb). Per the [tech report](https://arxiv.org/pdf/2408.11796), we can observe that width-pruning generally outperforms depth-pruning so users can choose to perform either [depth-pruning](./03_a_depth_pruning.ipynb) or [width-pruning](./03_b_width_pruning.ipynb) or both methods." + "> `NOTE:` We are exploring two methods to prune the fine-tuned teacher model: [depth-pruning](./03_a_depth_pruning.ipynb) and [width-pruning](./03_b_width_pruning.ipynb). Per the [tech report](https://arxiv.org/pdf/2408.11796), we can observe that width-pruning generally outperforms depth-pruning so users can choose to perform either [depth-pruning](./03_a_depth_pruning.ipynb) or [width-pruning](./03_b_width_pruning.ipynb) or both methods." ] } ], diff --git a/tutorials/nlp/ITN_with_Thutmose_Tagger.ipynb b/tutorials/nlp/ITN_with_Thutmose_Tagger.ipynb index 6204bf2516bb..b028b2d5c190 100644 --- a/tutorials/nlp/ITN_with_Thutmose_Tagger.ipynb +++ b/tutorials/nlp/ITN_with_Thutmose_Tagger.ipynb @@ -249,7 +249,7 @@ "\n", "import wget \n", "import torch\n", - "import pytorch_lightning as pl\n", + "import lightning.pytorch as pl\n", "from omegaconf import OmegaConf\n", "import pandas as pd" ] diff --git a/tutorials/nlp/Punctuation_and_Capitalization.ipynb b/tutorials/nlp/Punctuation_and_Capitalization.ipynb index f88c33fada34..cbdab3941b6f 100644 --- a/tutorials/nlp/Punctuation_and_Capitalization.ipynb +++ b/tutorials/nlp/Punctuation_and_Capitalization.ipynb @@ -72,7 +72,7 @@ "import os\n", "import wget \n", "import torch\n", - "import pytorch_lightning as pl\n", + "import lightning.pytorch as pl\n", "from omegaconf import OmegaConf" ] }, diff --git a/tutorials/nlp/Punctuation_and_Capitalization_Lexical_Audio.ipynb b/tutorials/nlp/Punctuation_and_Capitalization_Lexical_Audio.ipynb index 2afbb19c0e66..51d3a66c91fc 100644 --- a/tutorials/nlp/Punctuation_and_Capitalization_Lexical_Audio.ipynb +++ b/tutorials/nlp/Punctuation_and_Capitalization_Lexical_Audio.ipynb @@ -74,7 +74,7 @@ "import os\n", "import wget\n", "import torch\n", - "import pytorch_lightning as pl\n", + "import lightning.pytorch as pl\n", "from omegaconf import OmegaConf" ] }, diff --git a/tutorials/nlp/Relation_Extraction-BioMegatron.ipynb b/tutorials/nlp/Relation_Extraction-BioMegatron.ipynb index d6b1e98b428e..3c9e427e7e09 100644 --- a/tutorials/nlp/Relation_Extraction-BioMegatron.ipynb +++ b/tutorials/nlp/Relation_Extraction-BioMegatron.ipynb @@ -71,7 +71,7 @@ "import os\n", "import wget\n", "import torch\n", - "import pytorch_lightning as pl\n", + "import lightning.pytorch as pl\n", "from omegaconf import OmegaConf" ] }, diff --git a/tutorials/nlp/Text_Classification_Sentiment_Analysis.ipynb b/tutorials/nlp/Text_Classification_Sentiment_Analysis.ipynb index fdcff979ea46..0ed846881d02 100644 --- a/tutorials/nlp/Text_Classification_Sentiment_Analysis.ipynb +++ b/tutorials/nlp/Text_Classification_Sentiment_Analysis.ipynb @@ -58,7 +58,7 @@ "import os\n", "import wget \n", "import torch\n", - "import pytorch_lightning as pl\n", + "import lightning.pytorch as pl\n", "from omegaconf import OmegaConf" ] }, diff --git a/tutorials/nlp/Token_Classification-BioMegatron.ipynb b/tutorials/nlp/Token_Classification-BioMegatron.ipynb index 85cb769b28c0..a59eae67dde1 100644 --- a/tutorials/nlp/Token_Classification-BioMegatron.ipynb +++ b/tutorials/nlp/Token_Classification-BioMegatron.ipynb @@ -45,7 +45,7 @@ "import os\n", "import wget \n", "import torch\n", - "import pytorch_lightning as pl\n", + "import lightning.pytorch as pl\n", "from omegaconf import OmegaConf" ] }, diff --git a/tutorials/nlp/Token_Classification_Named_Entity_Recognition.ipynb b/tutorials/nlp/Token_Classification_Named_Entity_Recognition.ipynb index 3ab98f6c19fd..4c34c293dcca 100644 --- a/tutorials/nlp/Token_Classification_Named_Entity_Recognition.ipynb +++ b/tutorials/nlp/Token_Classification_Named_Entity_Recognition.ipynb @@ -94,7 +94,7 @@ "import os\n", "import wget \n", "import torch\n", - "import pytorch_lightning as pl\n", + "import lightning.pytorch as pl\n", "from omegaconf import OmegaConf" ], "execution_count": null, diff --git a/tutorials/nlp/Zero_Shot_Intent_Recognition.ipynb b/tutorials/nlp/Zero_Shot_Intent_Recognition.ipynb index 7f1baf536d87..b1eca63b8fd1 100644 --- a/tutorials/nlp/Zero_Shot_Intent_Recognition.ipynb +++ b/tutorials/nlp/Zero_Shot_Intent_Recognition.ipynb @@ -66,7 +66,7 @@ "from nemo.utils import logging\n", "from omegaconf import OmegaConf\n", "import pandas as pd\n", - "import pytorch_lightning as pl\n", + "import lightning.pytorch as pl\n", "import torch\n", "import wget " ] diff --git a/tutorials/nlp/lora.ipynb b/tutorials/nlp/lora.ipynb index c67fa6c2de15..0429dd7f053c 100644 --- a/tutorials/nlp/lora.ipynb +++ b/tutorials/nlp/lora.ipynb @@ -422,7 +422,7 @@ "source": [ "from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy\n", "import torch\n", - "import pytorch_lightning as pl\n", + "import lightning.pytorch as pl\n", "from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronTrainerBuilder\n", "\n", "# let's modify some trainer configs\n", diff --git a/tutorials/speaker_tasks/Speaker_Diarization_Training.ipynb b/tutorials/speaker_tasks/Speaker_Diarization_Training.ipynb index 7db905b6d225..c193e6600666 100644 --- a/tutorials/speaker_tasks/Speaker_Diarization_Training.ipynb +++ b/tutorials/speaker_tasks/Speaker_Diarization_Training.ipynb @@ -777,7 +777,7 @@ "metadata": {}, "outputs": [], "source": [ - "import pytorch_lightning as pl\n", + "import lightning.pytorch as pl\n", "from nemo.collections.asr.models import EncDecDiarLabelModel\n", "from nemo.utils.exp_manager import exp_manager\n", "\n", diff --git a/tutorials/speaker_tasks/Speaker_Identification_Verification.ipynb b/tutorials/speaker_tasks/Speaker_Identification_Verification.ipynb index 27a01b894eae..c4f7fbaca67e 100644 --- a/tutorials/speaker_tasks/Speaker_Identification_Verification.ipynb +++ b/tutorials/speaker_tasks/Speaker_Identification_Verification.ipynb @@ -438,7 +438,7 @@ "outputs": [], "source": [ "import torch\n", - "import pytorch_lightning as pl" + "import lightning.pytorch as pl" ] }, { diff --git a/tutorials/tools/DefinedCrowd_x_NeMo_ASR_Training_Tutorial.ipynb b/tutorials/tools/DefinedCrowd_x_NeMo_ASR_Training_Tutorial.ipynb index afd202f99d4a..8b0114690540 100644 --- a/tutorials/tools/DefinedCrowd_x_NeMo_ASR_Training_Tutorial.ipynb +++ b/tutorials/tools/DefinedCrowd_x_NeMo_ASR_Training_Tutorial.ipynb @@ -1636,7 +1636,7 @@ "outputId": "67209ee3-5161-40dc-a179-83d8219c3d71" }, "source": [ - "import pytorch_lightning as pl\n", + "import lightning.pytorch as pl\n", "from omegaconf import DictConfig\n", "import copy\n", "\n", diff --git a/tutorials/tts/Tacotron2_Training.ipynb b/tutorials/tts/Tacotron2_Training.ipynb index 79546bb79db9..edc814cf12ec 100644 --- a/tutorials/tts/Tacotron2_Training.ipynb +++ b/tutorials/tts/Tacotron2_Training.ipynb @@ -178,7 +178,7 @@ "Let's take a look at the tacotron2.py file\n", "\n", "```python\n", - "import pytorch_lightning as pl\n", + "import lightning.pytorch as pl\n", "\n", "from nemo.collections.common.callbacks import LogEpochTimeCallback\n", "from nemo.collections.tts.models import Tacotron2Model\n", diff --git a/tutorials/tts/Vits_Training.ipynb b/tutorials/tts/Vits_Training.ipynb index 9d3919e8dc6a..060c6bda43bb 100644 --- a/tutorials/tts/Vits_Training.ipynb +++ b/tutorials/tts/Vits_Training.ipynb @@ -191,7 +191,7 @@ "Let's take a look at the vits.py file\n", "\n", "```python\n", - "import pytorch_lightning as pl\n", + "import lightning.pytorch as pl\n", "\n", "from nemo.collections.tts.models.vits import VitsModel\n", "from nemo.core.config import hydra_runner\n",