diff --git a/.github/workflows/_test_template.yml b/.github/workflows/_test_template.yml index 3a1f69243c39b..1e184a8d41603 100644 --- a/.github/workflows/_test_template.yml +++ b/.github/workflows/_test_template.yml @@ -60,7 +60,16 @@ jobs: ARG=("--runtime=nvidia --gpus all") fi - docker run --rm -d --name nemo_container_${{ github.run_id }} ${ARG[@]} --shm-size=64g --env TRANSFORMERS_OFFLINE=0 --env HYDRA_FULL_ERROR=1 --volume /mnt/datadrive/TestData:/home/TestData nemoci.azurecr.io/nemo_container:${{ github.run_id }} bash -c "sleep $(( ${{ inputs.TIMEOUT }} * 60 + 60 ))" + docker run \ + --rm \ + -d \ + --name nemo_container_${{ github.run_id }} ${ARG[@]} \ + --shm-size=64g \ + --env TRANSFORMERS_OFFLINE=0 \ + --env HYDRA_FULL_ERROR=1 \ + --env HF_HOME=/home/TestData/HF_HOME \ + --volume /mnt/datadrive/TestData:/home/TestData nemoci.azurecr.io/nemo_container:${{ github.run_id }} \ + bash -c "sleep $(( ${{ inputs.TIMEOUT }} * 60 + 60 ))" - id: main name: Run main script @@ -95,4 +104,4 @@ jobs: if: always() run: | docker container stop nemo_container_${{ github.run_id }} || true - docker container rm nemo_container_${{ github.run_id }} || true \ No newline at end of file + docker container rm nemo_container_${{ github.run_id }} || true diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 03474251f995e..81db8e1160d91 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -20,10 +20,15 @@ on: description: Ref (SHA or branch name) to release required: true type: string + dry-run: + description: Do not publish a wheel and GitHub release. + required: true + default: true + type: boolean jobs: release: - uses: NVIDIA/NeMo-FW-CI-templates/.github/workflows/_release_library.yml@v0.12.3 + uses: NVIDIA/NeMo-FW-CI-templates/.github/workflows/_release_library.yml@v0.15.0 with: release-ref: ${{ inputs.release-ref }} image-name: nemo_container @@ -35,8 +40,10 @@ jobs: python-package: nemo container-workdir: /workspace library-name: Neural Modules + dry-run: ${{ inputs.dry-run }} secrets: TWINE_USERNAME: ${{ secrets.TWINE_USERNAME }} TWINE_PASSWORD: ${{ secrets.TWINE_PASSWORD }} SLACK_RELEASE_ENDPOINT: ${{ secrets.SLACK_RELEASE_ENDPOINT }} PAT: ${{ secrets.PAT }} + SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }} diff --git a/examples/llm/peft/hf.py b/examples/llm/peft/hf.py index 5b24c22ab79d4..357dc5a7bd172 100644 --- a/examples/llm/peft/hf.py +++ b/examples/llm/peft/hf.py @@ -76,11 +76,11 @@ def formatting_prompts_func(examples): # See: https://github.com/Lightning-AI/pytorch-lightning/blob/8ad3e29816a63d8ce5c00ac104b14729a4176f4f/src/lightning/pytorch/plugins/precision/fsdp.py#L81 grad_clip = None use_dist_samp = False - tokenizer = llm.HfAutoModelForCausalLM.configure_tokenizer(args.model) + tokenizer = llm.HFAutoModelForCausalLM.configure_tokenizer(args.model) llm.api.finetune( - model=llm.HfAutoModelForCausalLM(args.model), - data=llm.HfDatasetDataModule( + model=llm.HFAutoModelForCausalLM(args.model), + data=llm.HFDatasetDataModule( mk_hf_dataset(tokenizer.tokenizer), pad_token_id=tokenizer.tokenizer.eos_token_id ), trainer=nl.Trainer( diff --git a/examples/llm/sft/hf.py b/examples/llm/sft/hf.py index 1d282312b1305..ce79e136a1c2e 100755 --- a/examples/llm/sft/hf.py +++ b/examples/llm/sft/hf.py @@ -84,7 +84,7 @@ def squad(tokenizer) -> pl.LightningDataModule: from nemo.lightning.pytorch.accelerate.transformer_engine import te_accelerate - model = llm.HfAutoModelForCausalLM(model_name=args.model, model_accelerator=model_accelerator) + model = llm.HFAutoModelForCausalLM(model_name=args.model, model_accelerator=model_accelerator) tokenizer = model.tokenizer llm.api.finetune( diff --git a/examples/speaker_tasks/diarization/conf/neural_diarizer/sortformer_diarizer_hybrid_loss_4spk-v1.yaml b/examples/speaker_tasks/diarization/conf/neural_diarizer/sortformer_diarizer_hybrid_loss_4spk-v1.yaml new file mode 100644 index 0000000000000..66cfc5fd1b61a --- /dev/null +++ b/examples/speaker_tasks/diarization/conf/neural_diarizer/sortformer_diarizer_hybrid_loss_4spk-v1.yaml @@ -0,0 +1,213 @@ +# Sortformer Diarizer is an end-to-end speaker diarization model that is solely based on Transformer-encoder type of architecture. +# Model name convention for Sortformer Diarizer: sortformer_diarizer__-.yaml +# (Example) `sortformer_diarizer_hybrid_loss_4spk-v1.yaml`. +# Sortformer Diarizer model checkpoint (.ckpt) and NeMo file (.nemo) contain Fast Conformer Encoder model (NEST Encoder) and the pre-trained NEST model is loaded along with the Transformer Encoder layers. +# Example: a manifest line for training +# {"audio_filepath": "/path/to/audio01.wav", "offset": 390.83, "duration": 90.00, "text": "-", "num_speakers": 2, "rttm_filepath": "/path/to/audio01.rttm"} +name: "SortFormerDiarizer" +num_workers: 18 +batch_size: 8 + +model: + sample_rate: 16000 + pil_weight: 0.5 # Weight for Permutation Invariant Loss (PIL) used in training the Sortformer diarizer model + ats_weight: 0.5 # Weight for Arrival Time Sort (ATS) loss in training the Sortformer diarizer model + max_num_of_spks: 4 # Maximum number of speakers per model; currently set to 4 + + model_defaults: + fc_d_model: 512 # Hidden dimension size of the Fast-conformer Encoder + tf_d_model: 192 # Hidden dimension size of the Transformer Encoder + + train_ds: + manifest_filepath: ??? + sample_rate: ${model.sample_rate} + num_spks: ${model.max_num_of_spks} + session_len_sec: 90 # Maximum session length in seconds + soft_label_thres: 0.5 # Threshold for binarizing target values; higher values make the model more conservative in predicting speaker activity. + soft_targets: False # If True, use continuous values as target values when calculating cross-entropy loss + labels: null + batch_size: ${batch_size} + shuffle: True + num_workers: ${num_workers} + validation_mode: False + # lhotse config + use_lhotse: False + use_bucketing: True + num_buckets: 10 + bucket_duration_bins: [10, 20, 30, 40, 50, 60, 70, 80, 90] + pin_memory: True + min_duration: 10 + max_duration: 90 + batch_duration: 400 + quadratic_duration: 1200 + bucket_buffer_size: 20000 + shuffle_buffer_size: 10000 + window_stride: ${model.preprocessor.window_stride} + subsampling_factor: ${model.encoder.subsampling_factor} + + validation_ds: + manifest_filepath: ??? + is_tarred: False + tarred_audio_filepaths: null + sample_rate: ${model.sample_rate} + num_spks: ${model.max_num_of_spks} + session_len_sec: 90 # Maximum session length in seconds + soft_label_thres: 0.5 # A threshold value for setting up the binarized labels. The higher the more conservative the model becomes. + soft_targets: False + labels: null + batch_size: ${batch_size} + shuffle: False + num_workers: ${num_workers} + validation_mode: True + # lhotse config + use_lhotse: False + use_bucketing: False + drop_last: False + pin_memory: True + window_stride: ${model.preprocessor.window_stride} + subsampling_factor: ${model.encoder.subsampling_factor} + + test_ds: + manifest_filepath: null + is_tarred: False + tarred_audio_filepaths: null + sample_rate: 16000 + num_spks: ${model.max_num_of_spks} + session_len_sec: 90 # Maximum session length in seconds + soft_label_thres: 0.5 + soft_targets: False + labels: null + batch_size: ${batch_size} + shuffle: False + seq_eval_mode: True + num_workers: ${num_workers} + validation_mode: True + # lhotse config + use_lhotse: False + use_bucketing: False + drop_last: False + pin_memory: True + window_stride: ${model.preprocessor.window_stride} + subsampling_factor: ${model.encoder.subsampling_factor} + + preprocessor: + _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor + normalize: "per_feature" + window_size: 0.025 + sample_rate: ${model.sample_rate} + window_stride: 0.01 + window: "hann" + features: 80 + n_fft: 512 + frame_splicing: 1 + dither: 0.00001 + + sortformer_modules: + _target_: nemo.collections.asr.modules.sortformer_modules.SortformerModules + num_spks: ${model.max_num_of_spks} # Number of speakers per model. This is currently fixed at 4. + dropout_rate: 0.5 # Dropout rate + fc_d_model: ${model.model_defaults.fc_d_model} + tf_d_model: ${model.model_defaults.tf_d_model} # Hidden layer size for linear layers in Sortformer Diarizer module + + encoder: + _target_: nemo.collections.asr.modules.ConformerEncoder + feat_in: ${model.preprocessor.features} + feat_out: -1 + n_layers: 18 + d_model: ${model.model_defaults.fc_d_model} + + # Sub-sampling parameters + subsampling: dw_striding # vggnet, striding, stacking or stacking_norm, dw_striding + subsampling_factor: 8 # must be power of 2 for striding and vggnet + subsampling_conv_channels: 256 # set to -1 to make it equal to the d_model + causal_downsampling: false + # Feed forward module's params + ff_expansion_factor: 4 + # Multi-headed Attention Module's params + self_attention_model: rel_pos # rel_pos or abs_pos + n_heads: 8 # may need to be lower for smaller d_models + # [left, right] specifies the number of steps to be seen from left and right of each step in self-attention + att_context_size: [-1, -1] # -1 means unlimited context + att_context_style: regular # regular or chunked_limited + xscaling: true # scales up the input embeddings by sqrt(d_model) + untie_biases: true # unties the biases of the TransformerXL layers + pos_emb_max_len: 5000 + # Convolution module's params + conv_kernel_size: 9 + conv_norm_type: 'batch_norm' # batch_norm or layer_norm or groupnormN (N specifies the number of groups) + conv_context_size: null + # Regularization + dropout: 0.1 # The dropout used in most of the Conformer Modules + dropout_pre_encoder: 0.1 # The dropout used before the encoder + dropout_emb: 0.0 # The dropout used for embeddings + dropout_att: 0.1 # The dropout for multi-headed attention modules + # Set to non-zero to enable stochastic depth + stochastic_depth_drop_prob: 0.0 + stochastic_depth_mode: linear # linear or uniform + stochastic_depth_start_layer: 1 + + transformer_encoder: + _target_: nemo.collections.asr.modules.transformer.transformer_encoders.TransformerEncoder + num_layers: 18 + hidden_size: ${model.model_defaults.tf_d_model} # Needs to be multiple of num_attention_heads + inner_size: 768 + num_attention_heads: 8 + attn_score_dropout: 0.5 + attn_layer_dropout: 0.5 + ffn_dropout: 0.5 + hidden_act: relu + pre_ln: False + pre_ln_final_layer_norm: True + + loss: + _target_: nemo.collections.asr.losses.bce_loss.BCELoss + weight: null # Weight for binary cross-entropy loss. Either `null` or list type input. (e.g. [0.5,0.5]) + reduction: mean + + lr: 0.0001 + optim: + name: adamw + lr: ${model.lr} + # optimizer arguments + betas: [0.9, 0.98] + weight_decay: 1e-3 + + sched: + name: InverseSquareRootAnnealing + warmup_steps: 2500 + warmup_ratio: null + min_lr: 1e-06 + +trainer: + devices: 1 # number of gpus (devices) + accelerator: gpu + max_epochs: 800 + max_steps: -1 # computed at runtime if not set + num_nodes: 1 + strategy: ddp_find_unused_parameters_true # Could be "ddp" + accumulate_grad_batches: 1 + deterministic: True + enable_checkpointing: False + logger: False + log_every_n_steps: 1 # Interval of logging. + val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations + +exp_manager: + use_datetime_version: False + exp_dir: null + name: ${name} + resume_if_exists: True + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + resume_ignore_no_checkpoint: True + create_tensorboard_logger: True + create_checkpoint_callback: True + create_wandb_logger: False + checkpoint_callback_params: + monitor: "val_f1_acc" + mode: "max" + save_top_k: 9 + every_n_epochs: 1 + wandb_logger_kwargs: + resume: True + name: null + project: null \ No newline at end of file diff --git a/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_callhome-part1.yaml b/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_callhome-part1.yaml new file mode 100644 index 0000000000000..9b7a9701c4f25 --- /dev/null +++ b/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_callhome-part1.yaml @@ -0,0 +1,13 @@ +# Postprocessing parameters for timestamp outputs from speaker diarization models. +# This speaker diarization postprocessing scheme is inspired by the postprocessing procedure in the following paper: +# Medennikov, Ivan, et al. "Target-Speaker Voice Activity Detection: a Novel Approach for Multi-Speaker Diarization in a Dinner Party Scenario." (2020). +# These parameters were optimized with hybrid-loss trained Sortformer model introduced in https://arxiv.org/pdf/2409.06656. +# These parameters were optimized on CallHome Dataset from the NIST SRE 2000 Disc8, especially from the part1 (callhome1) specified in: Kaldi, “Kaldi x-vector recipe v2,” https://github.com/kaldi-asr/kaldi/blob/master/egs/callhome_diarization/v2/run.sh +# Trial 24682 finished with value: 0.10257785779242055 and parameters: {'onset': 0.53, 'offset': 0.49, 'pad_onset': 0.23, 'pad_offset': 0.01, 'min_duration_on': 0.42, 'min_duration_off': 0.34}. Best is trial 24682 with value: 0.10257785779242055. +parameters: + onset: 0.53 # Onset threshold for detecting the beginning and end of a speech + offset: 0.49 # Offset threshold for detecting the end of a speech + pad_onset: 0.23 # Adding durations before each speech segment + pad_offset: 0.01 # Adding durations after each speech segment + min_duration_on: 0.42 # Threshold for small non-speech deletion + min_duration_off: 0.34 # Threshold for short speech segment deletion \ No newline at end of file diff --git a/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_dihard3-dev.yaml b/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_dihard3-dev.yaml new file mode 100644 index 0000000000000..ebf994c10f2ea --- /dev/null +++ b/examples/speaker_tasks/diarization/conf/post_processing/sortformer_diar_4spk-v1_dihard3-dev.yaml @@ -0,0 +1,13 @@ +# Postprocessing parameters for timestamp outputs from speaker diarization models. +# This speaker diarization postprocessing scheme is inspired by the postprocessing procedure in the following paper: +# Medennikov, Ivan, et al. "Target-Speaker Voice Activity Detection: a Novel Approach for Multi-Speaker Diarization in a Dinner Party Scenario." (2020). +# These parameters were optimized with hybrid-loss trained Sortformer model introduced in https://arxiv.org/pdf/2409.06656. +# These parameters were optimized on the development split of DIHARD3 dataset (See https://arxiv.org/pdf/2012.01477). +# Trial 732 finished with value: 0.12171946949255649 and parameters: {'onset': 0.64, 'offset': 0.74, 'pad_onset': 0.06, 'pad_offset': 0.0, 'min_duration_on': 0.1, 'min_duration_off': 0.15}. Best is trial 732 with value: 0.12171946949255649. +parameters: + onset: 0.64 # Onset threshold for detecting the beginning and end of a speech + offset: 0.74 # Offset threshold for detecting the end of a speech + pad_onset: 0.06 # Adding durations before each speech segment + pad_offset: 0.0 # Adding durations after each speech segment + min_duration_on: 0.1 # Threshold for small non-speech deletion + min_duration_off: 0.15 # Threshold for short speech segment deletion \ No newline at end of file diff --git a/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py b/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py new file mode 100644 index 0000000000000..1767a16cbe02b --- /dev/null +++ b/examples/speaker_tasks/diarization/neural_diarizer/e2e_diarize_speech.py @@ -0,0 +1,443 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This script provides an inference and evaluation script for end-to-end speaker diarization models. +The performance of the diarization model is measured using the Diarization Error Rate (DER). +If you want to evaluate its performance, the manifest JSON file should contain the corresponding RTTM +(Rich Transcription Time Marked) file. +Please refer to the NeMo Library Documentation for more details on data preparation for diarization inference: +https://docs.nvidia.com/nemo-framework/user-guide/latest/nemotoolkit +/asr/speaker_diarization/datasets.html#data-preparation-for-inference + +Usage for diarization inference: + +The end-to-end speaker diarization model can be specified by either "model_path" or "pretrained_name". +Data for diarization is fed through the "dataset_manifest". +By default, post-processing is bypassed, and only binarization is performed. +If you want to reproduce DER scores reported on NeMo model cards, you need to apply post-processing steps. +Use batch_size = 1 to have the longest inference window and the highest possible accuracy. + +python $BASEPATH/neural_diarizer/e2e_diarize_speech.py \ + model_path=/path/to/diar_sortformer_4spk_v1.nemo \ + batch_size=1 \ + dataset_manifest=/path/to/diarization_manifest.json + +""" +import logging +import os +import tempfile +from dataclasses import dataclass, is_dataclass +from typing import Dict, List, Optional, Union + +import lightning.pytorch as pl +import optuna +import torch +import yaml +from omegaconf import OmegaConf +from pytorch_lightning import seed_everything +from tqdm import tqdm + +from nemo.collections.asr.metrics.der import score_labels +from nemo.collections.asr.models import SortformerEncLabelModel +from nemo.collections.asr.parts.utils.speaker_utils import audio_rttm_map, timestamps_to_pyannote_object +from nemo.collections.asr.parts.utils.vad_utils import ts_vad_post_processing +from nemo.core.config import hydra_runner + +seed_everything(42) +torch.backends.cudnn.deterministic = True + + +@dataclass +class PostProcessingParams: + """ + Postprocessing parameters for end-to-end speaker diarization models. + These parameters can significantly affect DER performance depending on the evaluation style and the dataset. + It is recommended to tune these parameters based on the evaluation style and the dataset + to achieve the desired DER performance. + """ + + onset: float = 0.5 # Onset threshold for detecting the beginning and end of a speech + offset: float = 0.5 # Offset threshold for detecting the end of a speech + pad_onset: float = 0.0 # Adding durations before each speech segment + pad_offset: float = 0.0 # Adding durations after each speech segment + min_duration_on: float = 0.0 # Threshold for small non-speech deletion + min_duration_off: float = 0.0 # Threshold for short speech segment deletion + + +@dataclass +class DiarizationConfig: + """Diarization configuration parameters for inference.""" + + model_path: Optional[str] = None # Path to a .nemo file + pretrained_name: Optional[str] = None # Name of a pretrained model + audio_dir: Optional[str] = None # Path to a directory which contains audio files + dataset_manifest: Optional[str] = None # Path to dataset's JSON manifest + + postprocessing_yaml: Optional[str] = None # Path to a yaml file for postprocessing configurations + no_der: bool = False + out_rttm_dir: Optional[str] = None + + # General configs + session_len_sec: float = -1 # End-to-end diarization session length in seconds + batch_size: int = 1 + num_workers: int = 0 + random_seed: Optional[int] = None # seed number going to be used in seed_everything() + bypass_postprocessing: bool = True # If True, postprocessing will be bypassed + + # Eval Settings: (0.25, False) should be default setting for sortformer eval. + collar: float = 0.25 # Collar in seconds for DER calculation + ignore_overlap: bool = False # If True, DER will be calculated only for non-overlapping segments + + # If `cuda` is a negative number, inference will be on CPU only. + cuda: Optional[int] = None + matmul_precision: str = "highest" # Literal["highest", "high", "medium"] + + # Optuna Config + launch_pp_optim: bool = False # If True, launch optimization process for postprocessing parameters + optuna_study_name: str = "optim_postprocessing" + optuna_temp_dir: str = "/tmp/optuna" + optuna_storage: str = f"sqlite:///{optuna_study_name}.db" + optuna_log_file: str = f"{optuna_study_name}.log" + optuna_n_trials: int = 100000 + + +def load_postprocessing_from_yaml(postprocessing_yaml: PostProcessingParams = None) -> PostProcessingParams: + """ + Load postprocessing parameters from a YAML file. + + Args: + postprocessing_yaml (str): + Path to a YAML file for postprocessing configurations. + + Returns: + postprocessing_params (dataclass): + Postprocessing parameters loaded from the YAML file. + """ + # Add PostProcessingParams as a field + postprocessing_params = OmegaConf.structured(PostProcessingParams()) + if postprocessing_yaml is None: + logging.info( + f"No postprocessing YAML file has been provided. Default postprocessing configurations will be applied." + ) + else: + # Load postprocessing params from the provided YAML file + with open(postprocessing_yaml, 'r') as file: + yaml_params = yaml.safe_load(file)['parameters'] + # Update the postprocessing_params with the loaded values + logging.info(f"Postprocessing YAML file '{postprocessing_yaml}' has been loaded.") + for key, value in yaml_params.items(): + if hasattr(postprocessing_params, key): + setattr(postprocessing_params, key, value) + return postprocessing_params + + +def optuna_suggest_params(postprocessing_cfg: PostProcessingParams, trial: optuna.Trial) -> PostProcessingParams: + """ + Suggests hyperparameters for postprocessing using Optuna. + See the following link for `trial` instance in Optuna framework. + https://optuna.readthedocs.io/en/stable/reference/generated/optuna.trial.Trial.html#optuna.trial.Trial + + Args: + postprocessing_cfg (PostProcessingParams): The current postprocessing configuration. + trial (optuna.Trial): The Optuna trial object used to suggest hyperparameters. + + Returns: + PostProcessingParams: The updated postprocessing configuration with suggested hyperparameters. + """ + postprocessing_cfg.onset = trial.suggest_float("onset", 0.4, 0.8, step=0.01) + postprocessing_cfg.offset = trial.suggest_float("offset", 0.4, 0.9, step=0.01) + postprocessing_cfg.pad_onset = trial.suggest_float("pad_onset", 0.1, 0.5, step=0.01) + postprocessing_cfg.pad_offset = trial.suggest_float("pad_offset", 0.0, 0.2, step=0.01) + postprocessing_cfg.min_duration_on = trial.suggest_float("min_duration_on", 0.0, 0.75, step=0.01) + postprocessing_cfg.min_duration_off = trial.suggest_float("min_duration_off", 0.0, 0.75, step=0.01) + return postprocessing_cfg + + +def get_tensor_path(cfg: DiarizationConfig) -> str: + """ + Constructs the file path for saving or loading prediction tensors based on the configuration. + + Args: + cfg (DiarizationConfig): The configuration object containing model and dataset details. + + Returns: + str: The constructed file path for the prediction tensor. + """ + tensor_filename = os.path.basename(cfg.dataset_manifest).replace("manifest.", "").replace(".json", "") + model_base_path = os.path.dirname(cfg.model_path) + model_id = os.path.basename(cfg.model_path).replace(".ckpt", "").replace(".nemo", "") + bpath = f"{model_base_path}/pred_tensors" + if not os.path.exists(bpath): + os.makedirs(bpath) + tensor_path = f"{bpath}/__{model_id}__{tensor_filename}.pt" + return tensor_path + + +def diarization_objective( + trial: optuna.Trial, + postprocessing_cfg: PostProcessingParams, + temp_out_dir: str, + infer_audio_rttm_dict: Dict[str, Dict[str, str]], + diar_model_preds_total_list: List[torch.Tensor], + collar: float = 0.25, + ignore_overlap: bool = False, +) -> float: + """ + Objective function for Optuna hyperparameter optimization in speaker diarization. + + This function evaluates the diarization performance using a set of postprocessing parameters + suggested by Optuna. It converts prediction matrices to time-stamp segments, scores the + diarization results, and returns the Diarization Error Rate (DER) as the optimization metric. + + Args: + trial (optuna.Trial): The Optuna trial object used to suggest hyperparameters. + postprocessing_cfg (PostProcessingParams): The current postprocessing configuration. + temp_out_dir (str): Temporary directory for storing intermediate outputs. + infer_audio_rttm_dict (Dict[str, Dict[str, str]]): Dictionary containing audio file paths, + offsets, durations, and RTTM file paths. + diar_model_preds_total_list (List[torch.Tensor]): List of prediction matrices containing + sigmoid values for each speaker. + Dimension: [(1, num_frames, num_speakers), ..., (1, num_frames, num_speakers)] + collar (float, optional): Collar in seconds for DER calculation. Defaults to 0.25. + ignore_overlap (bool, optional): If True, DER will be calculated only for non-overlapping segments. + Defaults to False. + + Returns: + float: The Diarization Error Rate (DER) for the given set of postprocessing parameters. + """ + with tempfile.TemporaryDirectory(dir=temp_out_dir, prefix="Diar_PostProcessing_") as local_temp_out_dir: + if trial is not None: + postprocessing_cfg = optuna_suggest_params(postprocessing_cfg, trial) + all_hyps, all_refs, all_uems = convert_pred_mat_to_segments( + audio_rttm_map_dict=infer_audio_rttm_dict, + postprocessing_cfg=postprocessing_cfg, + batch_preds_list=diar_model_preds_total_list, + unit_10ms_frame_count=8, + bypass_postprocessing=False, + ) + metric, mapping_dict, itemized_errors = score_labels( + AUDIO_RTTM_MAP=infer_audio_rttm_dict, + all_reference=all_refs, + all_hypothesis=all_hyps, + all_uem=all_uems, + collar=collar, + ignore_overlap=ignore_overlap, + ) + der = abs(metric) + return der + + +def run_optuna_hyperparam_search( + cfg: DiarizationConfig, # type: DiarizationConfig + postprocessing_cfg: PostProcessingParams, + infer_audio_rttm_dict: Dict[str, Dict[str, str]], + preds_list: List[torch.Tensor], + temp_out_dir: str, +): + """ + Run Optuna hyperparameter optimization for speaker diarization. + + Args: + cfg (DiarizationConfig): The configuration object containing model and dataset details. + postprocessing_cfg (PostProcessingParams): The current postprocessing configuration. + infer_audio_rttm_dict (dict): dictionary of audio file path, offset, duration and RTTM filepath. + preds_list (List[torch.Tensor]): list of prediction matrices containing sigmoid values for each speaker. + Dimension: [(1, num_frames, num_speakers), ..., (1, num_frames, num_speakers)] + temp_out_dir (str): temporary directory for storing intermediate outputs. + """ + worker_function = lambda trial: diarization_objective( + trial=trial, + postprocessing_cfg=postprocessing_cfg, + temp_out_dir=temp_out_dir, + infer_audio_rttm_dict=infer_audio_rttm_dict, + diar_model_preds_total_list=preds_list, + collar=cfg.collar, + ) + study = optuna.create_study( + direction="minimize", study_name=cfg.optuna_study_name, storage=cfg.optuna_storage, load_if_exists=True + ) + logger = logging.getLogger() + logger.setLevel(logging.INFO) # Setup the root logger. + if cfg.optuna_log_file is not None: + logger.addHandler(logging.FileHandler(cfg.optuna_log_file, mode="a")) + logger.addHandler(logging.StreamHandler()) + optuna.logging.enable_propagation() # Propagate logs to the root logger. + study.optimize(worker_function, n_trials=cfg.optuna_n_trials) + + +def convert_pred_mat_to_segments( + audio_rttm_map_dict: Dict[str, Dict[str, str]], + postprocessing_cfg, + batch_preds_list: List[torch.Tensor], + unit_10ms_frame_count: int = 8, + bypass_postprocessing: bool = False, + out_rttm_dir: str | None = None, +): + """ + Convert prediction matrix to time-stamp segments. + + Args: + audio_rttm_map_dict (dict): dictionary of audio file path, offset, duration and RTTM filepath. + batch_preds_list (List[torch.Tensor]): list of prediction matrices containing sigmoid values for each speaker. + Dimension: [(1, num_frames, num_speakers), ..., (1, num_frames, num_speakers)] + unit_10ms_frame_count (int, optional): number of 10ms segments in a frame. Defaults to 8. + bypass_postprocessing (bool, optional): if True, postprocessing will be bypassed. Defaults to False. + + Returns: + all_hypothesis (list): list of pyannote objects for each audio file. + all_reference (list): list of pyannote objects for each audio file. + all_uems (list): list of pyannote objects for each audio file. + """ + batch_pred_ts_segs, all_hypothesis, all_reference, all_uems = [], [], [], [] + cfg_vad_params = OmegaConf.structured(postprocessing_cfg) + pp_message = "Bypass PP, Running Binarization" if bypass_postprocessing else "Running post-processing" + for sample_idx, (uniq_id, audio_rttm_values) in tqdm( + enumerate(audio_rttm_map_dict.items()), total=len(audio_rttm_map_dict), desc=pp_message + ): + spk_ts = [] + offset, duration = audio_rttm_values['offset'], audio_rttm_values['duration'] + speaker_assign_mat = batch_preds_list[sample_idx].squeeze(dim=0) + speaker_timestamps = [[] for _ in range(speaker_assign_mat.shape[-1])] + for spk_id in range(speaker_assign_mat.shape[-1]): + ts_mat = ts_vad_post_processing( + speaker_assign_mat[:, spk_id], + cfg_vad_params=cfg_vad_params, + unit_10ms_frame_count=unit_10ms_frame_count, + bypass_postprocessing=bypass_postprocessing, + ) + ts_mat = ts_mat + offset + ts_mat = torch.clamp(ts_mat, min=offset, max=(offset + duration)) + ts_seg_list = ts_mat.tolist() + speaker_timestamps[spk_id].extend(ts_seg_list) + spk_ts.append(ts_seg_list) + all_hypothesis, all_reference, all_uems = timestamps_to_pyannote_object( + speaker_timestamps, + uniq_id, + audio_rttm_values, + all_hypothesis, + all_reference, + all_uems, + out_rttm_dir, + ) + batch_pred_ts_segs.append(spk_ts) + return all_hypothesis, all_reference, all_uems + + +@hydra_runner(config_name="DiarizationConfig", schema=DiarizationConfig) +def main(cfg: DiarizationConfig) -> Union[DiarizationConfig]: + """Main function for end-to-end speaker diarization inference.""" + for key in cfg: + cfg[key] = None if cfg[key] == 'None' else cfg[key] + + if is_dataclass(cfg): + cfg = OmegaConf.structured(cfg) + + if cfg.random_seed: + pl.seed_everything(cfg.random_seed) + + if cfg.model_path is None and cfg.pretrained_name is None: + raise ValueError("Both cfg.model_path and cfg.pretrained_name cannot be None!") + if cfg.audio_dir is None and cfg.dataset_manifest is None: + raise ValueError("Both cfg.audio_dir and cfg.dataset_manifest cannot be None!") + + # setup GPU + torch.set_float32_matmul_precision(cfg.matmul_precision) + if cfg.cuda is None: + if torch.cuda.is_available(): + device = [0] # use 0th CUDA device + accelerator = 'gpu' + map_location = torch.device('cuda:0') + else: + device = 1 + accelerator = 'cpu' + map_location = torch.device('cpu') + else: + device = [cfg.cuda] + accelerator = 'gpu' + map_location = torch.device(f'cuda:{cfg.cuda}') + + if cfg.model_path.endswith(".ckpt"): + diar_model = SortformerEncLabelModel.load_from_checkpoint( + checkpoint_path=cfg.model_path, map_location=map_location, strict=False + ) + elif cfg.model_path.endswith(".nemo"): + diar_model = SortformerEncLabelModel.restore_from(restore_path=cfg.model_path, map_location=map_location) + else: + raise ValueError("cfg.model_path must end with.ckpt or.nemo!") + + diar_model._cfg.test_ds.session_len_sec = cfg.session_len_sec + trainer = pl.Trainer(devices=device, accelerator=accelerator) + diar_model.set_trainer(trainer) + + diar_model = diar_model.eval() + diar_model._cfg.test_ds.manifest_filepath = cfg.dataset_manifest + infer_audio_rttm_dict = audio_rttm_map(cfg.dataset_manifest) + diar_model._cfg.test_ds.batch_size = cfg.batch_size + + # Model setup for inference + diar_model._cfg.test_ds.num_workers = cfg.num_workers + diar_model.setup_test_data(test_data_config=diar_model._cfg.test_ds) + + postprocessing_cfg = load_postprocessing_from_yaml(cfg.postprocessing_yaml) + tensor_path = get_tensor_path(cfg) + + if os.path.exists(tensor_path): + logging.info( + f"A saved prediction tensor has been found. Loading the saved prediction tensors from {tensor_path}..." + ) + diar_model_preds_total_list = torch.load(tensor_path) + else: + logging.info(f"No saved prediction tensors found. Running inference on the dataset...") + diar_model.test_batch() + diar_model_preds_total_list = diar_model.preds_total_list + torch.save(diar_model.preds_total_list, tensor_path) + + if cfg.launch_pp_optim: + # Launch a hyperparameter optimization process if launch_pp_optim is True + run_optuna_hyperparam_search( + cfg=cfg, + postprocessing_cfg=postprocessing_cfg, + infer_audio_rttm_dict=infer_audio_rttm_dict, + preds_list=diar_model_preds_total_list, + temp_out_dir=cfg.optuna_temp_dir, + ) + + # Evaluation + if not cfg.no_der: + if cfg.out_rttm_dir is not None and not os.path.exists(cfg.out_rttm_dir): + os.mkdir(cfg.out_rttm_dir) + all_hyps, all_refs, all_uems = convert_pred_mat_to_segments( + infer_audio_rttm_dict, + postprocessing_cfg=postprocessing_cfg, + batch_preds_list=diar_model_preds_total_list, + unit_10ms_frame_count=8, + bypass_postprocessing=cfg.bypass_postprocessing, + out_rttm_dir=cfg.out_rttm_dir, + ) + logging.info(f"Evaluating the model on the {len(diar_model_preds_total_list)} audio segments...") + score_labels( + AUDIO_RTTM_MAP=infer_audio_rttm_dict, + all_reference=all_refs, + all_hypothesis=all_hyps, + all_uem=all_uems, + collar=cfg.collar, + ignore_overlap=cfg.ignore_overlap, + ) + logging.info(f"PostProcessingParams: {postprocessing_cfg}") + + +if __name__ == '__main__': + main() diff --git a/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py b/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py new file mode 100644 index 0000000000000..ab6e418b10729 --- /dev/null +++ b/examples/speaker_tasks/diarization/neural_diarizer/sortformer_diar_train.py @@ -0,0 +1,55 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import lightning.pytorch as pl +from omegaconf import OmegaConf +from pytorch_lightning import seed_everything + +from nemo.collections.asr.models import SortformerEncLabelModel +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.exp_manager import exp_manager + +""" +Example training session (single node training) + +python ./sortformer_diar_train.py --config-path='../conf/neural_diarizer' \ + --config-name='sortformer_diarizer_hybrid_loss_4spk-v1.yaml' \ + trainer.devices=1 \ + model.train_ds.manifest_filepath="" \ + model.validation_ds.manifest_filepath="" \ + exp_manager.name='sample_train' \ + exp_manager.exp_dir='./sortformer_diar_train' +""" + +seed_everything(42) + + +@hydra_runner(config_path="../conf/neural_diarizer", config_name="sortformer_diarizer_hybrid_loss_4spk-v1.yaml") +def main(cfg): + """Main function for training the sortformer diarizer model.""" + logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') + trainer = pl.Trainer(**cfg.trainer) + exp_manager(trainer, cfg.get("exp_manager", None)) + sortformer_model = SortformerEncLabelModel(cfg=cfg.model, trainer=trainer) + sortformer_model.maybe_init_from_pretrained_checkpoint(cfg) + trainer.fit(sortformer_model) + + if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.manifest_filepath is not None: + if sortformer_model.prepare_test(trainer): + trainer.test(sortformer_model) + + +if __name__ == '__main__': + main() diff --git a/nemo/collections/asr/data/audio_to_diar_label.py b/nemo/collections/asr/data/audio_to_diar_label.py index a1cb6d0f1bdcb..0824c9c6ab513 100644 --- a/nemo/collections/asr/data/audio_to_diar_label.py +++ b/nemo/collections/asr/data/audio_to_diar_label.py @@ -15,15 +15,20 @@ import os from collections import OrderedDict from statistics import mode -from typing import Dict, Optional +from typing import Dict, List, Optional, Tuple +import numpy as np import torch from nemo.collections.asr.parts.utils.offline_clustering import get_argmin_mat -from nemo.collections.asr.parts.utils.speaker_utils import convert_rttm_line, prepare_split_data -from nemo.collections.common.parts.preprocessing.collections import DiarizationSpeechLabel +from nemo.collections.asr.parts.utils.speaker_utils import convert_rttm_line, get_subsegments, prepare_split_data +from nemo.collections.common.parts.preprocessing.collections import ( + DiarizationSpeechLabel, + EndtoEndDiarizationSpeechLabel, +) from nemo.core.classes import Dataset from nemo.core.neural_types import AudioSignal, EncodedRepresentation, LengthsType, NeuralType, ProbsType +from nemo.utils import logging def get_scale_mapping_list(uniq_timestamps): @@ -62,7 +67,7 @@ def get_scale_mapping_list(uniq_timestamps): return scale_mapping_argmat -def extract_seg_info_from_rttm(uniq_id, rttm_lines, mapping_dict=None, target_spks=None): +def extract_seg_info_from_rttm(rttm_lines, mapping_dict=None, target_spks=None): """ Get RTTM lines containing speaker labels, start time and end time. target_spks contains two targeted speaker indices for creating groundtruth label files. Only speakers in target_spks variable will be @@ -76,7 +81,8 @@ def extract_seg_info_from_rttm(uniq_id, rttm_lines, mapping_dict=None, target_sp mapping_dict (dict): Mapping between the estimated speakers and the speakers in the ground-truth annotation. `mapping_dict` variable is only provided when the inference mode is running in sequence-eval mode. - Sequence eval mode uses the mapping between the estimated speakers and the speakers in ground-truth annotation. + Sequence eval mode uses the mapping between the estimated speakers and the speakers + in ground-truth annotation. Returns: rttm_tup (tuple): Tuple containing lists of start time, end time and speaker labels. @@ -108,12 +114,14 @@ def assign_frame_level_spk_vector(rttm_timestamps, round_digits, frame_per_sec, Args: rttm_timestamps (list): List containing start and end time for each speaker segment label. - stt_list, end_list and speaker_list are contained. + `stt_list`, `end_list` and `speaker_list` are contained. frame_per_sec (int): - Number of feature frames per second. This quantity is determined by window_stride variable in preprocessing module. + Number of feature frames per second. This quantity is determined by + `window_stride` variable in preprocessing module. target_spks (tuple): - Speaker indices that are generated from combinations. If there are only one or two speakers, - only a single target_spks variable is generated. + Speaker indices that are generated from combinations. + If there are only one or two speakers, + only a single `target_spks` variable is generated. Returns: fr_level_target (torch.tensor): @@ -124,7 +132,7 @@ def assign_frame_level_spk_vector(rttm_timestamps, round_digits, frame_per_sec, return None else: sorted_speakers = sorted(list(set(speaker_list))) - total_fr_len = int(max(end_list) * (10 ** round_digits)) + total_fr_len = int(max(end_list) * (10**round_digits)) spk_num = max(len(sorted_speakers), min_spks) speaker_mapping_dict = {rttm_key: x_int for x_int, rttm_key in enumerate(sorted_speakers)} fr_level_target = torch.zeros(total_fr_len, spk_num) @@ -139,6 +147,140 @@ def assign_frame_level_spk_vector(rttm_timestamps, round_digits, frame_per_sec, return fr_level_target +def get_subsegments_to_timestamps( + subsegments: List[Tuple[float, float]], feat_per_sec: int = 100, max_end_ts: float = None, decimals=2 +): + """ + Convert subsegment timestamps to scale timestamps by multiplying with the feature rate (`feat_per_sec`) + and rounding. Segment is consisted of many subsegments and sugsegments are equivalent to `frames` + in end-to-end speaker diarization models. + + Args: + subsegments (List[Tuple[float, float]]): + A list of tuples where each tuple contains the start and end times of a subsegment + (frames in end-to-end models). + >>> subsegments = [[t0_start, t0_duration], [t1_start, t1_duration],..., [tN_start, tN_duration]] + feat_per_sec (int, optional): + The number of feature frames per second. Defaults to 100. + max_end_ts (float, optional): + The maximum end timestamp to clip the results. If None, no clipping is applied. Defaults to None. + decimals (int, optional): + The number of decimal places to round the timestamps. Defaults to 2. + + Example: + Segments starting from 0.0 and ending at 69.2 seconds. + If hop-length is 0.08 and the subsegment (frame) length is 0.16 seconds, + there are 864 = (69.2 - 0.16)/0.08 + 1 subsegments (frames in end-to-end models) in this segment. + >>> subsegments = [[[0.0, 0.16], [0.08, 0.16], ..., [69.04, 0.16], [69.12, 0.08]] + + Returns: + ts (torch.tensor): + A tensor containing the scaled and rounded timestamps for each subsegment. + """ + seg_ts = (torch.tensor(subsegments) * feat_per_sec).float() + ts_round = torch.round(seg_ts, decimals=decimals) + ts = ts_round.long() + ts[:, 1] = ts[:, 0] + ts[:, 1] + if max_end_ts is not None: + ts = np.clip(ts, 0, int(max_end_ts * feat_per_sec)) + return ts + + +def extract_frame_info_from_rttm(offset, duration, rttm_lines, round_digits=3): + """ + Extracts RTTM lines containing speaker labels, start time, and end time for a given audio segment. + + Args: + uniq_id (str): Unique identifier for the audio file and corresponding RTTM file. + offset (float): The starting time offset for the segment of interest. + duration (float): The duration of the segment of interest. + rttm_lines (list): List of RTTM lines in string format. + round_digits (int, optional): Number of decimal places to round the start and end times. Defaults to 3. + + Returns: + rttm_mat (tuple): A tuple containing lists of start times, end times, and speaker labels. + sess_to_global_spkids (dict): A mapping from session-specific speaker indices to global speaker identifiers. + """ + rttm_stt, rttm_end = offset, offset + duration + stt_list, end_list, speaker_list, speaker_set = [], [], [], [] + sess_to_global_spkids = dict() + + for rttm_line in rttm_lines: + start, end, speaker = convert_rttm_line(rttm_line) + + # Skip invalid RTTM lines where the start time is greater than the end time. + if start > end: + continue + + # Check if the RTTM segment overlaps with the specified segment of interest. + if (end > rttm_stt and start < rttm_end) or (start < rttm_end and end > rttm_stt): + # Adjust the start and end times to fit within the segment of interest. + start, end = max(start, rttm_stt), min(end, rttm_end) + else: + continue + + # Round the start and end times to the specified number of decimal places. + end_list.append(round(end, round_digits)) + stt_list.append(round(start, round_digits)) + + # Assign a unique index to each speaker and maintain a mapping. + if speaker not in speaker_set: + speaker_set.append(speaker) + speaker_list.append(speaker_set.index(speaker)) + sess_to_global_spkids.update({speaker_set.index(speaker): speaker}) + + rttm_mat = (stt_list, end_list, speaker_list) + return rttm_mat, sess_to_global_spkids + + +def get_frame_targets_from_rttm( + rttm_timestamps: list, + offset: float, + duration: float, + round_digits: int, + feat_per_sec: int, + max_spks: int, +): + """ + Create a multi-dimensional vector sequence containing speaker timestamp information in RTTM. + The unit-length is the frame shift length of the acoustic feature. The feature-level annotations + `feat_level_target` will later be converted to base-segment level diarization label. + + Args: + rttm_timestamps (list): + List containing start and end time for each speaker segment label. + stt_list, end_list and speaker_list are contained. + feat_per_sec (int): + Number of feature frames per second. + This quantity is determined by window_stride variable in preprocessing module. + target_spks (tuple): + Speaker indices that are generated from combinations. If there are only one or two speakers, + only a single target_spks variable is generated. + + Returns: + feat_level_target (torch.tensor): + Tensor containing label for each feature level frame. + """ + stt_list, end_list, speaker_list = rttm_timestamps + sorted_speakers = sorted(list(set(speaker_list))) + total_fr_len = int(duration * feat_per_sec) + if len(sorted_speakers) > max_spks: + logging.warning( + f"Number of speakers in RTTM file {len(sorted_speakers)} exceeds the maximum number of speakers: " + f"{max_spks}! Only {max_spks} first speakers remain, and this will affect frame metrics!" + ) + feat_level_target = torch.zeros(total_fr_len, max_spks) + for count, (stt, end, spk_rttm_key) in enumerate(zip(stt_list, end_list, speaker_list)): + if end < offset or stt > offset + duration: + continue + stt, end = max(offset, stt), min(offset + duration, end) + spk = spk_rttm_key + if spk < max_spks: + stt_fr, end_fr = int((stt - offset) * feat_per_sec), int((end - offset) * feat_per_sec) + feat_level_target[stt_fr:end_fr, spk] = 1 + return feat_level_target + + class _AudioMSDDTrainDataset(Dataset): """ Dataset class that loads a json file containing paths to audio files, @@ -214,7 +356,7 @@ def __init__( self.multiscale_args_dict = multiscale_args_dict self.emb_dir = emb_dir self.round_digits = 2 - self.decim = 10 ** self.round_digits + self.decim = 10**self.round_digits self.soft_label_thres = soft_label_thres self.pairwise_infer = pairwise_infer self.max_spks = 2 @@ -224,7 +366,10 @@ def __init__( self.global_rank = global_rank self.manifest_filepath = manifest_filepath self.multiscale_timestamp_dict = prepare_split_data( - self.manifest_filepath, self.emb_dir, self.multiscale_args_dict, self.global_rank, + self.manifest_filepath, + self.emb_dir, + self.multiscale_args_dict, + self.global_rank, ) def __len__(self): @@ -241,7 +386,7 @@ def assign_labels_to_longer_segs(self, uniq_id, base_scale_clus_label): Unique sample ID for training. base_scale_clus_label (torch.tensor): Tensor variable containing the speaker labels for the base-scale segments. - + Returns: per_scale_clus_label (torch.tensor): Tensor variable containing the speaker labels for each segment in each scale. @@ -270,15 +415,17 @@ def assign_labels_to_longer_segs(self, uniq_id, base_scale_clus_label): def get_diar_target_labels(self, uniq_id, sample, fr_level_target): """ - Convert frame-level diarization target variable into segment-level target variable. Since the granularity is reduced - from frame level (10ms) to segment level (100ms~500ms), we need a threshold value, `soft_label_thres`, which determines - the label of each segment based on the overlap between a segment range (start and end time) and the frame-level target variable. + Convert frame-level diarization target variable into segment-level target variable. + Since the granularity is reduced from frame level (10ms) to segment level (100ms~500ms), + we need a threshold value, `soft_label_thres`, which determines the label of each segment + based on the overlap between a segment range (start and end time) and the frame-level target variable. Args: uniq_id (str): Unique file ID that refers to an input audio file and corresponding RTTM (Annotation) file. sample: - `DiarizationSpeechLabel` instance containing sample information such as audio filepath and RTTM filepath. + `DiarizationSpeechLabel` instance containing sample information such as + audio filepath and RTTM filepath. fr_level_target (torch.tensor): Tensor containing label for each feature-level frame. @@ -286,13 +433,14 @@ def get_diar_target_labels(self, uniq_id, sample, fr_level_target): seg_target (torch.tensor): Tensor containing binary speaker labels for base-scale segments. base_clus_label (torch.tensor): - Representative speaker label for each segment. This variable only has one speaker label for each base-scale segment. + Representative speaker label for each segment. This variable only has one speaker label + for each base-scale segment. -1 means that there is no corresponding speaker in the target_spks tuple. """ seg_target_list, base_clus_label = [], [] self.scale_n = len(self.multiscale_timestamp_dict[uniq_id]['scale_dict']) subseg_time_stamp_list = self.multiscale_timestamp_dict[uniq_id]["scale_dict"][self.scale_n - 1]["time_stamps"] - for (seg_stt, seg_end) in subseg_time_stamp_list: + for seg_stt, seg_end in subseg_time_stamp_list: seg_stt_fr, seg_end_fr = int(seg_stt * self.frame_per_sec), int(seg_end * self.frame_per_sec) soft_label_vec_sess = torch.sum(fr_level_target[seg_stt_fr:seg_end_fr, :], axis=0) / ( seg_end_fr - seg_stt_fr @@ -321,7 +469,8 @@ def parse_rttm_for_ms_targets(self, sample): Args: sample: - `DiarizationSpeechLabel` instance containing sample information such as audio filepath and RTTM filepath. + `DiarizationSpeechLabel` instance containing sample information such as + audio filepath and RTTM filepath. target_spks (tuple): Speaker indices that are generated from combinations. If there are only one or two speakers, only a single target_spks tuple is generated. @@ -336,9 +485,10 @@ def parse_rttm_for_ms_targets(self, sample): multiscale embeddings to form an input matrix for the MSDD model. """ - rttm_lines = open(sample.rttm_file).readlines() + with open(sample.rttm_file, 'r') as file: + rttm_lines = file.readlines() uniq_id = self.get_uniq_id_with_range(sample) - rttm_timestamps = extract_seg_info_from_rttm(uniq_id, rttm_lines) + rttm_timestamps = extract_seg_info_from_rttm(rttm_lines) fr_level_target = assign_frame_level_spk_vector( rttm_timestamps, self.round_digits, self.frame_per_sec, target_spks=sample.target_spks ) @@ -370,14 +520,14 @@ def get_uniq_id_with_range(self, sample, deci=3): def get_ms_seg_timestamps(self, sample): """ - Get start and end time of segments in each scale. + Get start and end time of each diarization frame. Args: sample: `DiarizationSpeechLabel` instance from preprocessing.collections Returns: ms_seg_timestamps (torch.tensor): - Tensor containing Multiscale segment timestamps. + Tensor containing timestamps for each frame. ms_seg_counts (torch.tensor): Number of segments for each scale. This information is used for reshaping embedding batch during forward propagation. @@ -441,7 +591,8 @@ class _AudioMSDDInferDataset(Dataset): emb_dict (dict): Dictionary containing cluster-average embeddings and speaker mapping information. emb_seq (dict): - Dictionary containing multiscale speaker embedding sequence, scale mapping and corresponding segment timestamps. + Dictionary containing multiscale speaker embedding sequence, + scale mapping and corresponding segment timestamps. clus_label_dict (dict): Subsegment-level (from base-scale) speaker labels from clustering results. soft_label_thres (float): @@ -496,7 +647,7 @@ def __init__( self.emb_seq = emb_seq self.clus_label_dict = clus_label_dict self.round_digits = 2 - self.decim = 10 ** self.round_digits + self.decim = 10**self.round_digits self.frame_per_sec = int(1 / window_stride) self.soft_label_thres = soft_label_thres self.pairwise_infer = pairwise_infer @@ -529,20 +680,20 @@ def parse_rttm_multiscale(self, sample): rttm_lines = open(sample.rttm_file).readlines() uniq_id = os.path.splitext(os.path.basename(sample.rttm_file))[0] mapping_dict = self.emb_dict[max(self.emb_dict.keys())][uniq_id]['mapping'] - rttm_timestamps = extract_seg_info_from_rttm(uniq_id, rttm_lines, mapping_dict, sample.target_spks) + rttm_timestamps = extract_seg_info_from_rttm(rttm_lines, mapping_dict, sample.target_spks) fr_level_target = assign_frame_level_spk_vector( rttm_timestamps, self.round_digits, self.frame_per_sec, sample.target_spks ) seg_target = self.get_diar_target_labels_from_fr_target(uniq_id, fr_level_target) return seg_target - def get_diar_target_labels_from_fr_target(self, uniq_id, fr_level_target): + def get_diar_target_labels_from_fr_target(self, uniq_id: str, fr_level_target: torch.Tensor) -> torch.Tensor: """ Generate base-scale level binary diarization label from frame-level target matrix. For the given frame-level speaker target matrix fr_level_target, we count the number of frames that belong to each speaker and calculate - ratios for each speaker into the `soft_label_vec` variable. Finally, `soft_label_vec` variable is compared with `soft_label_thres` - to determine whether a label vector should contain 0 or 1 for each speaker bin. Note that seg_target variable has - dimension of (number of base-scale segments x 2) dimension. + ratios for each speaker into the `soft_label_vec` variable. Finally, `soft_label_vec` variable is compared + with `soft_label_thres` to determine whether a label vector should contain 0 or 1 for each speaker bin. + Note that seg_target variable has dimension of (number of base-scale segments x 2) dimension. Example of seg_target: [[0., 1.], [0., 1.], [1., 1.], [1., 0.], [1., 0.], ..., [0., 1.]] @@ -562,7 +713,7 @@ def get_diar_target_labels_from_fr_target(self, uniq_id, fr_level_target): return None else: seg_target_list = [] - for (seg_stt, seg_end, label_int) in self.clus_label_dict[uniq_id]: + for seg_stt, seg_end, label_int in self.clus_label_dict[uniq_id]: seg_stt_fr, seg_end_fr = int(seg_stt * self.frame_per_sec), int(seg_end * self.frame_per_sec) soft_label_vec = torch.sum(fr_level_target[seg_stt_fr:seg_end_fr, :], axis=0) / ( seg_end_fr - seg_stt_fr @@ -588,7 +739,8 @@ def __getitem__(self, index): if avg_embs.shape[2] > self.max_spks: raise ValueError( - f" avg_embs.shape[2] {avg_embs.shape[2]} should be less than or equal to self.max_num_speakers {self.max_spks}" + f" avg_embs.shape[2] {avg_embs.shape[2]} should be less than or equal to " + f"self.max_num_speakers {self.max_spks}" ) feats = [] @@ -682,7 +834,8 @@ def _msdd_train_collate_fn(self, batch): def _msdd_infer_collate_fn(self, batch): """ - Collate batch of feats (speaker embeddings), feature lengths, target label sequences and cluster-average embeddings. + Collate batch of feats (speaker embeddings), feature lengths, target label sequences + and cluster-average embeddings. Args: batch (tuple): @@ -784,6 +937,7 @@ def __init__( ) def msdd_train_collate_fn(self, batch): + """Collate batch of audio features, feature lengths, target label sequences for training.""" return _msdd_train_collate_fn(self, batch) @@ -805,11 +959,13 @@ class AudioToSpeechMSDDInferDataset(_AudioMSDDInferDataset): emb_dict (dict): Dictionary containing cluster-average embeddings and speaker mapping information. emb_seq (dict): - Dictionary containing multiscale speaker embedding sequence, scale mapping and corresponding segment timestamps. + Dictionary containing multiscale speaker embedding sequence, scale mapping + and corresponding segment timestamps. clus_label_dict (dict): Subsegment-level (from base-scale) speaker labels from clustering results. soft_label_thres (float): - Threshold that determines speaker labels of segments depending on the overlap with groundtruth speaker timestamps. + Threshold that determines speaker labels of segments depending on the overlap + with groundtruth speaker timestamps. featurizer: Featurizer instance for generating features from raw waveform. use_single_scale_clus (bool): @@ -817,11 +973,12 @@ class AudioToSpeechMSDDInferDataset(_AudioMSDDInferDataset): seq_eval_mode (bool): If True, F1 score will be calculated for each speaker pair during inference mode. window_stride (float): - Window stride for acoustic feature. This value is used for calculating the numbers of feature-level frames. + Window stride for acoustic feature. This value is used for calculating the numbers of + feature-level frames. pairwise_infer (bool): - If True, this Dataset class operates in inference mode. In inference mode, a set of speakers in the input audio - is split into multiple pairs of speakers and speaker tuples (e.g. 3 speakers: [(0,1), (1,2), (0,2)]) and then - fed into the MSDD to merge the individual results. + If True, this Dataset class operates in inference mode. In inference mode, a set of speakers + in the input audio is split into multiple pairs of speakers and speaker tuples + (e.g. 3 speakers: [(0,1), (1,2), (0,2)]) and then fed into the MSDD to merge the individual results. """ def __init__( @@ -850,4 +1007,366 @@ def __init__( ) def msdd_infer_collate_fn(self, batch): + """Collate batch of audio features, feature lengths, target label sequences for inference.""" return _msdd_infer_collate_fn(self, batch) + + +class _AudioToSpeechE2ESpkDiarDataset(Dataset): + """ + Dataset class that loads a json file containing paths to audio files, + RTTM files and number of speakers. This Dataset class is designed for + training or fine-tuning speaker embedding extractor and diarization decoder + at the same time. + + Example: + {"audio_filepath": "/path/to/audio_0.wav", "num_speakers": 2, + "rttm_filepath": "/path/to/diar_label_0.rttm} + ... + {"audio_filepath": "/path/to/audio_n.wav", "num_speakers": 2, + "rttm_filepath": "/path/to/diar_label_n.rttm} + + Args: + manifest_filepath (str): + Path to input manifest json files. + multiargs_dict (dict): + Dictionary containing the parameters for multiscale segmentation and clustering. + soft_label_thres (float): + Threshold that determines the label of each segment based on RTTM file information. + featurizer: + Featurizer instance for generating audio_signal from the raw waveform. + window_stride (float): + Window stride for acoustic feature. This value is used for calculating the numbers of feature-level frames. + """ + + @property + def output_types(self) -> Optional[Dict[str, NeuralType]]: + """Returns definitions of module output ports.""" + output_types = { + "audio_signal": NeuralType(('B', 'T'), AudioSignal()), + "audio_length": NeuralType(('B'), LengthsType()), + "targets": NeuralType(('B', 'T', 'C'), ProbsType()), + "target_len": NeuralType(('B'), LengthsType()), + } + + return output_types + + def __init__( + self, + *, + manifest_filepath: str, + soft_label_thres: float, + session_len_sec: float, + num_spks: int, + featurizer, + window_stride: float, + min_subsegment_duration: float = 0.03, + global_rank: int = 0, + dtype=torch.float16, + round_digits: int = 2, + soft_targets: bool = False, + subsampling_factor: int = 8, + ): + super().__init__() + self.collection = EndtoEndDiarizationSpeechLabel( + manifests_files=manifest_filepath.split(','), + round_digits=round_digits, + ) + self.featurizer = featurizer + self.round_digits = round_digits + self.feat_per_sec = int(1 / window_stride) + self.diar_frame_length = round(subsampling_factor * window_stride, round_digits) + self.session_len_sec = session_len_sec + self.soft_label_thres = soft_label_thres + self.max_spks = num_spks + self.min_subsegment_duration = min_subsegment_duration + self.dtype = dtype + self.use_asr_style_frame_count = True + self.soft_targets = soft_targets + self.round_digits = 2 + self.floor_decimal = 10**self.round_digits + + def __len__(self): + return len(self.collection) + + def get_uniq_id_with_range(self, sample, deci=3): + """ + Generate unique training sample ID from unique file ID, offset and duration. The start-end time added + unique ID is required for identifying the sample since multiple short audio samples are generated from a single + audio file. The start time and end time of the audio stream uses millisecond units if `deci=3`. + + Args: + sample: + `DiarizationSpeechLabel` instance from collections. + + Returns: + uniq_id (str): + Unique sample ID which includes start and end time of the audio stream. + Example: abc1001_3122_6458 + """ + bare_uniq_id = os.path.splitext(os.path.basename(sample.rttm_file))[0] + offset = str(int(round(sample.offset, deci) * pow(10, deci))) + endtime = str(int(round(sample.offset + sample.duration, deci) * pow(10, deci))) + uniq_id = f"{bare_uniq_id}_{offset}_{endtime}" + return uniq_id + + def parse_rttm_for_targets_and_lens(self, rttm_file, offset, duration, target_len): + """ + Generate target tensor variable by extracting groundtruth diarization labels from an RTTM file. + This function converts (start, end, speaker_id) format into base-scale (the finest scale) segment level + diarization label in a matrix form. + + Example of seg_target: + [[0., 1.], [0., 1.], [1., 1.], [1., 0.], [1., 0.], ..., [0., 1.]] + """ + with open(rttm_file, 'r') as f: + rttm_lines = f.readlines() + + rttm_timestamps, sess_to_global_spkids = extract_frame_info_from_rttm(offset, duration, rttm_lines) + + fr_level_target = get_frame_targets_from_rttm( + rttm_timestamps=rttm_timestamps, + offset=offset, + duration=duration, + round_digits=self.round_digits, + feat_per_sec=self.feat_per_sec, + max_spks=self.max_spks, + ) + + soft_target_seg = self.get_soft_targets_seg(feat_level_target=fr_level_target, target_len=target_len) + if self.soft_targets: + step_target = soft_target_seg + else: + step_target = (soft_target_seg >= self.soft_label_thres).float() + return step_target + + def get_soft_targets_seg(self, feat_level_target, target_len): + """ + Generate the final targets for the actual diarization step. + Here, frame level means step level which is also referred to as segments. + We follow the original paper and refer to the step level as "frames". + + Args: + feat_level_target (torch.tensor): + Tensor variable containing hard-labels of speaker activity in each feature-level segment. + target_len (torch.tensor): + Numbers of ms segments + + Returns: + soft_target_seg (torch.tensor): + Tensor variable containing soft-labels of speaker activity in each step-level segment. + """ + num_seg = torch.max(target_len) + targets = torch.zeros(num_seg, self.max_spks) + stride = int(self.feat_per_sec * self.diar_frame_length) + for index in range(num_seg): + if index == 0: + seg_stt_feat = 0 + else: + seg_stt_feat = stride * index - 1 - int(stride / 2) + if index == num_seg - 1: + seg_end_feat = feat_level_target.shape[0] + else: + seg_end_feat = stride * index - 1 + int(stride / 2) + targets[index] = torch.mean(feat_level_target[seg_stt_feat : seg_end_feat + 1, :], axis=0) + return targets + + def get_segment_timestamps( + self, + duration: float, + offset: float = 0, + sample_rate: int = 16000, + ): + """ + Get start and end time of segments in each scale. + + Args: + sample: + `DiarizationSpeechLabel` instance from preprocessing.collections + Returns: + segment_timestamps (torch.tensor): + Tensor containing Multiscale segment timestamps. + target_len (torch.tensor): + Number of segments for each scale. This information is used for reshaping embedding batch + during forward propagation. + """ + subsegments = get_subsegments( + offset=offset, + window=round(self.diar_frame_length * 2, self.round_digits), + shift=self.diar_frame_length, + duration=duration, + min_subsegment_duration=self.min_subsegment_duration, + use_asr_style_frame_count=self.use_asr_style_frame_count, + sample_rate=sample_rate, + feat_per_sec=self.feat_per_sec, + ) + if self.use_asr_style_frame_count: + effective_dur = ( + np.ceil((1 + duration * sample_rate) / int(sample_rate / self.feat_per_sec)).astype(int) + / self.feat_per_sec + ) + else: + effective_dur = duration + ts_tensor = get_subsegments_to_timestamps( + subsegments, self.feat_per_sec, decimals=2, max_end_ts=(offset + effective_dur) + ) + target_len = torch.tensor([ts_tensor.shape[0]]) + return target_len + + def __getitem__(self, index): + sample = self.collection[index] + if sample.offset is None: + sample.offset = 0 + offset = sample.offset + if self.session_len_sec < 0: + session_len_sec = sample.duration + else: + session_len_sec = min(sample.duration, self.session_len_sec) + + audio_signal = self.featurizer.process(sample.audio_file, offset=offset, duration=session_len_sec) + + # We should resolve the length mis-match from the round-off errors between these two variables: + # `session_len_sec` and `audio_signal.shape[0]` + session_len_sec = ( + np.floor(audio_signal.shape[0] / self.featurizer.sample_rate * self.floor_decimal) / self.floor_decimal + ) + audio_signal = audio_signal[: round(self.featurizer.sample_rate * session_len_sec)] + + audio_signal_length = torch.tensor(audio_signal.shape[0]).long() + audio_signal, audio_signal_length = audio_signal.to('cpu'), audio_signal_length.to('cpu') + target_len = self.get_segment_timestamps(duration=session_len_sec, sample_rate=self.featurizer.sample_rate) + targets = self.parse_rttm_for_targets_and_lens( + rttm_file=sample.rttm_file, offset=offset, duration=session_len_sec, target_len=target_len + ) + return audio_signal, audio_signal_length, targets, target_len + + +def _eesd_train_collate_fn(self, batch): + """ + Collate a batch of variables needed for training the end-to-end speaker diarization (EESD) model + from raw waveforms to diarization labels. The following variables are included in the training/validation batch: + + Args: + batch (tuple): + A tuple containing the variables for diarization training. + + Returns: + audio_signal (torch.Tensor): + A tensor containing the raw waveform samples (time series) loaded from the `audio_filepath` + in the input manifest file. + feature_length (torch.Tensor): + A tensor containing the lengths of the raw waveform samples. + targets (torch.Tensor): + Groundtruth speaker labels for the given input embedding sequence. + target_lens (torch.Tensor): + A tensor containing the number of segments for each sample in the batch, necessary for + reshaping inputs to the EESD model. + """ + packed_batch = list(zip(*batch)) + audio_signal, feature_length, targets, target_len = packed_batch + audio_signal_list, feature_length_list = [], [] + target_len_list, targets_list = [], [] + + max_raw_feat_len = max([x.shape[0] for x in audio_signal]) + max_target_len = max([x.shape[0] for x in targets]) + if max([len(feat.shape) for feat in audio_signal]) > 1: + max_ch = max([feat.shape[1] for feat in audio_signal]) + else: + max_ch = 1 + for feat, feat_len, tgt, segment_ct in batch: + seq_len = tgt.shape[0] + if len(feat.shape) > 1: + pad_feat = (0, 0, 0, max_raw_feat_len - feat.shape[0]) + else: + pad_feat = (0, max_raw_feat_len - feat.shape[0]) + if feat.shape[0] < feat_len: + feat_len_pad = feat_len - feat.shape[0] + feat = torch.nn.functional.pad(feat, (0, feat_len_pad)) + pad_tgt = (0, 0, 0, max_target_len - seq_len) + padded_feat = torch.nn.functional.pad(feat, pad_feat) + padded_tgt = torch.nn.functional.pad(tgt, pad_tgt) + if max_ch > 1 and padded_feat.shape[1] < max_ch: + feat_ch_pad = max_ch - padded_feat.shape[1] + padded_feat = torch.nn.functional.pad(padded_feat, (0, feat_ch_pad)) + audio_signal_list.append(padded_feat) + feature_length_list.append(feat_len.clone().detach()) + target_len_list.append(segment_ct.clone().detach()) + targets_list.append(padded_tgt) + audio_signal = torch.stack(audio_signal_list) + feature_length = torch.stack(feature_length_list) + target_lens = torch.stack(target_len_list).squeeze(1) + targets = torch.stack(targets_list) + return audio_signal, feature_length, targets, target_lens + + +class AudioToSpeechE2ESpkDiarDataset(_AudioToSpeechE2ESpkDiarDataset): + """ + Dataset class for loading a JSON file containing paths to audio files, + RTTM (Rich Transcription Time Marked) files, and the number of speakers. + This class is designed for training or fine-tuning a speaker embedding + extractor and diarization decoder simultaneously. + + The JSON manifest file should have entries in the following format: + + Example: + { + "audio_filepath": "/path/to/audio_0.wav", + "num_speakers": 2, + "rttm_filepath": "/path/to/diar_label_0.rttm" + } + ... + { + "audio_filepath": "/path/to/audio_n.wav", + "num_speakers": 2, + "rttm_filepath": "/path/to/diar_label_n.rttm" + } + + Args: + manifest_filepath (str): + Path to the input manifest JSON file containing paths to audio and RTTM files. + soft_label_thres (float): + Threshold for assigning soft labels to segments based on RTTM file information. + session_len_sec (float): + Duration of each session (in seconds) for training or fine-tuning. + num_spks (int): + Number of speakers in the audio files. + featurizer: + Instance of a featurizer for generating features from the raw waveform. + window_stride (float): + Window stride (in seconds) for extracting acoustic features, used to calculate + the number of feature frames. + global_rank (int): + Global rank of the current process (used for distributed training). + soft_targets (bool): + Whether or not to use soft targets during training. + + Methods: + eesd_train_collate_fn(batch): + Collates a batch of data for end-to-end speaker diarization training. + """ + + def __init__( + self, + *, + manifest_filepath: str, + soft_label_thres: float, + session_len_sec: float, + num_spks: int, + featurizer, + window_stride, + global_rank: int, + soft_targets: bool, + ): + super().__init__( + manifest_filepath=manifest_filepath, + soft_label_thres=soft_label_thres, + session_len_sec=session_len_sec, + num_spks=num_spks, + featurizer=featurizer, + window_stride=window_stride, + global_rank=global_rank, + soft_targets=soft_targets, + ) + + def eesd_train_collate_fn(self, batch): + """Collate a batch of data for end-to-end speaker diarization training.""" + return _eesd_train_collate_fn(self, batch) diff --git a/nemo/collections/asr/data/audio_to_diar_label_lhotse.py b/nemo/collections/asr/data/audio_to_diar_label_lhotse.py new file mode 100644 index 0000000000000..927e3887de78f --- /dev/null +++ b/nemo/collections/asr/data/audio_to_diar_label_lhotse.py @@ -0,0 +1,82 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Optional, Tuple + +import torch.utils.data +from lhotse.dataset import AudioSamples +from lhotse.dataset.collation import collate_matrices + +from nemo.collections.asr.parts.utils.asr_multispeaker_utils import ( + get_hidden_length_from_sample_length, + speaker_to_target, +) +from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, NeuralType + + +class LhotseAudioToSpeechE2ESpkDiarDataset(torch.utils.data.Dataset): + """ + This dataset is a Lhotse version of diarization dataset in audio_to_diar_label.py. + Unlike native NeMo datasets, Lhotse dataset defines only the mapping from + a CutSet (meta-data) to a mini-batch with PyTorch tensors. + Specifically, it performs tokenization, I/O, augmentation, and feature extraction (if any). + Managing data, sampling, de-duplication across workers/nodes etc. is all handled + by Lhotse samplers instead. + """ + + @property + def output_types(self) -> Optional[Dict[str, NeuralType]]: + """Define the output types of the dataset.""" + return { + 'audio_signal': NeuralType(('B', 'T'), AudioSignal()), + 'a_sig_length': NeuralType(tuple('B'), LengthsType()), + 'targets': NeuralType(('B', 'T', 'N'), LabelsType()), + 'target_length': NeuralType(tuple('B'), LengthsType()), + 'sample_id': NeuralType(tuple('B'), LengthsType(), optional=True), + } + + def __init__(self, cfg): + super().__init__() + self.load_audio = AudioSamples(fault_tolerant=True) + self.cfg = cfg + self.num_speakers = self.cfg.get('num_speakers', 4) + self.num_sample_per_mel_frame = int( + self.cfg.get('window_stride', 0.01) * self.cfg.get('sample_rate', 16000) + ) # 160 samples for every 1ms by default + self.num_mel_frame_per_target_frame = int(self.cfg.get('subsampling_factor', 8)) + self.spk_tar_all_zero = self.cfg.get('spk_tar_all_zero', False) + + def __getitem__(self, cuts) -> Tuple[torch.Tensor, ...]: + audio, audio_lens, cuts = self.load_audio(cuts) + speaker_activities = [] + for cut in cuts: + speaker_activity = speaker_to_target( + a_cut=cut, + num_speakers=self.num_speakers, + num_sample_per_mel_frame=self.num_sample_per_mel_frame, + num_mel_frame_per_asr_frame=self.num_mel_frame_per_target_frame, + spk_tar_all_zero=self.spk_tar_all_zero, + boundary_segments=True, + ) + speaker_activities.append(speaker_activity) + targets = collate_matrices(speaker_activities).to(audio.dtype) + target_lens_list = [] + for audio_len in audio_lens: + target_fr_len = get_hidden_length_from_sample_length( + audio_len, self.num_sample_per_mel_frame, self.num_mel_frame_per_target_frame + ) + target_lens_list.append([target_fr_len]) + target_lens = torch.tensor(target_lens_list) + + return audio, audio_lens, targets, target_lens diff --git a/nemo/collections/asr/data/audio_to_text_dataset.py b/nemo/collections/asr/data/audio_to_text_dataset.py index f91710de3cb3b..3e1301dd4d538 100644 --- a/nemo/collections/asr/data/audio_to_text_dataset.py +++ b/nemo/collections/asr/data/audio_to_text_dataset.py @@ -871,7 +871,6 @@ def write_on_batch_end( item["audio_filepath"] = sample.recording.sources[0].source else: item["audio_filepath"] = sample.id - item["audio_filepath"] = sample.recording.sources[0].source item["offset"] = sample.start item["duration"] = sample.duration item["text"] = sample.supervisions[0].text or '' diff --git a/nemo/collections/asr/losses/__init__.py b/nemo/collections/asr/losses/__init__.py index 756a071178d72..f88bd49d1f7b3 100644 --- a/nemo/collections/asr/losses/__init__.py +++ b/nemo/collections/asr/losses/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. from nemo.collections.asr.losses.angularloss import AngularSoftmaxLoss +from nemo.collections.asr.losses.bce_loss import BCELoss from nemo.collections.asr.losses.ctc import CTCLoss from nemo.collections.asr.losses.lattice_losses import LatticeLoss from nemo.collections.asr.losses.ssl_losses.contrastive import ContrastiveLoss diff --git a/nemo/collections/asr/losses/bce_loss.py b/nemo/collections/asr/losses/bce_loss.py index 30e31b8610ec6..36a7a0166f266 100644 --- a/nemo/collections/asr/losses/bce_loss.py +++ b/nemo/collections/asr/losses/bce_loss.py @@ -28,12 +28,11 @@ class BCELoss(Loss, Typing): @property def input_types(self): - """Input types definitions for AnguarLoss. - """ + """Input types definitions for AnguarLoss.""" return { "probs": NeuralType(('B', 'T', 'C'), ProbsType()), 'labels': NeuralType(('B', 'T', 'C'), LabelsType()), - "signal_lengths": NeuralType(tuple('B'), LengthsType()), + "target_lens": NeuralType(('B'), LengthsType()), } @property @@ -43,31 +42,94 @@ def output_types(self): """ return {"loss": NeuralType(elements_type=LossType())} - def __init__(self, reduction='sum', alpha=1.0, weight=torch.tensor([0.5, 0.5])): + def __init__( + self, + reduction: str = 'mean', + alpha: float = 1.0, + weight: torch.Tensor = torch.tensor([0.1, 0.9]), + sorted_preds: bool = False, + sorted_loss: bool = False, + class_normalization: bool = False, + ): + """ + A custom loss function that supports class normalization, + weighted binary cross-entropy, and optional sorting. + + Args: + reduction (str): Specifies the reduction to apply to the output, + options are 'mean', 'sum', or 'none'. Default is 'mean'. + alpha (float): Scaling factor for loss (unused in this implementation). Default is 1.0. + weight (torch.Tensor): Class weights for the binary cross-entropy loss. Default is [0.1, 0.9]. + sorted_preds (bool): If True, assumes predictions are sorted. Default is False. + sorted_loss (bool): If True, sorts the loss before reduction. Default is False. + class_normalization (bool): If True, uses 'none' reduction for per-class loss. Default is False. + """ super().__init__() - self.reduction = reduction + self.class_normalization = class_normalization + if class_normalization: + self.reduction = 'none' + else: + self.reduction = 'mean' self.loss_weight = weight - self.loss_f = torch.nn.BCELoss(weight=self.loss_weight, reduction=self.reduction) + self.loss_f = torch.nn.BCELoss(reduction=self.reduction) + self.sorted_preds = sorted_preds + self.sorted_loss = sorted_loss + self.eps = 1e-6 @typecheck() - def forward(self, probs, labels, signal_lengths): + def forward(self, probs, labels, target_lens): """ - Calculate binary cross entropy loss based on probs, labels and signal_lengths variables. + Calculate binary cross entropy loss based on probs, labels and target_lens variables. Args: probs (torch.tensor) Predicted probability value which ranges from 0 to 1. Sigmoid output is expected. labels (torch.tensor) Groundtruth label for the predicted samples. - signal_lengths (torch.tensor): + target_lens (torch.tensor): The actual length of the sequence without zero-padding. Returns: loss (NeuralType) Binary cross entropy loss value. """ - probs_list = [probs[k, : signal_lengths[k], :] for k in range(probs.shape[0])] - targets_list = [labels[k, : signal_lengths[k], :] for k in range(labels.shape[0])] + probs_list = [probs[k, : target_lens[k], :] for k in range(probs.shape[0])] + targets_list = [labels[k, : target_lens[k], :] for k in range(labels.shape[0])] probs = torch.cat(probs_list, dim=0) labels = torch.cat(targets_list, dim=0) - return self.loss_f(probs, labels) + norm_weight = torch.zeros_like(labels).detach().clone() + loss = torch.tensor(0.0).to(labels.device) + + if self.class_normalization in ['class', 'class_binary', 'binary']: + if self.class_normalization in ['class', 'class_binary']: + # Normalize loss by number of classes + norm_weight = 1 / (labels.sum(dim=0) + self.eps) + norm_weight_norm = norm_weight / norm_weight.sum() + norm_weight_norm = torch.clamp(norm_weight_norm, min=0.05, max=1.0) + norm_weight_norm = norm_weight_norm / norm_weight_norm.max() + norm_weight = norm_weight_norm[None, :].expand_as(labels).detach().clone() + else: + norm_weight = torch.ones_like(labels).detach().clone() + + if self.class_normalization in ['binary', 'class_binary']: + binary_weight = torch.ones_like(labels).detach().clone() + one_weight = (labels.sum() / (labels.shape[0] * labels.shape[1])).to(labels.device) + binary_weight[labels == 0] = one_weight + binary_weight[labels == 1] = 1 - one_weight + else: + binary_weight = torch.ones_like(labels).detach().clone() + + elif self.class_normalization == 'none' or not self.class_normalization: + binary_weight = torch.ones_like(labels).detach().clone() + norm_weight = torch.ones_like(labels).detach().clone() + + if self.reduction == 'sum': + loss = self.loss_f(probs, labels) + elif self.reduction == 'mean': + loss = self.loss_f(probs, labels).mean() + elif self.reduction == 'none': + if self.class_normalization in ['class', 'class_binary', 'binary']: + loss = (binary_weight * norm_weight * self.loss_f(probs, labels)).sum() + else: + loss = self.loss_f(probs, labels) + return loss diff --git a/nemo/collections/asr/metrics/der.py b/nemo/collections/asr/metrics/der.py index fc5cded970d07..c8dec24eaaca0 100644 --- a/nemo/collections/asr/metrics/der.py +++ b/nemo/collections/asr/metrics/der.py @@ -36,12 +36,12 @@ def get_partial_ref_labels(pred_labels: List[str], ref_labels: List[str]) -> List[str]: """ - For evaluation of online diarization performance, generate partial reference labels + For evaluation of online diarization performance, generate partial reference labels from the last prediction time. Args: pred_labels (list[str]): list of partial prediction labels - ref_labels (list[str]): list of full reference labels + ref_labels (list[str]): list of full reference labels Returns: ref_labels_out (list[str]): list of partial reference labels @@ -84,8 +84,8 @@ def get_online_DER_stats( For evaluation of online diarization performance, add cumulative, average, and maximum DER/CER. Args: - DER (float): Diarization Error Rate from the start to the current point - CER (float): Confusion Error Rate from the start to the current point + DER (float): Diarization Error Rate from the start to the current point + CER (float): Confusion Error Rate from the start to the current point FA (float): False Alarm from the start to the current point MISS (float): Miss rate from the start to the current point diar_eval_count (int): Number of evaluation sessions @@ -123,30 +123,45 @@ def uem_timeline_from_file(uem_file, uniq_name=''): lines = f.readlines() for line in lines: line = line.strip() - speaker_id, channel, start_time, end_time = line.split() + _, _, start_time, end_time = line.split() timeline.add(Segment(float(start_time), float(end_time))) return timeline def score_labels( - AUDIO_RTTM_MAP, all_reference, all_hypothesis, collar=0.25, ignore_overlap=True, verbose: bool = True + AUDIO_RTTM_MAP, + all_reference: list, + all_hypothesis: list, + all_uem: List[List[float]] = None, + collar: float = 0.25, + ignore_overlap: bool = True, + verbose: bool = True, ) -> Optional[Tuple[DiarizationErrorRate, Dict]]: """ Calculate DER, CER, FA and MISS rate from hypotheses and references. Hypothesis results are coming from Pyannote-formatted speaker diarization results and References are coming from Pyannote-formatted RTTM data. - Args: AUDIO_RTTM_MAP (dict): Dictionary containing information provided from manifestpath all_reference (list[uniq_name,Annotation]): reference annotations for score calculation all_hypothesis (list[uniq_name,Annotation]): hypothesis annotations for score calculation - verbose (bool): Warns if RTTM file is not found. + all_uem (list[list[float]]): List of UEM segments for each audio file. If UEM file is not provided, + it will be read from manifestpath + collar (float): Length of collar (in seconds) for diarization error rate calculation + ignore_overlap (bool): If True, overlapping segments in reference and hypothesis will be ignored + verbose (bool): If True, warning messages will be printed Returns: - metric (pyannote.DiarizationErrorRate): Pyannote Diarization Error Rate metric object. This object contains detailed scores of each audiofile. + metric (pyannote.DiarizationErrorRate): Pyannote Diarization Error Rate metric object. + This object contains detailed scores of each audiofile. mapping (dict): Mapping dict containing the mapping speaker label for each audio input + itemized_errors (tuple): Tuple containing (DER, CER, FA, MISS) for each audio file. + - DER: Diarization Error Rate, which is sum of all three errors, CER + FA + MISS. + - CER: Confusion Error Rate, which is sum of all errors + - FA: False Alarm Rate, which is the number of false alarm segments + - MISS: Missed Detection Rate, which is the number of missed detection segments < Caveat > Unlike md-eval.pl, "no score" collar in pyannote.metrics is the maximum length of @@ -157,33 +172,51 @@ def score_labels( if len(all_reference) == len(all_hypothesis): metric = DiarizationErrorRate(collar=2 * collar, skip_overlap=ignore_overlap) - mapping_dict = {} - for (reference, hypothesis) in zip(all_reference, all_hypothesis): + mapping_dict, correct_spk_count = {}, 0 + for idx, (reference, hypothesis) in enumerate(zip(all_reference, all_hypothesis)): ref_key, ref_labels = reference _, hyp_labels = hypothesis - uem = AUDIO_RTTM_MAP[ref_key].get('uem_filepath', None) - if uem is not None: - uem = uem_timeline_from_file(uem_file=uem, uniq_name=ref_key) - metric(ref_labels, hyp_labels, uem=uem, detailed=True) + if len(ref_labels.labels()) == len(hyp_labels.labels()): + correct_spk_count += 1 + if verbose and len(ref_labels.labels()) != len(hyp_labels.labels()): + logging.info( + f"Wrong Spk. Count with uniq_id:...{ref_key[-10:]}, " + f"Ref: {len(ref_labels.labels())}, Hyp: {len(hyp_labels.labels())}" + ) + uem_obj = None + if all_uem is not None: + metric(ref_labels, hyp_labels, uem=all_uem[idx], detailed=True) + elif AUDIO_RTTM_MAP[ref_key].get('uem_filepath', None) is not None: + uem_file = AUDIO_RTTM_MAP[ref_key].get('uem_filepath', None) + uem_obj = uem_timeline_from_file(uem_file=uem_file, uniq_name=ref_key) + metric(ref_labels, hyp_labels, uem=uem_obj, detailed=True) + else: + metric(ref_labels, hyp_labels, detailed=True) mapping_dict[ref_key] = metric.optimal_mapping(ref_labels, hyp_labels) + spk_count_acc = correct_spk_count / len(all_reference) DER = abs(metric) + if metric['total'] == 0: + raise ValueError("Total evaluation time is 0. Abort.") CER = metric['confusion'] / metric['total'] FA = metric['false alarm'] / metric['total'] MISS = metric['missed detection'] / metric['total'] + itemized_errors = (DER, CER, FA, MISS) + if verbose: + logging.info(f"\n{metric.report()}") logging.info( - "Cumulative Results for collar {} sec and ignore_overlap {}: \n FA: {:.4f}\t MISS {:.4f}\t \ - Diarization ER: {:.4f}\t, Confusion ER:{:.4f}".format( - collar, ignore_overlap, FA, MISS, DER, CER - ) + f"Cumulative Results for collar {collar} sec and ignore_overlap {ignore_overlap}: \n" + f"| FA: {FA:.4f} | MISS: {MISS:.4f} | CER: {CER:.4f} | DER: {DER:.4f} | " + f"Spk. Count Acc. {spk_count_acc:.4f}\n" ) return metric, mapping_dict, itemized_errors elif verbose: logging.warning( - "Check if each ground truth RTTMs were present in the provided manifest file. Skipping calculation of Diariazation Error Rate" + "Check if each ground truth RTTMs were present in the provided manifest file. " + "Skipping calculation of Diariazation Error Rate" ) return None @@ -365,7 +398,7 @@ def calculate_session_cpWER( # Calculate WER for each speaker in hypothesis with reference # There are (number of hyp speakers) x (number of ref speakers) combinations lsa_wer_list = [] - for (spk_hyp_trans, spk_ref_trans) in all_pairs: + for spk_hyp_trans, spk_ref_trans in all_pairs: spk_wer = word_error_rate(hypotheses=[spk_hyp_trans], references=[spk_ref_trans]) lsa_wer_list.append(spk_wer) @@ -419,7 +452,7 @@ def concat_perm_word_error_rate( f"{len(spk_hypotheses)} and {len(spk_references)} correspondingly" ) cpWER_values, hyps_spk, refs_spk = [], [], [] - for (spk_hypothesis, spk_reference) in zip(spk_hypotheses, spk_references): + for spk_hypothesis, spk_reference in zip(spk_hypotheses, spk_references): cpWER, min_hypothesis, concat_reference = calculate_session_cpWER(spk_hypothesis, spk_reference) cpWER_values.append(cpWER) hyps_spk.append(min_hypothesis) diff --git a/nemo/collections/asr/metrics/multi_binary_acc.py b/nemo/collections/asr/metrics/multi_binary_acc.py index 8cc21c53ad821..7b2b9148a74e5 100644 --- a/nemo/collections/asr/metrics/multi_binary_acc.py +++ b/nemo/collections/asr/metrics/multi_binary_acc.py @@ -73,13 +73,26 @@ def on_validation_epoch_end(self): def __init__(self, dist_sync_on_step=False): super().__init__(dist_sync_on_step=dist_sync_on_step) - self.total_correct_counts = 0 - self.total_sample_counts = 0 - self.true_positive_count = 0 - self.false_positive_count = 0 - self.false_negative_count = 0 + self.add_state("total_correct_counts", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False) + self.add_state("total_sample_counts", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False) + self.add_state("true_positive_count", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False) + self.add_state("false_positive_count", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False) + self.add_state("false_negative_count", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False) + self.add_state("positive_count", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False) + self.eps = 1e-6 + + def update( + self, preds: torch.Tensor, targets: torch.Tensor, signal_lengths: torch.Tensor, cumulative=False + ) -> torch.Tensor: + """ + Update the metric with the given predictions, targets, and signal lengths to the metric instance. - def update(self, preds: torch.Tensor, targets: torch.Tensor, signal_lengths: torch.Tensor) -> torch.Tensor: + Args: + preds (torch.Tensor): Predicted values. + targets (torch.Tensor): Target values. + signal_lengths (torch.Tensor): Length of each sequence in the batch input. + cumulative (bool): Whether to accumulate the values over time. + """ with torch.no_grad(): preds_list = [preds[k, : signal_lengths[k], :] for k in range(preds.shape[0])] targets_list = [targets[k, : signal_lengths[k], :] for k in range(targets.shape[0])] @@ -91,22 +104,35 @@ def update(self, preds: torch.Tensor, targets: torch.Tensor, signal_lengths: tor self.positive = self.preds.round().bool() == 1 self.negative = self.preds.round().bool() == 0 - self.positive_count = torch.sum(self.preds.round().bool() == True) - self.true_positive_count += torch.sum(torch.logical_and(self.true, self.positive)) - self.false_positive_count += torch.sum(torch.logical_and(self.false, self.positive)) - self.false_negative_count += torch.sum(torch.logical_and(self.false, self.negative)) - - self.total_correct_counts += torch.sum(self.preds.round().bool() == self.targets.round().bool()) - self.total_sample_counts += torch.prod(torch.tensor(self.targets.shape)) + if cumulative: + self.positive_count += torch.sum(self.preds.round().bool() == True) + self.true_positive_count += torch.sum(torch.logical_and(self.true, self.positive)) + self.false_positive_count += torch.sum(torch.logical_and(self.false, self.positive)) + self.false_negative_count += torch.sum(torch.logical_and(self.false, self.negative)) + self.total_correct_counts += torch.sum(self.preds.round().bool() == self.targets.round().bool()) + self.total_sample_counts += torch.prod(torch.tensor(self.targets.shape)) + else: + self.positive_count = torch.sum(self.preds.round().bool() == True) + self.true_positive_count = torch.sum(torch.logical_and(self.true, self.positive)) + self.false_positive_count = torch.sum(torch.logical_and(self.false, self.positive)) + self.false_negative_count = torch.sum(torch.logical_and(self.false, self.negative)) + self.total_correct_counts = torch.sum(self.preds.round().bool() == self.targets.round().bool()) + self.total_sample_counts = torch.prod(torch.tensor(self.targets.shape)) def compute(self): """ Compute F1 score from the accumulated values. Return -1 if the F1 score is NaN. + + Returns: + f1_score (torch.Tensor): F1 score calculated from the accumulated values. + precision (torch.Tensor): Precision calculated from the accumulated values. + recall (torch.Tensor): Recall calculated from the accumulated values. """ - self.precision = self.true_positive_count / (self.true_positive_count + self.false_positive_count) - self.recall = self.true_positive_count / (self.true_positive_count + self.false_negative_count) - self.f1_score = 2 * self.precision * self.recall / (self.precision + self.recall) - if torch.isnan(self.f1_score): + precision = self.true_positive_count / (self.true_positive_count + self.false_positive_count + self.eps) + recall = self.true_positive_count / (self.true_positive_count + self.false_negative_count + self.eps) + f1_score = (2 * precision * recall / (precision + recall + self.eps)).detach().clone() + + if torch.isnan(f1_score): logging.warn("self.f1_score contains NaN value. Returning -1 instead of NaN value.") - self.f1_score = -1 - return self.f1_score + f1_score = -1 + return f1_score.float(), precision.float(), recall.float() diff --git a/nemo/collections/asr/models/__init__.py b/nemo/collections/asr/models/__init__.py index e4a1342b9c361..34dead15b33d5 100644 --- a/nemo/collections/asr/models/__init__.py +++ b/nemo/collections/asr/models/__init__.py @@ -35,6 +35,7 @@ from nemo.collections.asr.models.rnnt_bpe_models import EncDecRNNTBPEModel from nemo.collections.asr.models.rnnt_models import EncDecRNNTModel from nemo.collections.asr.models.slu_models import SLUIntentSlotBPEModel +from nemo.collections.asr.models.sortformer_diar_models import SortformerEncLabelModel from nemo.collections.asr.models.ssl_models import ( EncDecDenoiseMaskedTokenPredModel, EncDecMaskedTokenPredModel, diff --git a/nemo/collections/asr/models/msdd_models.py b/nemo/collections/asr/models/msdd_models.py index d30411f01bcc4..5e90f7d62d786 100644 --- a/nemo/collections/asr/models/msdd_models.py +++ b/nemo/collections/asr/models/msdd_models.py @@ -70,6 +70,7 @@ @contextmanager def autocast(enabled=None): + """auto-casting context manager""" yield @@ -78,8 +79,8 @@ def autocast(enabled=None): class EncDecDiarLabelModel(ModelPT, ExportableEncDecModel): """ - Encoder decoder class for multiscale diarization decoder (MSDD). Model class creates training, validation methods for setting - up data performing model forward pass. + Encoder decoder class for multiscale diarization decoder (MSDD). Model class creates training, + validation methods for setting up data performing model forward pass. This model class expects config dict for: * preprocessor @@ -99,15 +100,18 @@ def list_available_models(cls) -> List[PretrainedModelInfo]: model = PretrainedModelInfo( pretrained_model_name="diar_msdd_telephonic", - location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/diar_msdd_telephonic/versions/1.0.1/files/diar_msdd_telephonic.nemo", - description="For details about this model, please visit https://ngc.nvidia.com/catalog/models/nvidia:nemo:diar_msdd_telephonic", + location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/" + "diar_msdd_telephonic/versions/1.0.1/files/diar_msdd_telephonic.nemo", + description="For details about this model, please visit " + "https://ngc.nvidia.com/catalog/models/nvidia:nemo:diar_msdd_telephonic", ) result.append(model) return result def __init__(self, cfg: DictConfig, trainer: Trainer = None): """ - Initialize an MSDD model and the specified speaker embedding model. In this init function, training and validation datasets are prepared. + Initialize an MSDD model and the specified speaker embedding model. In this init function, + training and validation datasets are prepared. """ self._trainer = trainer if trainer else None self.cfg_msdd_model = cfg @@ -173,9 +177,9 @@ def _init_segmentation_info(self): def _init_speaker_model(self): """ - Initialize speaker embedding model with model name or path passed through config. Note that speaker embedding model is loaded to - `self.msdd` to enable multi-gpu and multi-node training. In addition, speaker embedding model is also saved with msdd model when - `.ckpt` files are saved. + Initialize speaker embedding model with model name or path passed through config. Note that + speaker embedding model is loaded to `self.msdd` to enable multi-gpu and multi-node training. + In addition, speaker embedding model is also saved with msdd model when `.ckpt` files are saved. """ model_path = self.cfg_msdd_model.diarizer.speaker_embeddings.model_path self._diarizer_params = self.cfg_msdd_model.diarizer @@ -341,15 +345,17 @@ def get_ms_emb_seq( Merged embeddings without zero-padding in the batch. See `ms_seg_counts` for details. Shape: (Total number of segments in the batch, emb_dim) scale_mapping (Tensor): - The element at the m-th row and the n-th column of the scale mapping matrix indicates the (m+1)-th scale - segment index which has the closest center distance with (n+1)-th segment in the base scale. + The element at the m-th row and the n-th column of the scale mapping matrix indicates + the (m+1)-th scale segment index which has the closest center distance with (n+1)-th segment + in the base scale. + Example: scale_mapping_argmat[2][101] = 85 - In the above example, it means that 86-th segment in the 3rd scale (python index is 2) is mapped with - 102-th segment in the base scale. Thus, the longer segments bound to have more repeating numbers since - multiple base scale segments (since the base scale has the shortest length) fall into the range of the - longer segments. At the same time, each row contains N numbers of indices where N is number of - segments in the base-scale (i.e., the finest scale). + In the above example, it means that 86-th segment in the 3rd scale (python index is 2) + is mapped with 102-th segment in the base scale. Thus, the longer segments bound to have more + repeating numbers since multiple base scale segments (since the base scale has the shortest length) + fall into the range of the longer segments. At the same time, each row contains N numbers of + indices where N is number of segments in the base-scale (i.e., the finest scale). Shape: (batch_size, scale_n, self.diar_window_length) ms_seg_counts (Tensor): Cumulative sum of the number of segments in each scale. This information is needed to reconstruct @@ -366,8 +372,8 @@ def get_ms_emb_seq( Returns: ms_emb_seq (Tensor): - Multi-scale embedding sequence that is mapped, matched and repeated. The longer scales are less repeated, - while shorter scales are more frequently repeated following the scale mapping tensor. + Multi-scale embedding sequence that is mapped, matched and repeated. The longer scales are + less repeated, while shorter scales are more frequently repeated following the scale mapping tensor. """ scale_n, batch_size = scale_mapping[0].shape[0], scale_mapping.shape[0] split_emb_tup = torch.split(embs, ms_seg_counts.view(-1).tolist(), dim=0) @@ -388,19 +394,20 @@ def get_cluster_avg_embs_model( self, embs: torch.Tensor, clus_label_index: torch.Tensor, ms_seg_counts: torch.Tensor, scale_mapping ) -> torch.Tensor: """ - Calculate the cluster-average speaker embedding based on the ground-truth speaker labels (i.e., cluster labels). + Calculate the cluster-average speaker embedding based on the ground-truth speaker labels + (i.e., cluster labels). Args: embs (Tensor): Merged embeddings without zero-padding in the batch. See `ms_seg_counts` for details. Shape: (Total number of segments in the batch, emb_dim) clus_label_index (Tensor): - Merged ground-truth cluster labels from all scales with zero-padding. Each scale's index can be - retrieved by using segment index in `ms_seg_counts`. + Merged ground-truth cluster labels from all scales with zero-padding. Each scale's + index can be retrieved by using segment index in `ms_seg_counts`. Shape: (batch_size, maximum total segment count among the samples in the batch) ms_seg_counts (Tensor): - Cumulative sum of the number of segments in each scale. This information is needed to reconstruct - multi-scale input tensors during forward propagating. + Cumulative sum of the number of segments in each scale. This information is needed + to reconstruct multi-scale input tensors during forward propagating. Example: `batch_size=3, scale_n=6, emb_dim=192` .. code:: python @@ -420,8 +427,9 @@ def get_cluster_avg_embs_model( Returns: ms_avg_embs (Tensor): - Multi-scale cluster-average speaker embedding vectors. These embedding vectors are used as reference for - each speaker to predict the speaker label for the given multi-scale embedding sequences. + Multi-scale cluster-average speaker embedding vectors. These embedding vectors are used + as reference for each speaker to predict the speaker label for the given multi-scale + embedding sequences. Shape: (batch_size, scale_n, emb_dim, self.num_spks_per_model) """ scale_n, batch_size = scale_mapping[0].shape[0], scale_mapping.shape[0] @@ -534,7 +542,8 @@ def get_ms_mel_feat( def forward_infer(self, input_signal, input_signal_length, emb_vectors, targets): """ - Wrapper function for inference case. + Wrapper function for inference case. This `forward_infer` is only used during inference, where `forward` + is used for training and validation. """ preds, scale_weights = self.msdd( ms_emb_seq=input_signal, length=input_signal_length, ms_avg_embs=emb_vectors, targets=targets @@ -545,6 +554,7 @@ def forward_infer(self, input_signal, input_signal_length, emb_vectors, targets) def forward( self, features, feature_length, ms_seg_timestamps, ms_seg_counts, clus_label_index, scale_mapping, targets ): + """Function to compute forward pass for training/validation.""" processed_signal, processed_signal_len = self.msdd._speaker_model.preprocessor( input_signal=features, length=feature_length ) @@ -577,6 +587,7 @@ def forward( return preds, scale_weights def training_step(self, batch: list, batch_idx: int): + """Function to compute training step.""" features, feature_length, ms_seg_timestamps, ms_seg_counts, clus_label_index, scale_mapping, targets = batch sequence_lengths = torch.tensor([x[-1] for x in ms_seg_counts.detach()]) preds, _ = self.forward( @@ -588,10 +599,11 @@ def training_step(self, batch: list, batch_idx: int): scale_mapping=scale_mapping, targets=targets, ) - loss = self.loss(probs=preds, labels=targets, signal_lengths=sequence_lengths) + # loss = self.loss(probs=preds, labels=targets, signal_lengths=sequence_lengths) + loss = self.loss(probs=preds, labels=targets, target_lens=sequence_lengths) self._accuracy_train(preds, targets, sequence_lengths) torch.cuda.empty_cache() - f1_acc = self._accuracy_train.compute() + f1_acc, _, _ = self._accuracy_train.compute() self.log('loss', loss, sync_dist=True) self.log('learning_rate', self._optimizer.param_groups[0]['lr'], sync_dist=True) self.log('train_f1_acc', f1_acc, sync_dist=True) @@ -599,6 +611,7 @@ def training_step(self, batch: list, batch_idx: int): return {'loss': loss} def validation_step(self, batch: list, batch_idx: int, dataloader_idx: int = 0): + """Function to compute validation step.""" features, feature_length, ms_seg_timestamps, ms_seg_counts, clus_label_index, scale_mapping, targets = batch sequence_lengths = torch.tensor([x[-1] for x in ms_seg_counts]) preds, _ = self.forward( @@ -610,9 +623,10 @@ def validation_step(self, batch: list, batch_idx: int, dataloader_idx: int = 0): scale_mapping=scale_mapping, targets=targets, ) - loss = self.loss(probs=preds, labels=targets, signal_lengths=sequence_lengths) + # loss = self.loss(probs=preds, labels=targets, signal_lengths=sequence_lengths) + loss = self.loss(probs=preds, labels=targets, target_lens=sequence_lengths) self._accuracy_valid(preds, targets, sequence_lengths) - f1_acc = self._accuracy_valid.compute() + f1_acc, _, _ = self._accuracy_valid.compute() self.log('val_loss', loss, sync_dist=True) self.log('val_f1_acc', f1_acc, sync_dist=True) return { @@ -622,7 +636,7 @@ def validation_step(self, batch: list, batch_idx: int, dataloader_idx: int = 0): def multi_validation_epoch_end(self, outputs: list, dataloader_idx: int = 0): val_loss_mean = torch.stack([x['val_loss'] for x in outputs]).mean() - f1_acc = self._accuracy_valid.compute() + f1_acc, _, _ = self._accuracy_valid.compute() self._accuracy_valid.reset() self.log('val_loss', val_loss_mean, sync_dist=True) @@ -634,7 +648,7 @@ def multi_validation_epoch_end(self, outputs: list, dataloader_idx: int = 0): def multi_test_epoch_end(self, outputs: List[Dict[str, torch.Tensor]], dataloader_idx: int = 0): test_loss_mean = torch.stack([x['test_loss'] for x in outputs]).mean() - f1_acc = self._accuracy_test.compute() + f1_acc, _, _ = self._accuracy_test.compute() self._accuracy_test.reset() self.log('test_f1_acc', f1_acc, sync_dist=True) return { @@ -648,9 +662,10 @@ def compute_accuracies(self): Returns: f1_score (float): F1 score of the estimated diarized speaker label sequences. - simple_acc (float): Accuracy of predicted speaker labels: (total # of correct labels)/(total # of sigmoid values) + simple_acc (float): Accuracy of predicted speaker labels: + (total # of correct labels)/(total # of sigmoid values) """ - f1_score = self._accuracy_test.compute() + f1_score, _, _ = self._accuracy_test.compute() num_correct = torch.sum(self._accuracy_test.true.bool()) total_count = torch.prod(torch.tensor(self._accuracy_test.targets.shape)) simple_acc = num_correct / total_count @@ -659,7 +674,9 @@ def compute_accuracies(self): class ClusterEmbedding(torch.nn.Module): """ - This class is built for calculating cluster-average embeddings, segmentation and load/save of the estimated cluster labels. + This class is built for calculating cluster-average embeddings, segmentation and load/save of + the estimated cluster labels. + The methods in this class is used for the inference of MSDD models. Args: @@ -708,10 +725,10 @@ def prepare_cluster_embs_infer(self): def assign_labels_to_longer_segs(self, base_clus_label_dict: Dict, session_scale_mapping_dict: Dict): """ - In multi-scale speaker diarization system, clustering result is solely based on the base-scale (the shortest scale). - To calculate cluster-average speaker embeddings for each scale that are longer than the base-scale, this function assigns - clustering results for the base-scale to the longer scales by measuring the distance between subsegment timestamps in the - base-scale and non-base-scales. + In multi-scale speaker diarization system, clustering result is solely based on the base-scale + (the shortest scale). To calculate cluster-average speaker embeddings for each scale that are longer + than the base-scale, this function assigns clustering results for the base-scale to the longer scales + by measuring the distance between subsegment timestamps in the base-scale and non-base-scales. Args: base_clus_label_dict (dict): @@ -754,7 +771,8 @@ def get_base_clus_label_dict(self, clus_labels: List[str], emb_scale_seq_dict: D Dictionary containing multiscale embedding input sequences. Returns: base_clus_label_dict (dict): - Dictionary containing start and end of base scale segments and its cluster label. Indexed by `uniq_id`. + Dictionary containing start and end of base scale segments and its cluster label. + Indexed by `uniq_id`. emb_dim (int): Embedding dimension in integer. """ @@ -771,17 +789,18 @@ def get_cluster_avg_embs( self, emb_scale_seq_dict: Dict, clus_labels: List, speaker_mapping_dict: Dict, session_scale_mapping_dict: Dict ): """ - MSDD requires cluster-average speaker embedding vectors for each scale. This function calculates an average embedding vector for each cluster (speaker) - and each scale. + MSDD requires cluster-average speaker embedding vectors for each scale. This function calculates + an average embedding vector for each cluster (speaker) and each scale. Args: emb_scale_seq_dict (dict): Dictionary containing embedding sequence for each scale. Keys are scale index in integer. clus_labels (list): - Clustering results from clustering diarizer including all the sessions provided in input manifest files. + Clustering results from clustering diarizer including all the sessions provided + in input manifest files. speaker_mapping_dict (dict): - Speaker mapping dictionary in case RTTM files are provided. This is mapping between integer based speaker index and - speaker ID tokens in RTTM files. + Speaker mapping dictionary in case RTTM files are provided. This is mapping between + integer based speaker index and speaker ID tokens in RTTM files. Example: {'en_0638': {'speaker_0': 'en_0638_A', 'speaker_1': 'en_0638_B'}, 'en_4065': {'speaker_0': 'en_4065_B', 'speaker_1': 'en_4065_A'}, ...,} @@ -793,7 +812,8 @@ def get_cluster_avg_embs( Dictionary containing speaker mapping information and cluster-average speaker embedding vector. Each session-level dictionary is indexed by scale index in integer. output_clus_label_dict (dict): - Subegmentation timestamps in float type and Clustering result in integer type. Indexed by `uniq_id` keys. + Subegmentation timestamps in float type and Clustering result in integer type. + Indexed by `uniq_id` keys. """ self.scale_n = len(emb_scale_seq_dict.keys()) emb_sess_avg_dict = { @@ -830,9 +850,10 @@ def get_cluster_avg_embs( def run_clustering_diarizer(self, manifest_filepath: str, emb_dir: str): """ - If no pre-existing data is provided, run clustering diarizer from scratch. This will create scale-wise speaker embedding - sequence, cluster-average embeddings, scale mapping and base scale clustering labels. Note that speaker embedding `state_dict` - is loaded from the `state_dict` in the provided MSDD checkpoint. + If no pre-existing data is provided, run clustering diarizer from scratch. This will create + scale-wise speaker embedding sequence, cluster-average embeddings, scale mapping and base scale + clustering labels. Note that speaker embedding `state_dict` is loaded from the `state_dict` + in the provided MSDD checkpoint. Args: manifest_filepath (str): @@ -846,7 +867,8 @@ def run_clustering_diarizer(self, manifest_filepath: str, emb_dir: str): emb_scale_seq_dict (dict): Dictionary containing embedding tensors which are indexed by scale numbers. base_clus_label_dict (dict): - Dictionary containing clustering results. Clustering results are cluster labels for the base scale segments. + Dictionary containing clustering results. Clustering results are cluster labels + for the base scale segments. """ self.cfg_diar_infer.diarizer.manifest_filepath = manifest_filepath self.cfg_diar_infer.diarizer.out_dir = emb_dir @@ -974,9 +996,9 @@ def load_emb_scale_seq_dict(self, out_dir): class NeuralDiarizer(LightningModule): """ - Class for inference based on multiscale diarization decoder (MSDD). MSDD requires initializing clustering results from - clustering diarizer. Overlap-aware diarizer requires separate RTTM generation and evaluation modules to check the effect of - overlap detection in speaker diarization. + Class for inference based on multiscale diarization decoder (MSDD). MSDD requires initializing + clustering results from clustering diarizer. Overlap-aware diarizer requires separate RTTM + generation and evaluation modules to check the effect of overlap detection in speaker diarization. """ def __init__(self, cfg: Union[DictConfig, NeuralDiarizerInferenceConfig]): @@ -1029,7 +1051,8 @@ def save_to(self, save_path: str): You can use "restore_from" method to fully restore instance from .nemo file. .nemo file is an archive (tar.gz) with the following: - model_config.yaml - model configuration in .yaml format. You can deserialize this into cfg argument for model's constructor + model_config.yaml - model configuration in .yaml format. + You can deserialize this into cfg argument for model's constructor model_wights.chpt - model checkpoint Args: @@ -1053,8 +1076,8 @@ def save_to(self, save_path: str): def extract_standalone_speaker_model(self, prefix: str = 'msdd._speaker_model.') -> EncDecSpeakerLabelModel: """ - MSDD model file contains speaker embedding model and MSDD model. This function extracts standalone speaker model and save it to - `self.spk_emb_state_dict` to be loaded separately for clustering diarizer. + MSDD model file contains speaker embedding model and MSDD model. This function extracts standalone + speaker model and save it to `self.spk_emb_state_dict` to be loaded separately for clustering diarizer. Args: ext (str): @@ -1104,20 +1127,22 @@ def _init_msdd_model(self, cfg: Union[DictConfig, NeuralDiarizerInferenceConfig] def get_pred_mat(self, data_list: List[Union[Tuple[int], List[torch.Tensor]]]) -> torch.Tensor: """ - This module puts together the pairwise, two-speaker, predicted results to form a finalized matrix that has dimension of - `(total_len, n_est_spks)`. The pairwise results are evenutally averaged. For example, in 4 speaker case (speaker 1, 2, 3, 4), - the sum of the pairwise results (1, 2), (1, 3), (1, 4) are then divided by 3 to take average of the sigmoid values. + This module puts together the pairwise, two-speaker, predicted results to form a finalized matrix + that has dimension of `(total_len, n_est_spks)`. The pairwise results are evenutally averaged. + For example, in 4 speaker case (speaker 1, 2, 3, 4), the sum of the pairwise results + (1, 2), (1, 3), (1, 4) are then divided by 3 to take average of the sigmoid values. Args: data_list (list): - List containing data points from `test_data_collection` variable. `data_list` has sublists `data` as follows: - data[0]: `target_spks` tuple - Examples: (0, 1, 2) - data[1]: Tensor containing estimaged sigmoid values. - [[0.0264, 0.9995], - [0.0112, 1.0000], - ..., - [1.0000, 0.0512]] + List containing data points from `test_data_collection` variable. `data_list` + has sublists `data` as follows: + data[0]: `target_spks` tuple + Examples: (0, 1, 2) + data[1]: Tensor containing estimaged sigmoid values. + [[0.0264, 0.9995], + [0.0112, 1.0000], + ..., + [1.0000, 0.0512]] Returns: sum_pred (Tensor): @@ -1152,7 +1177,8 @@ def get_integrated_preds_list( uniq_id_list (list): List containing `uniq_id` values. test_data_collection (collections.DiarizationLabelEntity): - Class instance that is containing session information such as targeted speaker indices, audio filepaths and RTTM filepaths. + Class instance that is containing session information such as targeted speaker indices, + audio filepaths and RTTM filepaths. preds_list (list): List containing tensors filled with sigmoid values. @@ -1177,9 +1203,11 @@ def get_emb_clus_infer(self, cluster_embeddings): @torch.no_grad() def diarize(self) -> Optional[List[Optional[List[Tuple[DiarizationErrorRate, Dict]]]]]: """ - Launch diarization pipeline which starts from VAD (or a oracle VAD stamp generation), initialization clustering and multiscale diarization decoder (MSDD). - Note that the result of MSDD can include multiple speakers at the same time. Therefore, RTTM output of MSDD needs to be based on `make_rttm_with_overlap()` - function that can generate overlapping timestamps. `self.run_overlap_aware_eval()` function performs DER evaluation. + Launch diarization pipeline which starts from VAD (or a oracle VAD stamp generation), + initialization clustering and multiscale diarization decoder (MSDD). Note that the result of MSDD + can include multiple speakers at the same time. Therefore, RTTM output of MSDD needs to be based on + `make_rttm_with_overlap()` function that can generate overlapping timestamps. + `self.run_overlap_aware_eval()` function performs DER evaluation. """ self.clustering_embedding.prepare_cluster_embs_infer() self.msdd_model.pairwise_infer = True @@ -1192,10 +1220,11 @@ def get_range_average( self, signals: torch.Tensor, emb_vectors: torch.Tensor, diar_window_index: int, test_data_collection: List[Any] ) -> Tuple[torch.Tensor, torch.Tensor, int]: """ - This function is only used when `split_infer=True`. This module calculates cluster-average embeddings for the given short range. - The range length is set by `self.diar_window_length`, and each cluster-average is only calculated for the specified range. - Note that if the specified range does not contain some speakers (e.g. the range contains speaker 1, 3) compared to the global speaker sets - (e.g. speaker 1, 2, 3, 4) then the missing speakers (e.g. speakers 2, 4) are assigned with zero-filled cluster-average speaker embedding. + This function is only used when `split_infer=True`. This module calculates cluster-average embeddings + for the given short range. The range length is set by `self.diar_window_length`, and each cluster-average + is only calculated for the specified range. Note that if the specified range does not contain some speakers + (e.g. the range contains speaker 1, 3) compared to the global speaker sets (e.g. speaker 1, 2, 3, 4) then + the missing speakers (e.g. speakers 2, 4) are assigned with zero-filled cluster-average speaker embedding. Args: signals (Tensor): @@ -1207,7 +1236,8 @@ def get_range_average( diar_window_index (int): Index of split diarization wondows. test_data_collection (collections.DiarizationLabelEntity) - Class instance that is containing session information such as targeted speaker indices, audio filepath and RTTM filepath. + Class instance that is containing session information such as targeted speaker indices, + audio filepath and RTTM filepath. Returns: return emb_vectors_split (Tensor): @@ -1237,7 +1267,8 @@ def get_range_average( ) target_clus_label_bool = target_clus_label_tensor == test_data_collection.target_spks[spk_idx] - # There are cases where there is no corresponding speaker in split range, so any(target_clus_label_bool) could be False. + # There are cases where there is no corresponding speaker in split range, + # so any(target_clus_label_bool) could be False. if any(target_clus_label_bool): emb_vectors_split[:, :, spk_idx] = torch.mean(emb_seq[target_clus_label_bool], dim=0) @@ -1263,14 +1294,17 @@ def get_range_clus_avg_emb( self, test_batch: List[torch.Tensor], _test_data_collection: List[Any], device: torch.device('cpu') ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ - This function is only used when `get_range_average` function is called. This module calculates cluster-average embeddings for - the given short range. The range length is set by `self.diar_window_length`, and each cluster-average is only calculated for the specified range. + This function is only used when `get_range_average` function is called. This module calculates + cluster-average embeddings for the given short range. The range length is set by `self.diar_window_length`, + and each cluster-average is only calculated for the specified range. Args: test_batch: (list) - List containing embedding sequences, length of embedding sequences, ground truth labels (if exists) and initializing embedding vectors. + List containing embedding sequences, length of embedding sequences, ground truth labels + (if exists) and initializing embedding vectors. test_data_collection: (list) - List containing test-set dataloader contents. test_data_collection includes wav file path, RTTM file path, clustered speaker indices. + List containing test-set dataloader contents. test_data_collection includes wav file path, + RTTM file path, clustered speaker indices. Returns: sess_emb_vectors (Tensor): @@ -1305,16 +1339,18 @@ def diar_infer( self, test_batch: List[torch.Tensor], test_data_collection: List[Any] ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ - Launch forward_infer() function by feeding the session-wise embedding sequences to get pairwise speaker prediction values. - If split_infer is True, the input audio clips are broken into short sequences then cluster average embeddings are calculated - for inference. Split-infer might result in an improved results if calculating clustering average on the shorter tim-espan can - help speaker assignment. + Launch forward_infer() function by feeding the session-wise embedding sequences to get pairwise + speaker prediction values. If split_infer is True, the input audio clips are broken into short + sequences then cluster average embeddings are calculated for inference. Split-infer might result in + an improved results if calculating clustering average on the shorter tim-espan can help speaker assignment. Args: test_batch: (list) - List containing embedding sequences, length of embedding sequences, ground truth labels (if exists) and initializing embedding vectors. + List containing embedding sequences, length of embedding sequences, ground truth labels (if exists) + and initializing embedding vectors. test_data_collection: (list) - List containing test-set dataloader contents. test_data_collection includes wav file path, RTTM file path, clustered speaker indices. + List containing test-set dataloader contents. test_data_collection includes wav file path, + RTTM file path, clustered speaker indices. Returns: preds (Tensor): @@ -1353,8 +1389,9 @@ def diar_infer( @torch.no_grad() def run_pairwise_diarization(self) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]: """ - Setup the parameters needed for batch inference and run batch inference. Note that each sample is pairwise speaker input. - The pairwise inference results are reconstructed to make session-wise prediction results. + Setup the parameters needed for batch inference and run batch inference. Note that each sample is + pairwise speaker input. The pairwise inference results are reconstructed to make session-wise + prediction results. Returns: integrated_preds_list: (list) @@ -1405,7 +1442,8 @@ def run_overlap_aware_eval( - If threshold is 0.0, all speakers are considered active at any time step. """ logging.info( - f" [Threshold: {threshold:.4f}] [use_clus_as_main={self.use_clus_as_main}] [diar_window={self.diar_window_length}]" + f" [Threshold: {threshold:.4f}] [use_clus_as_main={self.use_clus_as_main}] " + f"[diar_window={self.diar_window_length}]" ) outputs = [] manifest_filepath = self.msdd_model.cfg.test_ds.manifest_filepath diff --git a/nemo/collections/asr/models/sortformer_diar_models.py b/nemo/collections/asr/models/sortformer_diar_models.py new file mode 100644 index 0000000000000..f6b0eab4c8950 --- /dev/null +++ b/nemo/collections/asr/models/sortformer_diar_models.py @@ -0,0 +1,579 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import random +from collections import OrderedDict +from typing import Dict, List, Optional, Union + +import torch +from hydra.utils import instantiate +from omegaconf import DictConfig +from pytorch_lightning import Trainer +from tqdm import tqdm + +from nemo.collections.asr.data.audio_to_diar_label import AudioToSpeechE2ESpkDiarDataset +from nemo.collections.asr.data.audio_to_diar_label_lhotse import LhotseAudioToSpeechE2ESpkDiarDataset +from nemo.collections.asr.metrics.multi_binary_acc import MultiBinaryAccuracy +from nemo.collections.asr.models.asr_model import ExportableEncDecModel +from nemo.collections.asr.parts.preprocessing.features import WaveformFeaturizer +from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations +from nemo.collections.asr.parts.utils.asr_multispeaker_utils import get_ats_targets, get_pil_targets +from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config +from nemo.core.classes import ModelPT +from nemo.core.classes.common import PretrainedModelInfo +from nemo.core.neural_types import AudioSignal, LengthsType, NeuralType +from nemo.core.neural_types.elements import ProbsType +from nemo.utils import logging + +__all__ = ['SortformerEncLabelModel'] + + +class SortformerEncLabelModel(ModelPT, ExportableEncDecModel): + """ + Encoder class for Sortformer diarization model. + Model class creates training, validation methods for setting up data performing model forward pass. + + This model class expects config dict for: + * preprocessor + * Transformer Encoder + * FastConformer Encoder + * Sortformer Modules + """ + + @classmethod + def list_available_models(cls) -> List[PretrainedModelInfo]: + """ + This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud. + + Returns: + List of available pre-trained models. + """ + result = [] + return result + + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + """ + Initialize an Sortformer Diarizer model and a pretrained NEST encoder. + In this init function, training and validation datasets are prepared. + """ + random.seed(42) + self._trainer = trainer if trainer else None + self._cfg = cfg + + if self._trainer: + self.world_size = trainer.num_nodes * trainer.num_devices + else: + self.world_size = 1 + + if self._trainer is not None and self._cfg.get('augmentor', None) is not None: + self.augmentor = process_augmentations(self._cfg.augmentor) + else: + self.augmentor = None + super().__init__(cfg=self._cfg, trainer=trainer) + self.preprocessor = SortformerEncLabelModel.from_config_dict(self._cfg.preprocessor) + + if hasattr(self._cfg, 'spec_augment') and self._cfg.spec_augment is not None: + self.spec_augmentation = SortformerEncLabelModel.from_config_dict(self._cfg.spec_augment) + else: + self.spec_augmentation = None + + self.encoder = SortformerEncLabelModel.from_config_dict(self._cfg.encoder).to(self.device) + self.sortformer_modules = SortformerEncLabelModel.from_config_dict(self._cfg.sortformer_modules).to( + self.device + ) + self.transformer_encoder = SortformerEncLabelModel.from_config_dict(self._cfg.transformer_encoder).to( + self.device + ) + if self._cfg.encoder.d_model != self._cfg.model_defaults.tf_d_model: + self.sortformer_modules.encoder_proj = self.sortformer_modules.encoder_proj.to(self.device) + else: + self.sortformer_modules.encoder_proj = None + self._init_loss_weights() + + self.eps = 1e-3 + self.loss = instantiate(self._cfg.loss) + + self.streaming_mode = self._cfg.get("streaming_mode", False) + self.save_hyperparameters("cfg") + self._init_eval_metrics() + + speaker_inds = list(range(self._cfg.max_num_of_spks)) + self.speaker_permutations = torch.tensor(list(itertools.permutations(speaker_inds))) # Get all permutations + + def _init_loss_weights(self): + pil_weight = self._cfg.get("pil_weight", 0.0) + ats_weight = self._cfg.get("ats_weight", 1.0) + if pil_weight + ats_weight == 0: + raise ValueError(f"weights for PIL {pil_weight} and ATS {ats_weight} cannot sum to 0") + self.pil_weight = pil_weight / (pil_weight + ats_weight) + self.ats_weight = ats_weight / (pil_weight + ats_weight) + logging.info(f"Normalized weights for PIL {self.pil_weight} and ATS {self.ats_weight}") + + def _init_eval_metrics(self): + """ + If there is no label, then the evaluation metrics will be based on Permutation Invariant Loss (PIL). + """ + self._accuracy_test = MultiBinaryAccuracy() + self._accuracy_train = MultiBinaryAccuracy() + self._accuracy_valid = MultiBinaryAccuracy() + + self._accuracy_test_ats = MultiBinaryAccuracy() + self._accuracy_train_ats = MultiBinaryAccuracy() + self._accuracy_valid_ats = MultiBinaryAccuracy() + + def _reset_train_metrics(self): + self._accuracy_train.reset() + self._accuracy_train_ats.reset() + + def _reset_valid_metrics(self): + self._accuracy_valid.reset() + self._accuracy_valid_ats.reset() + + def __setup_dataloader_from_config(self, config): + # Switch to lhotse dataloader if specified in the config + if config.get("use_lhotse"): + return get_lhotse_dataloader_from_config( + config, + global_rank=self.global_rank, + world_size=self.world_size, + dataset=LhotseAudioToSpeechE2ESpkDiarDataset(cfg=config), + ) + + featurizer = WaveformFeaturizer( + sample_rate=config['sample_rate'], int_values=config.get('int_values', False), augmentor=self.augmentor + ) + + if 'manifest_filepath' in config and config['manifest_filepath'] is None: + logging.warning(f"Could not load dataset as `manifest_filepath` was None. Provided config : {config}") + return None + + logging.info(f"Loading dataset from {config.manifest_filepath}") + + if self._trainer is not None: + global_rank = self._trainer.global_rank + else: + global_rank = 0 + + dataset = AudioToSpeechE2ESpkDiarDataset( + manifest_filepath=config.manifest_filepath, + soft_label_thres=config.soft_label_thres, + session_len_sec=config.session_len_sec, + num_spks=config.num_spks, + featurizer=featurizer, + window_stride=self._cfg.preprocessor.window_stride, + global_rank=global_rank, + soft_targets=config.soft_targets if 'soft_targets' in config else False, + ) + + self.data_collection = dataset.collection + self.collate_ds = dataset + + dataloader_instance = torch.utils.data.DataLoader( + dataset=dataset, + batch_size=config.batch_size, + collate_fn=self.collate_ds.eesd_train_collate_fn, + drop_last=config.get('drop_last', False), + shuffle=False, + num_workers=config.get('num_workers', 1), + pin_memory=config.get('pin_memory', False), + ) + return dataloader_instance + + def setup_training_data(self, train_data_config: Optional[Union[DictConfig, Dict]]): + self._train_dl = self.__setup_dataloader_from_config( + config=train_data_config, + ) + + def setup_validation_data(self, val_data_layer_config: Optional[Union[DictConfig, Dict]]): + self._validation_dl = self.__setup_dataloader_from_config( + config=val_data_layer_config, + ) + + def setup_test_data(self, test_data_config: Optional[Union[DictConfig, Dict]]): + self._test_dl = self.__setup_dataloader_from_config( + config=test_data_config, + ) + + def test_dataloader(self): + if self._test_dl is not None: + return self._test_dl + return None + + @property + def input_types(self) -> Optional[Dict[str, NeuralType]]: + if hasattr(self.preprocessor, '_sample_rate'): + audio_eltype = AudioSignal(freq=self.preprocessor._sample_rate) + else: + audio_eltype = AudioSignal() + return { + "audio_signal": NeuralType(('B', 'T'), audio_eltype), + "audio_signal_length": NeuralType(('B',), LengthsType()), + } + + @property + def output_types(self) -> Dict[str, NeuralType]: + return OrderedDict( + { + "preds": NeuralType(('B', 'T', 'C'), ProbsType()), + } + ) + + def frontend_encoder(self, processed_signal, processed_signal_length): + """ + Generate encoder outputs from frontend encoder. + + Args: + processed_signal (torch.Tensor): tensor containing audio-feature (mel spectrogram, mfcc, etc.) + processed_signal_length (torch.Tensor): tensor containing lengths of audio signal in integers + + Returns: + emb_seq (torch.Tensor): tensor containing encoder outputs + emb_seq_length (torch.Tensor): tensor containing lengths of encoder outputs + """ + # Spec augment is not applied during evaluation/testing + if self.spec_augmentation is not None and self.training: + processed_signal = self.spec_augmentation(input_spec=processed_signal, length=processed_signal_length) + emb_seq, emb_seq_length = self.encoder(audio_signal=processed_signal, length=processed_signal_length) + emb_seq = emb_seq.transpose(1, 2) + if self.sortformer_modules.encoder_proj is not None: + emb_seq = self.sortformer_modules.encoder_proj(emb_seq) + return emb_seq, emb_seq_length + + def forward_infer(self, emb_seq): + """ + The main forward pass for diarization for offline diarization inference. + + Args: + emb_seq (torch.Tensor): tensor containing FastConformer encoder states (embedding vectors). + Dimension: (batch_size, diar_frame_count, emb_dim) + + Returns: + preds (torch.Tensor): Sorted tensor containing Sigmoid values for predicted speaker labels. + Dimension: (batch_size, diar_frame_count, num_speakers) + """ + encoder_mask = self.sortformer_modules.length_to_mask(emb_seq) + trans_emb_seq = self.transformer_encoder(encoder_states=emb_seq, encoder_mask=encoder_mask) + preds = self.sortformer_modules.forward_speaker_sigmoids(trans_emb_seq) + return preds + + def process_signal(self, audio_signal, audio_signal_length): + """ + Extract audio features from time-series signal for further processing in the model. + + This function performs the following steps: + 1. Moves the audio signal to the correct device. + 2. Normalizes the time-series audio signal. + 3. Extrac audio feature from from the time-series audio signal using the model's preprocessor. + + Args: + audio_signal (torch.Tensor): The input audio signal. + Shape: (batch_size, num_samples) + audio_signal_length (torch.Tensor): The length of each audio signal in the batch. + Shape: (batch_size,) + + Returns: + tuple: A tuple containing: + - processed_signal (torch.Tensor): The preprocessed audio signal. + Shape: (batch_size, num_features, num_frames) + - processed_signal_length (torch.Tensor): The length of each processed signal. + Shape: (batch_size,) + """ + audio_signal = audio_signal.to(self.device) + audio_signal = (1 / (audio_signal.max() + self.eps)) * audio_signal + processed_signal, processed_signal_length = self.preprocessor( + input_signal=audio_signal, length=audio_signal_length + ) + return processed_signal, processed_signal_length + + def forward( + self, + audio_signal, + audio_signal_length, + ): + """ + Forward pass for training and inference. + + Args: + audio_signal (torch.Tensor): tensor containing audio waveform + Dimension: (batch_size, num_samples) + audio_signal_length (torch.Tensor): tensor containing lengths of audio waveforms + Dimension: (batch_size,) + + Returns: + preds (torch.Tensor): Sorted tensor containing predicted speaker labels + Dimension: (batch_size, diar_frame_count, num_speakers) + """ + processed_signal, processed_signal_length = self.process_signal( + audio_signal=audio_signal, audio_signal_length=audio_signal_length + ) + processed_signal = processed_signal[:, :, : processed_signal_length.max()] + if self._cfg.get("streaming_mode", False): + raise NotImplementedError("Streaming mode is not implemented yet.") + else: + emb_seq, _ = self.frontend_encoder( + processed_signal=processed_signal, processed_signal_length=processed_signal_length + ) + preds = self.forward_infer(emb_seq) + return preds + + def _get_aux_train_evaluations(self, preds, targets, target_lens) -> dict: + """ + Compute auxiliary training evaluations including losses and metrics. + + This function calculates various losses and metrics for the training process, + including Arrival Time Sort (ATS) Loss and Permutation Invariant Loss (PIL) + based evaluations. + + Args: + preds (torch.Tensor): Predicted speaker labels. + Shape: (batch_size, diar_frame_count, num_speakers) + targets (torch.Tensor): Ground truth speaker labels. + Shape: (batch_size, diar_frame_count, num_speakers) + target_lens (torch.Tensor): Lengths of target sequences. + Shape: (batch_size,) + + Returns: + (dict): A dictionary containing the following training metrics. + """ + targets_ats = get_ats_targets(targets.clone(), preds, speaker_permutations=self.speaker_permutations) + targets_pil = get_pil_targets(targets.clone(), preds, speaker_permutations=self.speaker_permutations) + ats_loss = self.loss(probs=preds, labels=targets_ats, target_lens=target_lens) + pil_loss = self.loss(probs=preds, labels=targets_pil, target_lens=target_lens) + loss = self.ats_weight * ats_loss + self.pil_weight * pil_loss + + self._accuracy_train(preds, targets_pil, target_lens) + train_f1_acc, train_precision, train_recall = self._accuracy_train.compute() + + self._accuracy_train_ats(preds, targets_ats, target_lens) + train_f1_acc_ats, _, _ = self._accuracy_train_ats.compute() + + train_metrics = { + 'loss': loss, + 'ats_loss': ats_loss, + 'pil_loss': pil_loss, + 'learning_rate': self._optimizer.param_groups[0]['lr'], + 'train_f1_acc': train_f1_acc, + 'train_precision': train_precision, + 'train_recall': train_recall, + 'train_f1_acc_ats': train_f1_acc_ats, + } + return train_metrics + + def training_step(self, batch: list) -> dict: + """ + Performs a single training step. + + Args: + batch (list): A list containing the following elements: + - audio_signal (torch.Tensor): The input audio signal in time-series format. + - audio_signal_length (torch.Tensor): The length of each audio signal in the batch. + - targets (torch.Tensor): The target labels for the batch. + - target_lens (torch.Tensor): The length of each target sequence in the batch. + + Returns: + (dict): A dictionary containing the 'loss' key with the calculated loss value. + """ + audio_signal, audio_signal_length, targets, target_lens = batch + preds = self.forward(audio_signal=audio_signal, audio_signal_length=audio_signal_length) + train_metrics = self._get_aux_train_evaluations(preds, targets, target_lens) + self._reset_train_metrics() + self.log_dict(train_metrics, sync_dist=True, on_step=True, on_epoch=False, logger=True) + return {'loss': train_metrics['loss']} + + def _get_aux_validation_evaluations(self, preds, targets, target_lens) -> dict: + """ + Compute auxiliary validation evaluations including losses and metrics. + + This function calculates various losses and metrics for the training process, + including Arrival Time Sort (ATS) Loss and Permutation Invariant Loss (PIL) + based evaluations. + + Args: + preds (torch.Tensor): Predicted speaker labels. + Shape: (batch_size, diar_frame_count, num_speakers) + targets (torch.Tensor): Ground truth speaker labels. + Shape: (batch_size, diar_frame_count, num_speakers) + target_lens (torch.Tensor): Lengths of target sequences. + Shape: (batch_size,) + + Returns: + val_metrics (dict): A dictionary containing the following validation metrics + """ + targets_ats = get_ats_targets(targets.clone(), preds, speaker_permutations=self.speaker_permutations) + targets_pil = get_pil_targets(targets.clone(), preds, speaker_permutations=self.speaker_permutations) + + val_ats_loss = self.loss(probs=preds, labels=targets_ats, target_lens=target_lens) + val_pil_loss = self.loss(probs=preds, labels=targets_pil, target_lens=target_lens) + val_loss = self.ats_weight * val_ats_loss + self.pil_weight * val_pil_loss + + self._accuracy_valid(preds, targets_pil, target_lens) + val_f1_acc, val_precision, val_recall = self._accuracy_valid.compute() + + self._accuracy_valid_ats(preds, targets_ats, target_lens) + valid_f1_acc_ats, _, _ = self._accuracy_valid_ats.compute() + + self._accuracy_valid.reset() + self._accuracy_valid_ats.reset() + + val_metrics = { + 'val_loss': val_loss, + 'val_ats_loss': val_ats_loss, + 'val_pil_loss': val_pil_loss, + 'val_f1_acc': val_f1_acc, + 'val_precision': val_precision, + 'val_recall': val_recall, + 'val_f1_acc_ats': valid_f1_acc_ats, + } + return val_metrics + + def validation_step(self, batch: list, dataloader_idx: int = 0): + """ + Performs a single validation step. + + This method processes a batch of data during the validation phase. It forward passes + the audio signal through the model, computes various validation metrics, and stores + these metrics for later aggregation. + + Args: + batch (list): A list containing the following elements: + - audio_signal (torch.Tensor): The input audio signal. + - audio_signal_length (torch.Tensor): The length of each audio signal in the batch. + - targets (torch.Tensor): The target labels for the batch. + - target_lens (torch.Tensor): The length of each target sequence in the batch. + batch_idx (int): The index of the current batch. + dataloader_idx (int, optional): The index of the dataloader in case of multiple + validation dataloaders. Defaults to 0. + + Returns: + dict: A dictionary containing various validation metrics for this batch. + """ + audio_signal, audio_signal_length, targets, target_lens = batch + preds = self.forward( + audio_signal=audio_signal, + audio_signal_length=audio_signal_length, + ) + val_metrics = self._get_aux_validation_evaluations(preds, targets, target_lens) + if isinstance(self.trainer.val_dataloaders, list) and len(self.trainer.val_dataloaders) > 1: + self.validation_step_outputs[dataloader_idx].append(val_metrics) + else: + self.validation_step_outputs.append(val_metrics) + return val_metrics + + def multi_validation_epoch_end(self, outputs: list, dataloader_idx: int = 0): + if not outputs: + logging.warning(f"`outputs` is None; empty outputs for dataloader={dataloader_idx}") + return None + val_loss_mean = torch.stack([x['val_loss'] for x in outputs]).mean() + val_ats_loss_mean = torch.stack([x['val_ats_loss'] for x in outputs]).mean() + val_pil_loss_mean = torch.stack([x['val_pil_loss'] for x in outputs]).mean() + val_f1_acc_mean = torch.stack([x['val_f1_acc'] for x in outputs]).mean() + val_precision_mean = torch.stack([x['val_precision'] for x in outputs]).mean() + val_recall_mean = torch.stack([x['val_recall'] for x in outputs]).mean() + val_f1_acc_ats_mean = torch.stack([x['val_f1_acc_ats'] for x in outputs]).mean() + + self._reset_valid_metrics() + + multi_val_metrics = { + 'val_loss': val_loss_mean, + 'val_ats_loss': val_ats_loss_mean, + 'val_pil_loss': val_pil_loss_mean, + 'val_f1_acc': val_f1_acc_mean, + 'val_precision': val_precision_mean, + 'val_recall': val_recall_mean, + 'val_f1_acc_ats': val_f1_acc_ats_mean, + } + return {'log': multi_val_metrics} + + def _get_aux_test_batch_evaluations(self, batch_idx: int, preds, targets, target_lens): + """ + Compute auxiliary validation evaluations including losses and metrics. + + This function calculates various losses and metrics for the training process, + including Arrival Time Sort (ATS) Loss and Permutation Invariant Loss (PIL) + based evaluations. + + Args: + preds (torch.Tensor): Predicted speaker labels. + Shape: (batch_size, diar_frame_count, num_speakers) + targets (torch.Tensor): Ground truth speaker labels. + Shape: (batch_size, diar_frame_count, num_speakers) + target_lens (torch.Tensor): Lengths of target sequences. + Shape: (batch_size,) + """ + targets_ats = get_ats_targets(targets.clone(), preds, speaker_permutations=self.speaker_permutations) + targets_pil = get_pil_targets(targets.clone(), preds, speaker_permutations=self.speaker_permutations) + self._accuracy_test(preds, targets_pil, target_lens) + f1_acc, precision, recall = self._accuracy_test.compute() + self.batch_f1_accs_list.append(f1_acc) + self.batch_precision_list.append(precision) + self.batch_recall_list.append(recall) + logging.info(f"batch {batch_idx}: f1_acc={f1_acc}, precision={precision}, recall={recall}") + + self._accuracy_test_ats(preds, targets_ats, target_lens) + f1_acc_ats, precision_ats, recall_ats = self._accuracy_test_ats.compute() + self.batch_f1_accs_ats_list.append(f1_acc_ats) + logging.info( + f"batch {batch_idx}: f1_acc_ats={f1_acc_ats}, precision_ats={precision_ats}, recall_ats={recall_ats}" + ) + + self._accuracy_test.reset() + self._accuracy_test_ats.reset() + + def test_batch( + self, + ): + """ + Perform batch testing on the model. + + This method iterates through the test data loader, making predictions for each batch, + and calculates various evaluation metrics. It handles both single and multi-sample batches. + """ + ( + self.preds_total_list, + self.batch_f1_accs_list, + self.batch_precision_list, + self.batch_recall_list, + self.batch_f1_accs_ats_list, + ) = ([], [], [], [], []) + + with torch.no_grad(): + for batch_idx, batch in enumerate(tqdm(self._test_dl)): + audio_signal, audio_signal_length, targets, target_lens = batch + audio_signal = audio_signal.to(self.device) + audio_signal_length = audio_signal_length.to(self.device) + preds = self.forward( + audio_signal=audio_signal, + audio_signal_length=audio_signal_length, + ) + preds = preds.detach().to('cpu') + if preds.shape[0] == 1: # batch size = 1 + self.preds_total_list.append(preds) + else: + self.preds_total_list.extend(torch.split(preds, [1] * preds.shape[0])) + torch.cuda.empty_cache() + self._get_aux_test_batch_evaluations(batch_idx, preds, targets, target_lens) + + logging.info(f"Batch F1Acc. MEAN: {torch.mean(torch.tensor(self.batch_f1_accs_list))}") + logging.info(f"Batch Precision MEAN: {torch.mean(torch.tensor(self.batch_precision_list))}") + logging.info(f"Batch Recall MEAN: {torch.mean(torch.tensor(self.batch_recall_list))}") + logging.info(f"Batch ATS F1Acc. MEAN: {torch.mean(torch.tensor(self.batch_f1_accs_ats_list))}") + + def diarize( + self, + ): + """One-clieck runner function for diarization.""" + # TODO: A direct one-click runner function that generates + # speaker labels from audio file path lists. + raise NotImplementedError diff --git a/nemo/collections/asr/modules/sortformer_modules.py b/nemo/collections/asr/modules/sortformer_modules.py new file mode 100644 index 0000000000000..d99bf3b93e38e --- /dev/null +++ b/nemo/collections/asr/modules/sortformer_modules.py @@ -0,0 +1,108 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from nemo.core.classes.exportable import Exportable +from nemo.core.classes.module import NeuralModule + +__all__ = ['SortformerModules'] + + +class SortformerModules(NeuralModule, Exportable): + """ + A class including auxiliary functions for Sortformer models. + This class contains and will contain the following functions that performs streaming features, + and any neural layers that are not included in the NeMo neural modules (e.g. Transformer, Fast-Conformer). + """ + + def init_weights(self, m): + """Init weights for linear layers.""" + if type(m) == nn.Linear: + torch.nn.init.xavier_uniform_(m.weight) + m.bias.data.fill_(0.01) + + def __init__( + self, + num_spks: int = 4, + hidden_size: int = 192, + dropout_rate: float = 0.5, + fc_d_model: int = 512, + tf_d_model: int = 192, + ): + """ + Args: + num_spks (int): + Max number of speakers that are processed by the model. + hidden_size (int): + Number of hidden units in sequence models and intermediate layers. + dropout_rate (float): + Dropout rate for linear layers, CNN and LSTM. + fc_d_model (int): + Dimension of the embedding vectors. + tf_d_model (int): + Dimension of the embedding vectors. + """ + super().__init__() + self.fc_d_model = fc_d_model + self.tf_d_model = tf_d_model + self.hidden_size = tf_d_model + self.unit_n_spks: int = num_spks + self.hidden_to_spks = nn.Linear(2 * self.hidden_size, self.unit_n_spks) + self.first_hidden_to_hidden = nn.Linear(self.hidden_size, self.hidden_size) + self.single_hidden_to_spks = nn.Linear(self.hidden_size, self.unit_n_spks) + self.dropout = nn.Dropout(dropout_rate) + self.encoder_proj = nn.Linear(self.fc_d_model, self.tf_d_model) + + def length_to_mask(self, context_embs): + """ + Convert length values to encoder mask input tensor. + + Args: + lengths (torch.Tensor): tensor containing lengths of sequences + max_len (int): maximum sequence length + + Returns: + mask (torch.Tensor): tensor of shape (batch_size, max_len) containing 0's + in the padded region and 1's elsewhere + """ + lengths = torch.tensor([context_embs.shape[1]] * context_embs.shape[0]) + batch_size = context_embs.shape[0] + max_len = context_embs.shape[1] + # create a tensor with the shape (batch_size, 1) filled with ones + row_vector = torch.arange(max_len).unsqueeze(0).expand(batch_size, -1).to(lengths.device) + # create a tensor with the shape (batch_size, max_len) filled with lengths + length_matrix = lengths.unsqueeze(1).expand(-1, max_len).to(lengths.device) + # create a mask by comparing the row vector and length matrix + mask = row_vector < length_matrix + return mask.float().to(context_embs.device) + + def forward_speaker_sigmoids(self, hidden_out): + """ + A set of layers for predicting speaker probabilities with a sigmoid activation function. + + Args: + hidden_out (torch.Tensor): tensor of shape (batch_size, seq_len, hidden_size) + + Returns: + preds (torch.Tensor): tensor of shape (batch_size, seq_len, num_spks) containing speaker probabilities + """ + hidden_out = self.dropout(F.relu(hidden_out)) + hidden_out = self.first_hidden_to_hidden(hidden_out) + hidden_out = self.dropout(F.relu(hidden_out)) + spk_preds = self.single_hidden_to_spks(hidden_out) + preds = nn.Sigmoid()(spk_preds) + return preds diff --git a/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py b/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py new file mode 100644 index 0000000000000..66cfcc75f49f4 --- /dev/null +++ b/nemo/collections/asr/parts/utils/asr_multispeaker_utils.py @@ -0,0 +1,416 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Optional, Union + +import torch +from lhotse import SupervisionSet +from lhotse.cut import MixedCut, MonoCut + + +def find_first_nonzero(mat: torch.Tensor, max_cap_val=-1, thres: float = 0.5) -> torch.Tensor: + """ + Finds the first nonzero value in the matrix, discretizing it to the specified maximum capacity. + + Args: + mat (Tensor): A torch tensor representing the matrix. + max_cap_val (int): The maximum capacity to which the matrix values will be discretized. + thres (float): The threshold value for discretizing the matrix values. + + Returns: + mask_max_indices (Tensor): A torch tensor representing the discretized matrix with the first + nonzero value in each row. + """ + # Discretize the matrix to the specified maximum capacity + labels_discrete = mat.clone() + labels_discrete[labels_discrete < thres] = 0 + labels_discrete[labels_discrete >= thres] = 1 + + # non zero values mask + non_zero_mask = labels_discrete != 0 + # operations on the mask to find first nonzero values in the rows + mask_max_values, mask_max_indices = torch.max(non_zero_mask, dim=1) + # if the max-mask is zero, there is no nonzero value in the row + mask_max_indices[mask_max_values == 0] = max_cap_val + return mask_max_indices + + +def find_best_permutation(match_score: torch.Tensor, speaker_permutations: torch.Tensor) -> torch.Tensor: + """ + Finds the best permutation indices based on the match score. + + Args: + match_score (torch.Tensor): A tensor containing the match scores for each permutation. + Shape: (batch_size, num_permutations) + speaker_permutations (torch.Tensor): A tensor containing all possible speaker permutations. + Shape: (num_permutations, num_speakers) + + Returns: + torch.Tensor: A tensor containing the best permutation indices for each batch. + Shape: (batch_size, num_speakers) + """ + batch_best_perm = torch.argmax(match_score, axis=1) + rep_speaker_permutations = speaker_permutations.repeat(batch_best_perm.shape[0], 1).to(match_score.device) + perm_size = speaker_permutations.shape[0] + global_inds_vec = ( + torch.arange(0, perm_size * batch_best_perm.shape[0], perm_size).to(batch_best_perm.device) + batch_best_perm + ) + return rep_speaker_permutations[global_inds_vec.to(rep_speaker_permutations.device), :] + + +def reconstruct_labels(labels: torch.Tensor, batch_perm_inds: torch.Tensor) -> torch.Tensor: + """ + Reconstructs the labels using the best permutation indices with matrix operations. + + Args: + labels (torch.Tensor): A tensor containing the original labels. + Shape: (batch_size, num_frames, num_speakers) + batch_perm_inds (torch.Tensor): A tensor containing the best permutation indices for each batch. + Shape: (batch_size, num_speakers) + + Returns: + torch.Tensor: A tensor containing the reconstructed labels using the best permutation indices. + Shape: (batch_size, num_frames, num_speakers) + """ + # Expanding batch_perm_inds to align with labels dimensions + batch_size, num_frames, num_speakers = labels.shape + batch_perm_inds_exp = batch_perm_inds.unsqueeze(1).expand(-1, num_frames, -1) + + # Reconstructing the labels using advanced indexing + reconstructed_labels = torch.gather(labels, 2, batch_perm_inds_exp) + return reconstructed_labels + + +def get_ats_targets( + labels: torch.Tensor, + preds: torch.Tensor, + speaker_permutations: torch.Tensor, + thres: float = 0.5, + tolerance: float = 0, +) -> torch.Tensor: + """ + Sorts labels and predictions to get the optimal of all arrival-time ordered permutations. + + Args: + labels (torch.Tensor): A tensor containing the original labels. + Shape: (batch_size, num_frames, num_speakers) + preds (torch.Tensor): A tensor containing the predictions. + Shape: (batch_size, num_frames, num_speakers) + speaker_permutations (torch.Tensor): A tensor containing all possible speaker permutations. + Shape: (num_permutations, num_speakers) + thres (float): The threshold value for discretizing the matrix values. Default is 0.5. + tolerance (float): The tolerance for comparing the first speech frame indices. Default is 0. + + Returns: + torch.Tensor: A tensor containing the reconstructed labels using the best permutation indices. + Shape: (batch_size, num_frames, num_speakers) + """ + # Find the first nonzero frame index for each speaker in each batch + nonzero_ind = find_first_nonzero( + mat=labels, max_cap_val=labels.shape[1], thres=thres + ) # (batch_size, num_speakers) + + # Sort the first nonzero frame indices for arrival-time ordering + sorted_values = torch.sort(nonzero_ind)[0] # (batch_size, num_speakers) + perm_size = speaker_permutations.shape[0] # Scalar value (num_permutations) + permed_labels = labels[:, :, speaker_permutations] # (batch_size, num_frames, num_permutations, num_speakers) + permed_nonzero_ind = find_first_nonzero( + mat=permed_labels, max_cap_val=labels.shape[1] + ) # (batch_size, num_permutations, num_speakers) + + # Compare the first frame indices of sorted labels with those of the permuted labels using tolerance + perm_compare = ( + torch.abs(sorted_values.unsqueeze(1) - permed_nonzero_ind) <= tolerance + ) # (batch_size, num_permutations, num_speakers) + perm_mask = torch.all(perm_compare, dim=2).float() # (batch_size, num_permutations) + preds_rep = torch.unsqueeze(preds, 2).repeat( + 1, 1, perm_size, 1 + ) # Exapnd the preds: (batch_size, num_frames, num_permutations, num_speakers) + + # Compute the match score for each permutation by comparing permuted labels with preds + match_score = ( + torch.sum(permed_labels * preds_rep, axis=1).sum(axis=2) * perm_mask + ) # (batch_size, num_permutations) + batch_perm_inds = find_best_permutation(match_score, speaker_permutations) # (batch_size, num_speakers) + max_score_permed_labels = reconstruct_labels(labels, batch_perm_inds) # (batch_size, num_frames, num_speakers) + return max_score_permed_labels # (batch_size, num_frames, num_speakers) + + +def get_pil_targets(labels: torch.Tensor, preds: torch.Tensor, speaker_permutations: torch.Tensor) -> torch.Tensor: + """ + Sorts labels and predictions to get the optimal permutation based on the match score. + + Args: + labels (torch.Tensor): A tensor containing the ground truth labels. + Shape: (batch_size, num_speakers, num_classes) + preds (torch.Tensor): A tensor containing the predicted values. + Shape: (batch_size, num_speakers, num_classes) + speaker_permutations (torch.Tensor): A tensor containing all possible speaker permutations. + Shape: (num_permutations, num_speakers) + + Returns: + torch.Tensor: A tensor of permuted labels that best match the predictions. + Shape: (batch_size, num_speakers, num_classes) + """ + permed_labels = labels[:, :, speaker_permutations] # (batch_size, num_classes, num_permutations, num_speakers) + # Repeat preds to match permutations for comparison + preds_rep = torch.unsqueeze(preds, 2).repeat( + 1, 1, speaker_permutations.shape[0], 1 + ) # (batch_size, num_speakers, num_permutations, num_classes) + match_score = torch.sum(permed_labels * preds_rep, axis=1).sum(axis=2) # (batch_size, num_permutations) + batch_perm_inds = find_best_permutation(match_score, speaker_permutations) # (batch_size, num_speakers) + # Reconstruct labels based on the best permutation for each batch + max_score_permed_labels = reconstruct_labels(labels, batch_perm_inds) # (batch_size, num_speakers, num_classes) + return max_score_permed_labels # (batch_size, num_speakers, num_classes) + + +def find_segments_from_rttm( + recording_id: str, + rttms: SupervisionSet, + start_after: float, + end_before: float, + adjust_offset: bool = True, + tolerance: float = 0.001, +): + """ + Finds segments from the given rttm file. + This function is designed to replace rttm + + Args: + recording_id (str): The recording ID in string format. + rttms (SupervisionSet): The SupervisionSet instance. + start_after (float): The start time after which segments are selected. + end_before (float): The end time before which segments are selected. + adjust_offset (bool): Whether to adjust the offset of the segments. + tolerance (float): The tolerance for time matching. 0.001 by default. + + Returns: + segments (List[SupervisionSegment]): A list of SupervisionSegment instances. + """ + segment_by_recording_id = rttms._segments_by_recording_id + if segment_by_recording_id is None: + from cytoolz import groupby + + segment_by_recording_id = groupby(lambda seg: seg.recording_id, rttms) + + return [ + # We only modify the offset - the duration remains the same, as we're only shifting the segment + # relative to the Cut's start, and not truncating anything. + segment.with_offset(-start_after) if adjust_offset else segment + for segment in segment_by_recording_id.get(recording_id, []) + if segment.start < end_before + tolerance and segment.end > start_after + tolerance + ] + + +def get_mask_from_segments( + segments: list, + a_cut: Optional[Union[MonoCut, MixedCut]], + speaker_to_idx_map: torch.Tensor, + num_speakers: int = 4, + feat_per_sec: int = 100, + ignore_num_spk_mismatch: bool = False, +): + """ + Generate mask matrix from segments list. + This function is needed for speaker diarization with ASR model trainings. + + Args: + segments: A list of Lhotse Supervision segments iterator. + cut (MonoCut, MixedCut): Lhotse MonoCut or MixedCut instance. + speaker_to_idx_map (dict): A dictionary mapping speaker names to indices. + num_speakers (int): max number of speakers for all cuts ("mask" dim0), 4 by default + feat_per_sec (int): number of frames per second, 100 by default, 0.01s frame rate + ignore_num_spk_mismatch (bool): This is a temporary solution to handle speaker mismatch. + Will be removed in the future. + + Returns: + mask (Tensor): A numpy array of shape (num_speakers, encoder_hidden_len). + Dimension: (num_speakers, num_frames) + """ + # get targets with 0.01s frame rate + num_samples = round(a_cut.duration * feat_per_sec) + mask = torch.zeros((num_samples, num_speakers)) + for rttm_sup in segments: + speaker_idx = speaker_to_idx_map[rttm_sup.speaker] + if speaker_idx >= num_speakers: + if ignore_num_spk_mismatch: + continue + else: + raise ValueError(f"Speaker Index {speaker_idx} exceeds the max index: {num_speakers-1}") + stt = max(rttm_sup.start, 0) + ent = min(rttm_sup.end, a_cut.duration) + stf = int(stt * feat_per_sec) + enf = int(ent * feat_per_sec) + mask[stf:enf, speaker_idx] = 1.0 + return mask + + +def get_soft_mask(feat_level_target, num_frames, stride): + """ + Get soft mask from feat_level_target with stride. + This function is needed for speaker diarization with ASR model trainings. + + Args: + feat_level_target (Tensor): A numpy array of shape (num_frames, num_speakers). + Dimension: (num_frames, num_speakers) + num_sample (int): The total number of samples. + stride (int): The stride for the mask. + + Returns: + mask: The soft mask of shape (num_frames, num_speakers). + Dimension: (num_frames, num_speakers) + """ + + num_speakers = feat_level_target.shape[1] + mask = torch.zeros(num_frames, num_speakers) + + for index in range(num_frames): + if index == 0: + seg_stt_feat = 0 + else: + seg_stt_feat = stride * index - 1 - int(stride / 2) + if index == num_frames - 1: + seg_end_feat = feat_level_target.shape[0] + else: + seg_end_feat = stride * index - 1 + int(stride / 2) + mask[index] = torch.mean(feat_level_target[seg_stt_feat : seg_end_feat + 1, :], axis=0) + return mask + + +def get_hidden_length_from_sample_length( + num_samples: int, num_sample_per_mel_frame: int = 160, num_mel_frame_per_asr_frame: int = 8 +) -> int: + """ + Calculate the hidden length from the given number of samples. + This function is needed for speaker diarization with ASR model trainings. + + This function computes the number of frames required for a given number of audio samples, + considering the number of samples per mel frame and the number of mel frames per ASR frame. + + Parameters: + num_samples (int): The total number of audio samples. + num_sample_per_mel_frame (int, optional): The number of samples per mel frame. Default is 160. + num_mel_frame_per_asr_frame (int, optional): The number of mel frames per ASR frame. Default is 8. + + Returns: + hidden_length (int): The calculated hidden length in terms of the number of frames. + """ + mel_frame_count = math.ceil((num_samples + 1) / num_sample_per_mel_frame) + hidden_length = math.ceil(mel_frame_count / num_mel_frame_per_asr_frame) + return int(hidden_length) + + +def speaker_to_target( + a_cut, + num_speakers: int = 4, + num_sample_per_mel_frame: int = 160, + num_mel_frame_per_asr_frame: int = 8, + spk_tar_all_zero: bool = False, + boundary_segments: bool = False, + soft_label: bool = False, + ignore_num_spk_mismatch: bool = True, + soft_thres: float = 0.5, +): + """ + Get rttm samples corresponding to one cut, generate speaker mask numpy.ndarray with shape + (num_speaker, hidden_length). This function is needed for speaker diarization with ASR model trainings. + + Args: + a_cut (MonoCut, MixedCut): + Lhotse Cut instance which is MonoCut or MixedCut instance. + num_speakers (int): + Max number of speakers for all cuts ("mask" dim0), 4 by default + num_sample_per_mel_frame (int): + Number of sample per mel frame, sample_rate / 1000 * window_stride, 160 by default (10ms window stride) + num_mel_frame_per_asr_frame (int): + Encoder subsampling_factor, 8 by default + spk_tar_all_zero (Tensor): + Set to True gives all zero "mask" + boundary_segments (bool): + Set to True to include segments containing the boundary of the cut, + False by default for multi-speaker ASR training + soft_label (bool): + Set to True to use soft label that enables values in [0, 1] range, + False by default and leads to binary labels. + ignore_num_spk_mismatch (bool): + This is a temporary solution to handle speaker mismatch. Will be removed in the future. + + Returns: + mask (Tensor): Speaker mask with shape (num_speaker, hidden_lenght) + """ + # get cut-related segments from rttms + if isinstance(a_cut, MixedCut): + cut_list = [track.cut for track in a_cut.tracks if isinstance(track.cut, MonoCut)] + offsets = [track.offset for track in a_cut.tracks if isinstance(track.cut, MonoCut)] + elif isinstance(a_cut, MonoCut): + cut_list = [a_cut] + offsets = [0] + else: + raise ValueError(f"Unsupported cut type type{a_cut}: only MixedCut and MonoCut are supported") + + segments_total = [] + for i, cut in enumerate(cut_list): + rttms = SupervisionSet.from_rttm(cut.rttm_filepath) + if boundary_segments: # segments with seg_start < total_end and seg_end > total_start are included + segments_iterator = find_segments_from_rttm( + recording_id=cut.recording_id, rttms=rttms, start_after=cut.start, end_before=cut.end, tolerance=0.0 + ) + else: # segments with seg_start > total_start and seg_end < total_end are included + segments_iterator = rttms.find( + recording_id=cut.recording_id, start_after=cut.start, end_before=cut.end, adjust_offset=True + ) + + for seg in segments_iterator: + if seg.start < 0: + seg.duration += seg.start + seg.start = 0 + if seg.end > cut.duration: + seg.duration -= seg.end - cut.duration + seg.start += offsets[i] + segments_total.append(seg) + + # apply arrival time sorting to the existing segments + segments_total.sort(key=lambda rttm_sup: rttm_sup.start) + + seen = set() + seen_add = seen.add + speaker_ats = [s.speaker for s in segments_total if not (s.speaker in seen or seen_add(s.speaker))] + + speaker_to_idx_map = {spk: idx for idx, spk in enumerate(speaker_ats)} + if len(speaker_to_idx_map) > num_speakers and not ignore_num_spk_mismatch: # raise error if number of speakers + raise ValueError( + f"Number of speakers {len(speaker_to_idx_map)} is larger than " + f"the maximum number of speakers {num_speakers}" + ) + + # initialize mask matrices (num_speaker, encoder_hidden_len) + feat_per_sec = int(a_cut.sampling_rate / num_sample_per_mel_frame) # 100 by default + num_samples = get_hidden_length_from_sample_length( + a_cut.num_samples, num_sample_per_mel_frame, num_mel_frame_per_asr_frame + ) + if spk_tar_all_zero: + frame_mask = torch.zeros((num_samples, num_speakers)) + else: + frame_mask = get_mask_from_segments( + segments_total, a_cut, speaker_to_idx_map, num_speakers, feat_per_sec, ignore_num_spk_mismatch + ) + soft_mask = get_soft_mask(frame_mask, num_samples, num_mel_frame_per_asr_frame) + + if soft_label: + mask = soft_mask + else: + mask = (soft_mask > soft_thres).float() + + return mask diff --git a/nemo/collections/asr/parts/utils/manifest_utils.py b/nemo/collections/asr/parts/utils/manifest_utils.py index e9f91045c9a2a..418f95832f486 100644 --- a/nemo/collections/asr/parts/utils/manifest_utils.py +++ b/nemo/collections/asr/parts/utils/manifest_utils.py @@ -24,7 +24,7 @@ from nemo.collections.asr.parts.utils.speaker_utils import ( audio_rttm_map, - get_subsegments, + get_subsegments_scriptable, get_uniqname_from_filepath, rttm_to_labels, segments_manifest_to_subsegments_manifest, @@ -66,13 +66,15 @@ def get_ctm_line( output_precision: int = 2, ) -> str: """ - Get a line in Conversation Time Mark (CTM) format. Following CTM format appeared in `Rich Transcription Meeting Eval Plan: RT09` document. - - CTM Format: + Get a line in Conversation Time Mark (CTM) format. Following CTM format appeared in + `Rich Transcription Meeting Eval Plan: RT09` document. + + CTM Format: - - Reference: - https://web.archive.org/web/20170119114252/http://www.itl.nist.gov/iad/mig/tests/rt/2009/docs/rt09-meeting-eval-plan-v2.pdf + + Reference: + https://web.archive.org/web/20170119114252/ + http://www.itl.nist.gov/iad/mig/tests/rt/2009/docs/rt09-meeting-eval-plan-v2.pdf Args: source (str): is name of the source file, session name or utterance ID @@ -80,11 +82,14 @@ def get_ctm_line( start_time (float): is the begin time of the word, which we refer to as `start_time` in NeMo. duration (float): is duration of the word token (str): Token or word for the current entry - conf (float): is a floating point number between 0 (no confidence) and 1 (certainty). A value of “NA” is used (in CTM format data) - when no confidence is computed and in the reference data. - type_of_token (str): is the token type. The legal values of are “lex”, “frag”, “fp”, “un-lex”, “for-lex”, “non-lex”, “misc”, or “noscore” - speaker (str): is a string identifier for the speaker who uttered the token. This should be “null” for non-speech tokens and “unknown” when - the speaker has not been determined. + conf (float): is a floating point number between 0 (no confidence) and 1 (certainty). + A value of “NA” is used (in CTM format data) + when no confidence is computed and in the reference data. + type_of_token (str): is the token type. The legal values of are + “lex”, “frag”, “fp”, “un-lex”, “for-lex”, “non-lex”, “misc”, or “noscore” + speaker (str): is a string identifier for the speaker who uttered the token. + This should be “null” for non-speech tokens and “unknown” when + the speaker has not been determined. NA_token (str, optional): A token for . Defaults to ''. output_precision (int, optional): The precision of the output floating point number. Defaults to 3. @@ -179,7 +184,7 @@ def get_subsegment_dict(subsegments_manifest_file: str, window: float, shift: fl segment = segment.strip() dic = json.loads(segment) audio, offset, duration, label = dic['audio_filepath'], dic['offset'], dic['duration'], dic['label'] - subsegments = get_subsegments(offset=offset, window=window, shift=shift, duration=duration) + subsegments = get_subsegments_scriptable(offset=offset, window=window, shift=shift, duration=duration) if dic['uniq_id'] is not None: uniq_id = dic['uniq_id'] else: @@ -368,7 +373,11 @@ def create_segment_manifest( segments_manifest_file = write_rttm2manifest(AUDIO_RTTM_MAP, segment_manifest_path, deci) subsegments_manifest_file = subsegment_manifest_path segments_manifest_to_subsegments_manifest( - segments_manifest_file, subsegments_manifest_file, window, shift, min_subsegment_duration, + segments_manifest_file, + subsegments_manifest_file, + window, + shift, + min_subsegment_duration, ) subsegments_dict = get_subsegment_dict(subsegments_manifest_file, window, shift, deci) write_truncated_subsegments(input_manifest_dict, subsegments_dict, output_manifest_path, step_count, deci) @@ -505,7 +514,9 @@ def write_manifest(output_path: Union[Path, str], target_manifest: List[dict], e Args: output_path (str or Path): Path to output manifest file target_manifest (list): List of manifest file entries - ensure_ascii (bool): default is True, meaning the output is guaranteed to have all incoming non-ASCII characters escaped. If ensure_ascii is false, these characters will be output as-is. + ensure_ascii (bool): default is True, meaning the output is guaranteed to have all incoming + non-ASCII characters escaped. If ensure_ascii is false, these characters + will be output as-is. """ with open(output_path, "w", encoding="utf-8") as outfile: for tgt in target_manifest: diff --git a/nemo/collections/asr/parts/utils/speaker_utils.py b/nemo/collections/asr/parts/utils/speaker_utils.py index 5d3a0bf4274e9..223916e60a761 100644 --- a/nemo/collections/asr/parts/utils/speaker_utils.py +++ b/nemo/collections/asr/parts/utils/speaker_utils.py @@ -21,21 +21,17 @@ from typing import Dict, List, Tuple, Union import numpy as np -import omegaconf import soundfile as sf import torch -from pyannote.core import Annotation, Segment +from omegaconf.listconfig import ListConfig +from pyannote.core import Annotation, Segment, Timeline from tqdm import tqdm from nemo.collections.asr.data.audio_to_label import repeat_signal from nemo.collections.asr.parts.utils.longform_clustering import LongFormSpeakerClustering -from nemo.collections.asr.parts.utils.offline_clustering import SpeakerClustering, get_argmin_mat, split_input_data +from nemo.collections.asr.parts.utils.offline_clustering import get_argmin_mat, split_input_data from nemo.utils import logging -""" -This file contains all the utility functions required for speaker embeddings part in diarization scripts -""" - def get_uniqname_from_filepath(filepath): """ @@ -81,10 +77,13 @@ def audio_rttm_map(manifest, attach_dur=False): """ This function creates AUDIO_RTTM_MAP which is used by all diarization components to extract embeddings, cluster and unify time stamps - Args: manifest file that contains keys audio_filepath, rttm_filepath if exists, text, num_speakers if known and uem_filepath if exists - returns: - AUDIO_RTTM_MAP (dict) : A dictionary with keys of uniq id, which is being used to map audio files and corresponding rttm files + Args: + manifest (str): Path to the manifest file + attach_dur (bool, optional): If True, attach duration information to the unique name. Defaults to False. + + Returns: + AUDIO_RTTM_MAP (dict) : Dictionary with unique names as keys and corresponding metadata as values. """ AUDIO_RTTM_MAP = {} @@ -108,15 +107,17 @@ def audio_rttm_map(manifest, attach_dur=False): if attach_dur: uniqname = get_uniq_id_with_dur(meta) else: - uniqname = get_uniqname_from_filepath(filepath=meta['audio_filepath']) + if "uniq_id" in dic.keys(): + uniqname = dic['uniq_id'] + else: + uniqname = get_uniqname_from_filepath(filepath=meta['audio_filepath']) if uniqname not in AUDIO_RTTM_MAP: AUDIO_RTTM_MAP[uniqname] = meta else: raise KeyError( - "file {} is already part of AUDIO_RTTM_MAP, it might be duplicated, Note: file basename must be unique".format( - meta['audio_filepath'] - ) + f"file {meta['audio_filepath']} is already part of AUDIO_RTTM_MAP, it might be duplicated, " + "Note: file basename must be unique" ) return AUDIO_RTTM_MAP @@ -144,7 +145,7 @@ def parse_scale_configs(window_lengths_in_sec, shift_lengths_in_sec, multiscale_ """ check_float_config = [isinstance(var, float) for var in (window_lengths_in_sec, shift_lengths_in_sec)] check_list_config = [ - isinstance(var, (omegaconf.listconfig.ListConfig, list, tuple)) + isinstance(var, (ListConfig, list, tuple)) for var in (window_lengths_in_sec, shift_lengths_in_sec, multiscale_weights) ] if all(check_list_config) or all(check_float_config): @@ -247,7 +248,8 @@ def get_embs_and_timestamps(multiscale_embeddings_and_timestamps, multiscale_arg def get_timestamps(multiscale_timestamps, multiscale_args_dict): """ The timestamps in `multiscale_timestamps` dictionary are indexed by scale index. - This function rearranges the extracted speaker embedding and timestamps by unique ID to make the further processing more convenient. + This function rearranges the extracted speaker embedding and timestamps by unique ID + to make the further processing more convenient. Args: multiscale_timestamps (dict): @@ -441,13 +443,20 @@ def perform_clustering( 'embeddings' : Tensor containing embeddings. Dimensions:(# of embs) x (emb. dimension) 'timestamps' : Tensor containing ime stamps list for each audio recording 'multiscale_segment_counts' : Tensor containing the number of segments for each scale - AUDIO_RTTM_MAP (dict): AUDIO_RTTM_MAP for mapping unique id with audio file path and rttm path - out_rttm_dir (str): Path to write predicted rttms - clustering_params (dict): clustering parameters provided through config that contains max_num_speakers (int), - oracle_num_speakers (bool), max_rp_threshold(float), sparse_search_volume(int) and enhance_count_threshold (int) - use_torch_script (bool): Boolean that determines whether to use torch.jit.script for speaker clustering - device (torch.device): Device we are running on ('cpu', 'cuda'). - verbose (bool): Enable TQDM progress bar. + AUDIO_RTTM_MAP (dict): + AUDIO_RTTM_MAP for mapping unique id with audio file path and rttm path + out_rttm_dir (str): + Path to write predicted rttms + clustering_params (dict): + Clustering parameters provided through config that contains max_num_speakers (int), + oracle_num_speakers (bool), max_rp_threshold(float), sparse_search_volume(int) + and enhance_count_threshold (int). + use_torch_script (bool): + Boolean that determines whether to use torch.jit.script for speaker clustering + device (torch.device): + Device we are running on ('cpu', 'cuda'). + verbose (bool): + Enable TQDM progress bar. Returns: all_reference (list[uniq_name,Annotation]): reference annotations for score calculation @@ -585,7 +594,7 @@ def write_overlap_segments(outfile, AUDIO_RTTM_MAP, uniq_id, overlap_range_list, Number of decimals to round the offset and duration values. """ audio_path = AUDIO_RTTM_MAP[uniq_id]['audio_filepath'] - for (stt, end) in overlap_range_list: + for stt, end in overlap_range_list: meta = { "audio_filepath": audio_path, "offset": round(stt, decimals), @@ -614,9 +623,8 @@ def read_rttm_lines(rttm_file_path): lines = f.readlines() else: raise FileNotFoundError( - "Requested to construct manifest from rttm with oracle VAD option or from NeMo VAD but received filename as {}".format( - rttm_file_path - ) + "Requested to construct manifest from rttm with oracle VAD option " + f"or from NeMo VAD but received filename as {rttm_file_path}" ) return lines @@ -745,14 +753,14 @@ def fl2int(x: float, decimals: int = 3) -> int: """ Convert floating point number to integer. """ - return torch.round(torch.tensor([x * (10 ** decimals)]), decimals=0).int().item() + return torch.round(torch.tensor([x * (10**decimals)]), decimals=0).int().item() def int2fl(x: int, decimals: int = 3) -> float: """ Convert integer to floating point number. """ - return torch.round(torch.tensor([x / (10 ** decimals)]), decimals=decimals).item() + return torch.round(torch.tensor([x / (10**decimals)]), decimals=decimals).item() def merge_float_intervals(ranges: List[List[float]], decimals: int = 5, margin: int = 2) -> List[List[float]]: @@ -886,7 +894,8 @@ def segments_manifest_to_subsegments_manifest( Generate subsegments manifest from segments manifest file Args: segments_manifest file (str): path to segments manifest file, typically from VAD output - subsegments_manifest_file (str): path to output subsegments manifest file (default (None) : writes to current working directory) + subsegments_manifest_file (str): path to output subsegments manifest file + (default (None) : writes to current working directory) window (float): window length for segments to subsegments length shift (float): hop length for subsegments shift min_subsegments_duration (float): exclude subsegments smaller than this duration value @@ -898,15 +907,16 @@ def segments_manifest_to_subsegments_manifest( pwd = os.getcwd() subsegments_manifest_file = os.path.join(pwd, 'subsegments.json') - with open(segments_manifest_file, 'r') as segments_manifest, open( - subsegments_manifest_file, 'w' - ) as subsegments_manifest: + with ( + open(segments_manifest_file, 'r') as segments_manifest, + open(subsegments_manifest_file, 'w') as subsegments_manifest, + ): segments = segments_manifest.readlines() for segment in segments: segment = segment.strip() dic = json.loads(segment) audio, offset, duration, label = dic['audio_filepath'], dic['offset'], dic['duration'], dic['label'] - subsegments = get_subsegments(offset=offset, window=window, shift=shift, duration=duration) + subsegments = get_subsegments_scriptable(offset=offset, window=window, shift=shift, duration=duration) if include_uniq_id and 'uniq_id' in dic: uniq_id = dic['uniq_id'] else: @@ -928,16 +938,82 @@ def segments_manifest_to_subsegments_manifest( return subsegments_manifest_file -def get_subsegments(offset: float, window: float, shift: float, duration: float) -> List[List[float]]: +def get_subsegments( + offset: float, + window: float, + shift: float, + duration: float, + min_subsegment_duration: float = 0.01, + decimals: int = 2, + use_asr_style_frame_count: bool = False, + sample_rate: int = 16000, + feat_per_sec: int = 100, +) -> List[List[float]]: + """ + Return subsegments from a segment of audio file. + + Example: + (window, shift) = 1.5, 0.75 + Segment: [12.05, 14.45] + Subsegments: [[12.05, 13.55], [12.8, 14.3], [13.55, 14.45], [14.3, 14.45]] + + Args: + offset (float): Start time of audio segment + window (float): Window length for segments to subsegments length + shift (float): Hop length for subsegments shift + duration (float): Duration of segment + min_subsegment_duration (float): Exclude subsegments smaller than this duration value + decimals (int): Number of decimal places to round to + use_asr_style_frame_count (bool): If True, use asr style frame count to generate subsegments. + For example, if duration is 10 secs and frame_shift is 0.08 secs, + it results in (10/0.08)+1 = 125 + 1 frames. + + Returns: + subsegments (List[tuple[float, float]]): subsegments generated for the segments as + list of tuple of start and duration of each subsegment + """ + subsegments: List[List[float]] = [] + start = offset + slice_end = start + duration + if min_subsegment_duration <= duration <= shift: + slices = 1 + elif use_asr_style_frame_count is True: + num_feat_frames = np.ceil((1 + duration * sample_rate) / int(sample_rate / feat_per_sec)).astype(int) + slices = np.ceil(num_feat_frames / int(feat_per_sec * shift)).astype(int) + slice_end = start + shift * slices + else: + slices = np.ceil(1 + (duration - window) / shift).astype(int) + if slices == 1: + if min(duration, window) >= min_subsegment_duration: + subsegments.append([start, min(duration, window)]) + elif slices > 0: # What if slcies = 0 ? + start_col = torch.arange(offset, slice_end, shift)[:slices] + dur_col_raw = torch.min( + slice_end * torch.ones_like(start_col) - start_col, window * torch.ones_like(start_col) + ) + dur_col = torch.round(dur_col_raw, decimals=decimals) + valid_mask = dur_col >= min_subsegment_duration + valid_subsegments = torch.stack([start_col[valid_mask], dur_col[valid_mask]], dim=1) + subsegments = valid_subsegments.tolist() + return subsegments + + +def get_subsegments_scriptable(offset: float, window: float, shift: float, duration: float) -> List[List[float]]: """ - Return subsegments from a segment of audio file + This function returns subsegments from a segment of an audio file. + Although this implementation is inefficient due to the use of a for-loop for segmentation, + it is designed to be torch-jit-scriptable. + Use `get_subsegments` for a more efficient implementation. + Args: offset (float): start time of audio segment window (float): window length for segments to subsegments length shift (float): hop length for subsegments shift duration (float): duration of segment Returns: - subsegments (List[tuple[float, float]]): subsegments generated for the segments as list of tuple of start and duration of each subsegment + subsegments (List[tuple[float, float]]): subsegments generated for the segments + as list of tuple of start and duration of + each subsegment """ subsegments: List[List[float]] = [] start = offset @@ -953,7 +1029,13 @@ def get_subsegments(offset: float, window: float, shift: float, duration: float) return subsegments -def get_target_sig(sig, start_sec: float, end_sec: float, slice_length: int, sample_rate: int,) -> torch.Tensor: +def get_target_sig( + sig, + start_sec: float, + end_sec: float, + slice_length: int, + sample_rate: int, +) -> torch.Tensor: """ Extract time-series signal from the given audio buffer based on the start and end timestamps. @@ -1000,6 +1082,34 @@ def tensor_to_list(range_tensor: torch.Tensor) -> List[List[float]]: return [[float(range_tensor[k][0]), float(range_tensor[k][1])] for k in range(range_tensor.shape[0])] +def generate_diarization_output_lines(speaker_timestamps: List[List[float]], model_spk_num: int) -> List[str]: + """ + Generate diarization output lines list from the speaker timestamps list by merging overlapping intervals. + + Args: + speaker_timestamps (list): + List containing the start and end time of the speech intervals for each speaker. + Example: + >>> speaker_timestamps = [[0.5, 3.12], [3.51, 7.26],... ] + model_spk_num (int): + Number of speakers in the model. + + Returns: + speaker_lines_total (list): + List containing the diarization output lines in the format: + "start_time end_time speaker_id" + Example: + >>> speaker_lines_total = ["0.5 3.12 speaker_0", "3.51 7.26 speaker_1",...] + """ + speaker_lines_total = [] + for spk_idx in range(model_spk_num): + ts_invervals = speaker_timestamps[spk_idx] + merged_ts_intervals = merge_float_intervals(ts_invervals) + for ts_interval in merged_ts_intervals: + speaker_lines_total.extend([f"{ts_interval[0]:.3f} {ts_interval[1]:.3f} speaker_{int(spk_idx)}"]) + return speaker_lines_total + + def get_speech_labels_for_update( frame_start: float, buffer_end: float, @@ -1067,9 +1177,12 @@ def get_speech_labels_for_update( return speech_label_for_new_segments, cumulative_speech_labels -def get_new_cursor_for_update(frame_start: float, segment_range_ts: List[List[float]],) -> Tuple[float, int]: +def get_new_cursor_for_update( + frame_start: float, + segment_range_ts: List[List[float]], +) -> Tuple[float, int]: """ - Function for updating a cursor online speaker diarization. + Function for updating a cursor online speaker diarization. Remove the old segments that overlap with the new frame (self.frame_start) cursor_for_old_segments is set to the onset of the t_range popped lastly. @@ -1226,8 +1339,11 @@ def get_online_subsegments_from_buffer( range_offs = [float(range_spl[0].item() - buffer_start), float(range_spl[1].item() - buffer_start)] range_t = [max(0, range_offs[0]), range_offs[1]] - subsegments = get_subsegments( - offset=range_t[0], window=window, shift=shift, duration=(range_t[1] - range_t[0]), + subsegments = get_subsegments_scriptable( + offset=range_t[0], + window=window, + shift=shift, + duration=(range_t[1] - range_t[0]), ) ind_offset, sigs, ranges, inds = get_online_segments_from_slices( sig=audio_buffer, @@ -1277,20 +1393,22 @@ def get_scale_mapping_argmat(uniq_embs_and_timestamps: Dict[str, dict]) -> Dict[ def get_overlap_stamps(cont_stamps: List[str], ovl_spk_idx: List[str]): """ - Generate timestamps that include overlap speech. Overlap-including timestamps are created based on the segments that are - created for clustering diarizer. Overlap speech is assigned to the existing speech segments in `cont_stamps`. + Generate timestamps that include overlap speech. Overlap-including timestamps are created based on + the segments that are created for clustering diarizer. Overlap speech is assigned to the existing + speech segments in `cont_stamps`. Args: cont_stamps (list): - Non-overlapping (single speaker per segment) diarization output in string format. - Each line contains the start and end time of segments and corresponding speaker labels. + Non-overlapping (single speaker per segment) diarization output in string format. Each line + contains the start and end time of segments and corresponding speaker labels. ovl_spk_idx (list): - List containing segment index of the estimated overlapped speech. The start and end of segments are based on the - single-speaker (i.e., non-overlap-aware) RTTM generation. + List containing segment index of the estimated overlapped speech. The start and end of + segments are based on the single-speaker (i.e., non-overlap-aware) RTTM generation. + Returns: total_ovl_cont_list (list): - Rendered diarization output in string format. Each line contains the start and end time of segments and - corresponding speaker labels. This format is identical to `cont_stamps`. + Rendered diarization output in string format. Each line contains the start and end time of + segments and corresponding speaker labels. This format is identical to `cont_stamps`. """ ovl_spk_cont_list = [[] for _ in range(len(ovl_spk_idx))] for spk_idx in range(len(ovl_spk_idx)): @@ -1307,18 +1425,21 @@ def get_overlap_stamps(cont_stamps: List[str], ovl_spk_idx: List[str]): def get_adaptive_threshold(estimated_num_of_spks: int, min_threshold: float, overlap_infer_spk_limit: int): """ - This function controls the magnitude of the sigmoid threshold based on the estimated number of speakers. As the number of - speakers becomes larger, diarization error rate is very sensitive on overlap speech detection. This function linearly increases - the threshold in proportion to the estimated number of speakers so more confident overlap speech results are reflected when - the number of estimated speakers are relatively high. + This function controls the magnitude of the sigmoid threshold based on the estimated number of + speakers. As the number of speakers becomes larger, diarization error rate is very sensitive + to overlap speech detection. This function linearly increases the threshold in proportion to + the estimated number of speakers so more confident overlap speech results are reflected when + the number of estimated speakers is relatively high. Args: estimated_num_of_spks (int): Estimated number of speakers from the clustering result. min_threshold (float): - Sigmoid threshold value from the config file. This threshold value is minimum threshold value when `estimated_num_of_spks=2` + Sigmoid threshold value from the config file. This threshold value is the minimum + threshold when `estimated_num_of_spks=2`. overlap_infer_spk_limit (int): - If the `estimated_num_of_spks` is less then `overlap_infer_spk_limit`, overlap speech estimation is skipped. + If the `estimated_num_of_spks` is less than `overlap_infer_spk_limit`, overlap speech + estimation is skipped. Returns: adaptive_threshold (float): @@ -1333,37 +1454,41 @@ def get_adaptive_threshold(estimated_num_of_spks: int, min_threshold: float, ove def generate_speaker_timestamps( clus_labels: List[Union[float, int]], msdd_preds: List[torch.Tensor], **params ) -> Tuple[List[str], List[str]]: - ''' - Generate speaker timestamps from the segmentation information. If `use_clus_as_main=True`, use clustering result for main speaker - labels and use timestamps from the predicted sigmoid values. In this function, the main speaker labels in `maj_labels` exist for - every subsegment steps while overlap speaker labels in `ovl_labels` only exist for segments where overlap-speech is occuring. + """ + Generate speaker timestamps from the segmentation information. If `use_clus_as_main=True`, use + clustering result for main speaker labels and use timestamps from the predicted sigmoid values. + In this function, the main speaker labels in `maj_labels` exist for every subsegment step, while + overlap speaker labels in `ovl_labels` only exist for segments where overlap speech occurs. Args: clus_labels (list): List containing integer-valued speaker clustering results. msdd_preds (list): - List containing tensors of the predicted sigmoid values. - Each tensor has shape of: (Session length, estimated number of speakers). + List containing tensors of the predicted sigmoid values. Each tensor has shape of: + (Session length, estimated number of speakers). params: Parameters for generating RTTM output and evaluation. Parameters include: - infer_overlap (bool): If False, overlap-speech will not be detected. - use_clus_as_main (bool): Add overlap-speech detection from MSDD to clustering results. If False, only MSDD output - is used for constructing output RTTM files. + infer_overlap (bool): If False, overlap speech will not be detected. + use_clus_as_main (bool): Add overlap-speech detection from MSDD to clustering results. + If False, only MSDD output is used for constructing output + RTTM files. overlap_infer_spk_limit (int): Above this limit, overlap-speech detection is bypassed. - use_adaptive_thres (bool): Boolean that determines whehther to use adaptive_threshold depending on the estimated - number of speakers. + use_adaptive_thres (bool): Boolean that determines whether to use adaptive thresholds + depending on the estimated number of speakers. max_overlap_spks (int): Maximum number of overlap speakers detected. Default is 2. threshold (float): Sigmoid threshold for MSDD output. Returns: maj_labels (list): - List containing string-formated single-speaker speech segment timestamps and corresponding speaker labels. + List containing string-formatted single-speaker speech segment timestamps and corresponding + speaker labels. Example: [..., '551.685 552.77 speaker_1', '552.99 554.43 speaker_0', '554.97 558.19 speaker_0', ...] ovl_labels (list): - List containing string-formated additional overlapping speech segment timestamps and corresponding speaker labels. - Note that `ovl_labels` includes only overlapping speech that is not included in `maj_labels`. + List containing string-formatted additional overlapping speech segment timestamps and + corresponding speaker labels. Note that `ovl_labels` includes only overlapping speech that + is not included in `maj_labels`. Example: [..., '152.495 152.745 speaker_1', '372.71 373.085 speaker_0', '554.97 555.885 speaker_1', ...] - ''' + """ msdd_preds.squeeze(0) estimated_num_of_spks = msdd_preds.shape[-1] overlap_speaker_list = [[] for _ in range(estimated_num_of_spks)] @@ -1398,8 +1523,7 @@ def generate_speaker_timestamps( def get_uniq_id_list_from_manifest(manifest_file: str): - """Retrieve `uniq_id` values from the given manifest_file and save the IDs to a list. - """ + """Retrieve `uniq_id` values from the given manifest_file and save the IDs to a list.""" uniq_id_list = [] with open(manifest_file, 'r', encoding='utf-8') as manifest: for i, line in enumerate(manifest.readlines()): @@ -1418,7 +1542,8 @@ def get_id_tup_dict(uniq_id_list: List[str], test_data_collection, preds_list: L uniq_id_list (list): List containing the `uniq_id` values. test_data_collection (collections.DiarizationLabelEntity): - Class instance that is containing session information such as targeted speaker indices, audio filepath and RTTM filepath. + Class instance that is containing session information such as targeted speaker indices, + audio filepath and RTTM filepath. preds_list (list): List containing tensors of predicted sigmoid values. @@ -1447,11 +1572,14 @@ def prepare_split_data(manifest_filepath, _out_dir, multiscale_args_dict, global Returns: multiscale_args_dict (dict): - - Dictionary containing two types of arguments: multi-scale weights and subsegment timestamps for each data sample. + - Dictionary containing two types of arguments: multi-scale weights and subsegment timestamps + for each data sample. - Each data sample has two keys: `multiscale_weights` and `scale_dict`. - `multiscale_weights` key contains a list containing multiscale weights. - `scale_dict` is indexed by integer keys which are scale index. - - Each data sample is indexed by using the following naming convention: `__` + - Each data sample is indexed by using the following naming convention: + `__` + Example: `fe_03_00106_mixed_626310_642300` """ speaker_dir = os.path.join(_out_dir, 'speaker_outputs') @@ -1580,6 +1708,86 @@ def make_rttm_with_overlap( return all_reference, all_hypothesis +def timestamps_to_pyannote_object( + speaker_timestamps: List[Tuple[float, float]], + uniq_id: str, + audio_rttm_values: Dict[str, str], + all_hypothesis: List[Tuple[str, Timeline]], + all_reference: List[Tuple[str, Timeline]], + all_uems: List[Tuple[str, Timeline]], + out_rttm_dir: str | None, +): + """ + Convert speaker timestamps to pyannote.core.Timeline object. + + Args: + speaker_timestamps (List[Tuple[float, float]]): + Timestamps of each speaker: start time and end time of each speaker. + uniq_id (str): + Unique ID of each speaker. + audio_rttm_values (Dict[str, str]): + Dictionary of manifest values. + all_hypothesis (List[Tuple[str, pyannote.core.Timeline]]): + List of hypothesis in pyannote.core.Timeline object. + all_reference (List[Tuple[str, pyannote.core.Timeline]]): + List of reference in pyannote.core.Timeline object. + all_uems (List[Tuple[str, pyannote.core.Timeline]]): + List of uems in pyannote.core.Timeline object. + out_rttm_dir (str | None): + Directory to save RTTMs + + Returns: + all_hypothesis (List[Tuple[str, pyannote.core.Timeline]]): + List of hypothesis in pyannote.core.Timeline object with an added Timeline object. + all_reference (List[Tuple[str, pyannote.core.Timeline]]): + List of reference in pyannote.core.Timeline object with an added Timeline object. + all_uems (List[Tuple[str, pyannote.core.Timeline]]): + List of uems in pyannote.core.Timeline object with an added Timeline object. + """ + offset, dur = float(audio_rttm_values.get('offset', None)), float(audio_rttm_values.get('duration', None)) + hyp_labels = generate_diarization_output_lines( + speaker_timestamps=speaker_timestamps, model_spk_num=len(speaker_timestamps) + ) + hypothesis = labels_to_pyannote_object(hyp_labels, uniq_name=uniq_id) + if out_rttm_dir is not None and os.path.exists(out_rttm_dir): + with open(f'{out_rttm_dir}/{uniq_id}.rttm', 'w') as f: + hypothesis.write_rttm(f) + all_hypothesis.append([uniq_id, hypothesis]) + rttm_file = audio_rttm_values.get('rttm_filepath', None) + if rttm_file is not None and os.path.exists(rttm_file): + uem_lines = [[offset, dur + offset]] + org_ref_labels = rttm_to_labels(rttm_file) + ref_labels = org_ref_labels + reference = labels_to_pyannote_object(ref_labels, uniq_name=uniq_id) + uem_obj = get_uem_object(uem_lines, uniq_id=uniq_id) + all_uems.append(uem_obj) + all_reference.append([uniq_id, reference]) + return all_hypothesis, all_reference, all_uems + + +def get_uem_object(uem_lines: List[List[float]], uniq_id: str): + """ + Generate pyannote timeline segments for uem file. + + file format + UNIQ_SPEAKER_ID CHANNEL START_TIME END_TIME + + Args: + uem_lines (list): list of session ID and start, end times. + Example: + [[0.0, 30.41], [60.04, 165.83]] + uniq_id (str): Unique session ID. + + Returns: + timeline (pyannote.core.Timeline): pyannote timeline object. + """ + timeline = Timeline(uri=uniq_id) + for uem_stt_end in uem_lines: + start_time, end_time = uem_stt_end + timeline.add(Segment(float(start_time), float(end_time))) + return timeline + + def embedding_normalize(embs, use_std=False, eps=1e-10): """ Mean and l2 length normalize the input speaker embeddings @@ -1635,7 +1843,7 @@ def run_online_segmentation( segment_indexes: List[int], window: float, shift: float, - ): + ) -> Tuple[List[torch.Tensor], List[List[float]], List[int]]: """ Remove the old segments that overlap with the new frame (self.frame_start) cursor_for_old_segments is pointing at the onset of the t_range popped most recently. diff --git a/nemo/collections/asr/parts/utils/vad_utils.py b/nemo/collections/asr/parts/utils/vad_utils.py index aea04b8cafcf8..83a811ee4adbe 100644 --- a/nemo/collections/asr/parts/utils/vad_utils.py +++ b/nemo/collections/asr/parts/utils/vad_utils.py @@ -23,31 +23,22 @@ from pathlib import Path from typing import Dict, List, Optional, Tuple, Union +import IPython.display as ipd import librosa import matplotlib.pyplot as plt import numpy as np import pandas as pd import torch -from omegaconf import DictConfig +from omegaconf import DictConfig, OmegaConf from pyannote.core import Annotation, Segment from pyannote.metrics import detection from sklearn.metrics import roc_auc_score from sklearn.model_selection import ParameterGrid from tqdm import tqdm - from nemo.collections.asr.models import EncDecClassificationModel, EncDecFrameClassificationModel from nemo.collections.common.parts.preprocessing.manifest import get_full_path from nemo.utils import logging -HAVE_IPYTHON = False -try: - import IPython.display as ipd - - HAVE_IPYTHON = True -except: - HAVE_IPYTHON = False - - """ This file contains all the utility functions required for voice activity detection. """ @@ -74,8 +65,8 @@ def prepare_manifest(config: dict) -> str: input_list = config['input'] else: raise ValueError( - "The input for manifest preparation would either be a string of the filepath to \ - manifest or a list of {'audio_filepath': i, 'offset': 0, 'duration': null} " + "The input for manifest preparation would either be a string of the filepath to manifest " + "or a list of {'audio_filepath': i, 'offset': 0, 'duration': null}." ) args_func = { @@ -204,8 +195,7 @@ def write_vad_infer_manifest(file: dict, args_func: dict) -> list: def get_vad_stream_status(data: list) -> list: """ - Generate a list of status for each snippet in manifest. - A snippet should be in single, start, next or end status. + Generate a list of status for each snippet in manifest. A snippet should be in single, start, next or end status. Used for concatenating to full audio file. Args: data (list): list of filepath of audio snippet @@ -321,9 +311,8 @@ def generate_overlap_vad_seq_per_tensor( frame: torch.Tensor, per_args: Dict[str, float], smoothing_method: str ) -> torch.Tensor: """ - Use generated frame prediction (generated by shifting window of shift_length_in_sec (10ms)) - to generate prediction with overlapping input window/segments - See description in generate_overlap_vad_seq. + Use generated frame prediction (generated by shifting window of shift_length_in_sec (10ms)) to generate + prediction with overlapping input window/segments. See description in generate_overlap_vad_seq. Use this for single instance pipeline. """ # This function will be refactor for vectorization but this is okay for now @@ -484,8 +473,8 @@ def binarization(sequence: torch.Tensor, per_args: Dict[str, float]) -> torch.Te Binarize predictions to speech and non-speech Reference - Paper: Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of RNN-based Voice Activity Detection", \ - InterSpeech 2015. + Paper: Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of RNN-based Voice + Activity Detection", InterSpeech 2015. Implementation: https://github.com/pyannote/pyannote-audio/blob/master/pyannote/audio/utils/signal.py Args: @@ -498,8 +487,8 @@ def binarization(sequence: torch.Tensor, per_args: Dict[str, float]) -> torch.Te frame_length_in_sec (float): length of frame. Returns: - speech_segments(torch.Tensor): A tensor of speech segment in torch.Tensor([[start1, end1], [start2, end2]]) \ - format. + speech_segments(torch.Tensor): A tensor of speech segment in the form of: + `torch.Tensor([[start1, end1], [start2, end2]])`. """ frame_length_in_sec = per_args.get('frame_length_in_sec', 0.01) @@ -549,11 +538,10 @@ def binarization(sequence: torch.Tensor, per_args: Dict[str, float]) -> torch.Te def remove_segments(original_segments: torch.Tensor, to_be_removed_segments: torch.Tensor) -> torch.Tensor: """ Remove speech segments list in to_be_removed_segments from original_segments. - For example, - remove torch.Tensor([[start2, end2],[start4, end4]]) from torch.Tensor([[start1, end1],[start2, end2],\ - [start3, end3], [start4, end4]]), - -> - torch.Tensor([[start1, end1],[start3, end3]]) + (Example) Remove torch.Tensor([[start2, end2],[start4, end4]]) + from torch.Tensor([[start1, end1],[start2, end2],[start3, end3], [start4, end4]]), + -> + torch.Tensor([[start1, end1],[start3, end3]]) """ for y in to_be_removed_segments: original_segments = original_segments[original_segments.eq(y).all(dim=1).logical_not()] @@ -574,24 +562,30 @@ def get_gap_segments(segments: torch.Tensor) -> torch.Tensor: @torch.jit.script def filtering(speech_segments: torch.Tensor, per_args: Dict[str, float]) -> torch.Tensor: """ - Filter out short non_speech and speech segments. + Filter out short non-speech and speech segments. + + Reference: + Paper: Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of RNN-based Voice + Activity Detection", InterSpeech 2015. + Implementation: + https://github.com/pyannote/pyannote-audio/blob/master/pyannote/audio/utils/signal.py - Reference - Paper: Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of RNN-based Voice Activity Detection", \ - InterSpeech 2015. - Implementation: https://github.com/pyannote/pyannote-audio/blob/master/pyannote/audio/utils/signal.py Args: - speech_segments (torch.Tensor): A tensor of speech segment in torch.Tensor([[start1, end1], \ - [start2, end2]]) format. + speech_segments (torch.Tensor): + A tensor of speech segments in the format + torch.Tensor([[start1, end1], [start2, end2]]). per_args: - min_duration_on (float): threshold for small non_speech deletion - min_duration_off (float): threshold for short speech segment deletion - filter_speech_first (float): Whether to perform short speech segment deletion first. \ - Use 1.0 to represent True. + min_duration_on (float): + Threshold for small non-speech deletion. + min_duration_off (float): + Threshold for short speech segment deletion. + filter_speech_first (float): + Whether to perform short speech segment deletion first. Use 1.0 to represent True. Returns: - speech_segments(torch.Tensor): A tensor of filtered speech segment in \ - torch.Tensor([[start1, end1], [start2, end2]]) format. + speech_segments (torch.Tensor): + A tensor of filtered speech segments in the format + torch.Tensor([[start1, end1], [start2, end2]]). """ if speech_segments.shape == torch.Size([0]): return speech_segments @@ -840,18 +834,19 @@ def vad_tune_threshold_on_dev( num_workers: int = 20, ) -> Tuple[dict, dict]: """ - Tune thresholds on dev set. Return best thresholds which gives the lowest detection error rate - (DetER) in thresholds. + Tune thresholds on dev set. Return best thresholds which gives the lowest + detection error rate (DetER) in thresholds. + Args: params (dict): dictionary of parameters to be tuned on. vad_pred_method (str): suffix of prediction file. Use to locate file. - Should be either in "frame", "mean" or "median". - groundtruth_RTTM_dir (str): directory of ground-truth rttm files or a file contains the paths of them. - focus_metric (str): metrics we care most when tuning threshold. Should be either in "DetER", "FA", "MISS" - frame_length_in_sec (float): frame length. - num_workers (int): number of workers. + Should be either in "frame", "mean" or "median". + groundtruth_RTTM_dir (str): Directory of ground-truth rttm files or a file contains the paths of them. + focus_metric (str): Metrics we care most when tuning threshold. Should be either in "DetER", "FA", "MISS" + frame_length_in_sec (float): Frame length. + num_workers (int): Number of workers. Returns: - best_threshold (float): threshold that gives lowest DetER. + best_threshold (float): Threshold that gives lowest DetER. """ min_score = 100 all_perf = {} @@ -936,8 +931,7 @@ def check_if_param_valid(params: dict) -> bool: for j in params[i]: if not j >= 0: raise ValueError( - "Invalid inputs! All float parameters except pad_onset and pad_offset should be \ - larger than 0!" + "Invalid inputs! All float parameters except pad_onset and pad_offset should be larger than 0!" ) if not (all(i <= 1 for i in params['onset']) and all(i <= 1 for i in params['offset'])): @@ -995,7 +989,7 @@ def plot( unit_frame_len: float = 0.01, label_repeat: int = 1, xticks_step: int = 5, -) -> "ipd.Audio": +) -> ipd.Audio: """ Plot Audio and/or VAD output and/or groundtruth labels for visualization Args: @@ -1009,13 +1003,10 @@ def plot( threshold (float): threshold for prediction score (from 0 to 1). per_args(dict): a dict that stores the thresholds for postprocessing. unit_frame_len (float): unit frame length in seconds for VAD predictions. - label_repeat (int): repeat the label for this number of times to match different \ - frame lengths in preds and labels. + label_repeat (int): repeat the label for this number of times to match different + frame lengths in preds and labels. xticks_step (int): step size for xticks. """ - if HAVE_IPYTHON is False: - raise ImportError("IPython is not installed. Please install IPython to use this function.") - plt.figure(figsize=[20, 2]) audio, sample_rate = librosa.load( @@ -1281,8 +1272,8 @@ def stitch_segmented_asr_output( fout.flush() logging.info( - f"Finish stitch segmented ASR output to {stitched_output_manifest}, \ - the speech segments info has been stored in directory {speech_segments_tensor_dir}" + f"Finish stitch segmented ASR output to {stitched_output_manifest}, " + f"the speech segments info has been stored in directory {speech_segments_tensor_dir}" ) return stitched_output_manifest @@ -1462,13 +1453,10 @@ def plot_sample_from_rttm( show: bool = True, offset: float = 0.0, unit_frame_len: float = 0.01, -) -> "ipd.Audio": +): """ Plot audio signal and frame-level labels from RTTM file """ - if HAVE_IPYTHON is False: - raise ImportError("IPython is not installed. Please install IPython to use this function.") - plt.figure(figsize=[20, 2]) audio, sample_rate = librosa.load(path=audio_file, sr=16000, mono=True, offset=offset, duration=max_duration) @@ -1502,17 +1490,22 @@ def plot_sample_from_rttm( def align_labels_to_frames(probs, labels, threshold=0.2): """ - Aligns labels to frames when the frame length (e.g., 10ms) is different from the label length (e.g., 20ms). - The threshold 0.2 is not important, since the actual ratio will always be close to an integer - unless using frame/label. lengths that are not multiples of each other - (e.g., 15ms frame length and 20ms label length), which is not valid. - The value 0.2 here is just for easier unit testing. + Aligns labels to frames when the frame length (e.g., 10ms) is different from the label length + (e.g., 20ms). The threshold 0.2 is not critical, as the actual ratio will always be close to an + integer unless using frame/label lengths that are not multiples of each other (e.g., 15ms frame + length and 20ms label length), which is not valid. The value 0.2 is chosen for easier unit testing. + Args: - probs (List[float]): list of probabilities - labels (List[int]): list of labels - threshold (float): threshold for rounding ratio to integer + probs (List[float]): + List of probabilities. + labels (List[int]): + List of labels. + threshold (float): + Threshold for rounding the ratio to an integer. + Returns: - labels (List[int]): list of labels aligned to frames + labels (List[int]): + List of labels aligned to frames. """ frames_len = len(probs) labels_len = len(labels) @@ -1543,13 +1536,13 @@ def align_labels_to_frames(probs, labels, threshold=0.2): ratio = frames_len / labels_len res = frames_len % labels_len if ceil(ratio) - ratio < threshold: - # e.g., ratio is 1.83, ceil(ratio) = 2, then we repeat labels to make it a - # multiple of 2, and discard the redundant labels + # e.g., ratio is 1.83, ceil(ratio) = 2, then we repeat labels + # to make it a multiple of 2, and discard the redundant labels labels = labels.repeat_interleave(ceil(ratio), dim=0).long().tolist() labels = labels[:frames_len] else: - # e.g., ratio is 2.02, floor(ratio) = 2, then we repeat labels to make it a multiple of - # 2 and add additional labels + # e.g., ratio is 2.02, floor(ratio) = 2, then we repeat labels + # to make it a multiple of 2 and add additional labels labels = labels.repeat_interleave(floor(ratio), dim=0).long().tolist() if res > 0: labels += labels[-res:] @@ -1743,3 +1736,52 @@ def frame_vad_eval_detection_error( auroc = roc_auc_score(y_true=all_labels, y_score=all_probs) report = metric.report(display=False) return auroc, report + + +def ts_vad_post_processing( + ts_vad_binary_vec: torch.Tensor, + cfg_vad_params: OmegaConf, + unit_10ms_frame_count: int = 8, + bypass_postprocessing: bool = False, +): + """ + Post-processing on diarization results using VAD style post-processing methods. + These post-processing methods are inspired by the following paper: + Medennikov, Ivan, et al. "Target-Speaker Voice Activity Detection: + a Novel Approach for Multi-Speaker Diarization in a Dinner Party Scenario." (2020). + + Args: + ts_vad_binary_vec (Tensor): + Sigmoid values of each frame and each speaker. + Dimension: (num_frames,) + cfg_vad_params (OmegaConf): + Configuration (omega config) of VAD parameters. + unit_10ms_frame_count (int, optional): + an integer indicating the number of 10ms frames in a unit. + For example, if unit_10ms_frame_count is 8, then each frame is 0.08 seconds. + bypass_postprocessing (bool, optional): + If True, diarization post-processing will be bypassed. + + Returns: + speech_segments (Tensor): + start and end of each speech segment. + Dimension: (num_segments, 2) + + Example: + tensor([[ 0.0000, 3.0400], + [ 6.0000, 6.0800], + ... + [587.3600, 591.0400], + [591.1200, 597.7600]]) + """ + ts_vad_binary_frames = torch.repeat_interleave(ts_vad_binary_vec, unit_10ms_frame_count) + if not bypass_postprocessing: + speech_segments = binarization(ts_vad_binary_frames, cfg_vad_params) + speech_segments = filtering(speech_segments, cfg_vad_params) + else: + cfg_vad_params.onset = 0.5 + cfg_vad_params.offset = 0.5 + cfg_vad_params.pad_onset = 0.0 + cfg_vad_params.pad_offset = 0.0 + speech_segments = binarization(ts_vad_binary_frames, cfg_vad_params) + return speech_segments diff --git a/nemo/collections/common/parts/preprocessing/collections.py b/nemo/collections/common/parts/preprocessing/collections.py index 915f406a3e881..d54c807f26377 100644 --- a/nemo/collections/common/parts/preprocessing/collections.py +++ b/nemo/collections/common/parts/preprocessing/collections.py @@ -313,7 +313,10 @@ class InstructionTuningAudioText(_Collection): OUTPUT_TYPE = collections.namedtuple( typename='InstructionTuningText', - field_names='id context context_type context_duration question question_type answer answer_type answer_duration speaker', + field_names=( + 'id context context_type context_duration question ' + 'question_type answer answer_type answer_duration speaker' + ), ) def __init__( @@ -478,7 +481,10 @@ def __init__(self, manifests_files: Union[str, List[str]], parse_func: Optional[ class SpeechLLMAudioTextEntity(object): + """Class for SpeechLLM dataloader instance.""" + def __init__(self, sid, audio_file, duration, context, answer, offset, speaker, orig_sr, lang) -> None: + """Initialize the AudioTextEntity for a SpeechLLM dataloader instance.""" self.id = sid self.audio_file = audio_file self.duration = duration @@ -559,7 +565,6 @@ def __init__( ): """Instantiates audio-context-answer manifest with filters and preprocessing. - Args: ids: List of examples positions. audio_files: List of audio files. @@ -770,7 +775,8 @@ def __parse_item(self, line: str, manifest_file: str) -> Dict[str, Any]: elif 'question' in item: # compatability with old manifests that uses 'question' as context key logging.warning( - f"Neither `{self.context_key}` is found nor `context_file` is set, but found `question` in item: {item}", + f"Neither `{self.context_key}` is found nor" + f"`context_file` is set, but found `question` in item: {item}", mode=logging_mode.ONCE, ) item['context'] = item.pop('question') @@ -867,7 +873,8 @@ def __init__( else: logging.info(f"Filtered duration for loading collection is {duration_filtered / 3600: .2f} hours.") logging.info( - f"Dataset successfully loaded with {len(data)} items and total duration provided from manifest is {total_duration / 3600: .2f} hours." + f"Dataset successfully loaded with {len(data)} items " + f"and total duration provided from manifest is {total_duration / 3600: .2f} hours." ) self.uniq_labels = sorted(set(map(lambda x: x.label, data))) @@ -1008,13 +1015,15 @@ def __init__( if len(data) == max_number: break - logging.info("# {} files loaded including # {} unique labels".format(len(data), len(self.uniq_labels))) + logging.info(f"# {len(data)} files loaded including # {len(self.uniq_labels)} unique labels") super().__init__(data) def relative_speaker_parser(self, seq_label): """Convert sequence of speaker labels to relative labels. Convert sequence of absolute speaker to sequence of relative speaker [E A C A E E C] -> [0 1 2 1 0 0 2] - In this seq of label , if label do not appear before, assign new relative labels len(pos); else reuse previous assigned relative labels. + In this seq of label , if label do not appear before, assign new relative labels len(pos); + else reuse previous assigned relative labels. + Args: seq_label (str): A string of a sequence of labels. @@ -1051,10 +1060,13 @@ def __init__( """Parse lists of feature files and sequences of labels. Args: - manifests_files: Either single string file or list of such - - manifests to yield items from. - max_number: Maximum number of samples to collect; pass to `FeatureSequenceLabel` constructor. - index_by_file_id: If True, saves a mapping from filename base (ID) to index in data; pass to `FeatureSequenceLabel` constructor. + manifests_files: + Either single string file or list of such manifests to yield items from. + max_number: + Maximum number of samples to collect; pass to `FeatureSequenceLabel` constructor. + index_by_file_id: + If True, saves a mapping from filename base (ID) to index in data; + pass to `FeatureSequenceLabel` constructor. """ feature_files, seq_labels = [], [] @@ -1209,35 +1221,37 @@ def __init__( manifests_files: Union[str, List[str]], emb_dict: Dict, clus_label_dict: Dict, - round_digit=2, + round_digits: int = 2, seq_eval_mode=False, pairwise_infer=False, *args, **kwargs, ): """ - Parse lists of audio files, durations, RTTM (Diarization annotation) files. Since diarization model infers only - two speakers, speaker pairs are generated from the total number of speakers in the session. + Parse lists of audio files, durations, RTTM (Diarization annotation) files. Since the diarization + model infers only two speakers, speaker pairs are generated from the total number of speakers in + the session. Args: manifest_filepath (str): - Path to input manifest json files. + Path to input manifest JSON files. emb_dict (Dict): Dictionary containing cluster-average embeddings and speaker mapping information. clus_label_dict (Dict): Segment-level speaker labels from clustering results. round_digit (int): - Number of digits to be rounded. + Number of digits to round. seq_eval_mode (bool): If True, F1 score will be calculated for each speaker pair during inference mode. pairwise_infer (bool): - If True, this dataset class operates in inference mode. In inference mode, a set of speakers in the input audio - is split into multiple pairs of speakers and speaker tuples (e.g. 3 speakers: [(0,1), (1,2), (0,2)]) and then - fed into the diarization system to merge the individual results. + If True, this dataset class operates in inference mode. In inference mode, a set of + speakers in the input audio is split into multiple pairs of speakers and speaker tuples + (e.g., 3 speakers: [(0,1), (1,2), (0,2)]) and then fed into the diarization system to + merge the individual results. *args: Args to pass to `SpeechLabel` constructor. **kwargs: Kwargs to pass to `SpeechLabel` constructor. """ - self.round_digit = round_digit + self.round_digits = round_digits self.emb_dict = emb_dict self.clus_label_dict = clus_label_dict self.seq_eval_mode = seq_eval_mode @@ -1371,6 +1385,188 @@ def __parse_item_rttm(self, line: str, manifest_file: str) -> Dict[str, Any]: return item +class EndtoEndDiarizationLabel(_Collection): + """List of end-to-end diarization audio-label correspondence with preprocessing.""" + + OUTPUT_TYPE = collections.namedtuple( + typename='DiarizationLabelEntity', + field_names='audio_file uniq_id duration rttm_file offset', + ) + + def __init__( + self, + audio_files: List[str], + uniq_ids: List[str], + durations: List[float], + rttm_files: List[str], + offsets: List[float], + max_number: Optional[int] = None, + do_sort_by_duration: bool = False, + index_by_file_id: bool = False, + ): + """ + Instantiates audio-label manifest with filters and preprocessing. + + This method initializes the EndtoEndDiarizationLabel object by processing the input data + and applying optional filters and sorting. + + Args: + audio_files (List[str]): List of audio file paths. + uniq_ids (List[str]): List of unique identifiers for each audio file. + durations (List[float]): List of float durations for each audio file. + rttm_files (List[str]): List of RTTM path strings (Groundtruth diarization annotation file). + offsets (List[float]): List of offsets or None for each audio file. + max_number (Optional[int]): Maximum number of samples to collect. Defaults to None. + do_sort_by_duration (bool): If True, sort samples list by duration. Defaults to False. + index_by_file_id (bool): If True, saves a mapping from filename base (ID) to index in data. + Defaults to False. + + """ + if index_by_file_id: + self.mapping = {} + output_type = self.OUTPUT_TYPE + data, duration_filtered = [], 0.0 + + zipped_items = zip(audio_files, uniq_ids, durations, rttm_files, offsets) + for ( + audio_file, + uniq_id, + duration, + rttm_file, + offset, + ) in zipped_items: + + if duration is None: + duration = 0 + + data.append( + output_type( + audio_file, + uniq_id, + duration, + rttm_file, + offset, + ) + ) + + if index_by_file_id: + if isinstance(audio_file, list): + if len(audio_file) == 0: + raise ValueError(f"Empty audio file list: {audio_file}") + file_id, _ = os.path.splitext(os.path.basename(audio_file)) + self.mapping[file_id] = len(data) - 1 + + # Max number of entities filter. + if len(data) == max_number: + break + + if do_sort_by_duration: + if index_by_file_id: + logging.warning("Tried to sort dataset by duration, but cannot since index_by_file_id is set.") + else: + data.sort(key=lambda entity: entity.duration) + + logging.info( + "Filtered duration for loading collection is %f.", + duration_filtered, + ) + logging.info(f"Total {len(data)} session files loaded accounting to # {len(audio_files)} audio clips") + + super().__init__(data) + + +class EndtoEndDiarizationSpeechLabel(EndtoEndDiarizationLabel): + """End-to-end speaker diarization data sample collector from structured json files.""" + + def __init__( + self, + manifests_files: Union[str, List[str]], + round_digits: int = 2, + *args, + **kwargs, + ): + """ + Parse lists of audio files, durations, RTTM (Diarization annotation) files. + Since diarization model infers only two speakers, speaker pairs are generated + from the total number of speakers in the session. + + Args: + manifest_filepath (str): + Path to input manifest json files. + round_digit (int): + Number of digits to be rounded. + *args: Args to pass to `SpeechLabel` constructor. + **kwargs: Kwargs to pass to `SpeechLabel` constructor. + """ + self.round_digits = round_digits + audio_files, uniq_ids, durations, rttm_files, offsets = ( + [], + [], + [], + [], + [], + ) + + for item in manifest.item_iter(manifests_files, parse_func=self.__parse_item_rttm): + # Training mode + rttm_labels = [] + with open(item['rttm_file'], 'r') as f: + for index, rttm_line in enumerate(f.readlines()): + rttm = rttm_line.strip().split() + start = round(float(rttm[3]), round_digits) + end = round(float(rttm[4]), round_digits) + round(float(rttm[3]), round_digits) + speaker = rttm[7] + rttm_labels.append('{} {} {}'.format(start, end, speaker)) + audio_files.append(item['audio_file']) + uniq_ids.append(item['uniq_id']) + durations.append(item['duration']) + rttm_files.append(item['rttm_file']) + offsets.append(item['offset']) + + super().__init__( + audio_files, + uniq_ids, + durations, + rttm_files, + offsets, + *args, + **kwargs, + ) + + def __parse_item_rttm(self, line: str, manifest_file: str) -> Dict[str, Any]: + """Parse each rttm file and save it to in Dict format""" + item = json.loads(line) + if 'audio_filename' in item: + item['audio_file'] = item.pop('audio_filename') + elif 'audio_filepath' in item: + item['audio_file'] = item.pop('audio_filepath') + else: + raise ValueError( + f"Manifest file has invalid json line " f"structure: {line} without proper audio file key." + ) + if isinstance(item['audio_file'], list): + item['audio_file'] = [os.path.expanduser(audio_file_path) for audio_file_path in item['audio_file']] + else: + item['audio_file'] = os.path.expanduser(item['audio_file']) + + if not isinstance(item['audio_file'], list): + if 'uniq_id' not in item: + item['uniq_id'] = os.path.splitext(os.path.basename(item['audio_file']))[0] + elif 'uniq_id' not in item: + raise ValueError(f"Manifest file has invalid json line " f"structure: {line} without proper uniq_id key.") + + if 'duration' not in item: + raise ValueError(f"Manifest file has invalid json line " f"structure: {line} without proper duration key.") + item = dict( + audio_file=item['audio_file'], + uniq_id=item['uniq_id'], + duration=item['duration'], + rttm_file=item['rttm_filepath'], + offset=item.get('offset', None), + ) + return item + + class Audio(_Collection): """Prepare a list of all audio items, filtered by duration.""" @@ -1641,7 +1837,8 @@ def __init__( manifests_files: Either single string file or list of such - manifests to yield items from. max_number: Maximum number of samples to collect; pass to `FeatureSequenceLabel` constructor. - index_by_file_id: If True, saves a mapping from filename base (ID) to index in data; pass to `FeatureSequenceLabel` constructor. + index_by_file_id: If True, saves a mapping from filename base (ID) to index in data; + pass to `FeatureSequenceLabel` constructor. """ feature_files, labels, durations = [], [], [] diff --git a/nemo/collections/llm/__init__.py b/nemo/collections/llm/__init__.py index c36da39b43c70..f17128cdb36d3 100644 --- a/nemo/collections/llm/__init__.py +++ b/nemo/collections/llm/__init__.py @@ -22,7 +22,7 @@ AlpacaDataModule, DollyDataModule, FineTuningDataModule, - HfDatasetDataModule, + HFDatasetDataModule, MockDataModule, PreTrainingDataModule, SquadDataModule, @@ -64,7 +64,7 @@ GPTConfig126M, GPTConfig175B, GPTModel, - HfAutoModelForCausalLM, + HFAutoModelForCausalLM, Llama2Config7B, Llama2Config13B, Llama2Config70B, @@ -218,7 +218,7 @@ "dolly", "peft", "hf_dataset", - "HfAutoModelForCausalLM", + "HFAutoModelForCausalLM", ] diff --git a/nemo/collections/llm/api.py b/nemo/collections/llm/api.py index 4bafdd97ba217..adf98747059cf 100644 --- a/nemo/collections/llm/api.py +++ b/nemo/collections/llm/api.py @@ -265,12 +265,11 @@ def validate( @run.cli.entrypoint(name="ptq", namespace="llm") def ptq( nemo_checkpoint: str, + export_config: ExportConfig, calib_tp: int = 1, calib_pp: int = 1, quantization_config: Annotated[Optional[QuantizationConfig], run.Config[QuantizationConfig]] = None, - export_config: Optional[Union[ExportConfig, run.Config[ExportConfig]]] = None, ) -> Path: - # TODO: Fix "nemo_run.cli.cli_parser.CLIException: An unexpected error occurred (Argument: , Context: {})" """ Applies Post-Training Quantization (PTQ) for a model using the specified quantization and export configs. It runs calibration for a small dataset to collect scaling factors low-precision GEMMs used by desired quantization method. @@ -297,6 +296,9 @@ def ptq( Returns: Path: The path where the quantized checkpoint has been saved after calibration. """ + if not quantization_config: + quantization_config = QuantizationConfig() + if export_config.path is None: raise ValueError("The export_config.path needs to be specified, got None.") diff --git a/nemo/collections/llm/gpt/data/__init__.py b/nemo/collections/llm/gpt/data/__init__.py index b42c350bcaba5..c8690fd0668fc 100644 --- a/nemo/collections/llm/gpt/data/__init__.py +++ b/nemo/collections/llm/gpt/data/__init__.py @@ -15,7 +15,7 @@ from nemo.collections.llm.gpt.data.alpaca import AlpacaDataModule from nemo.collections.llm.gpt.data.dolly import DollyDataModule from nemo.collections.llm.gpt.data.fine_tuning import FineTuningDataModule -from nemo.collections.llm.gpt.data.hf_dataset import HfDatasetDataModule +from nemo.collections.llm.gpt.data.hf_dataset import HFDatasetDataModule from nemo.collections.llm.gpt.data.mock import MockDataModule from nemo.collections.llm.gpt.data.pre_training import PreTrainingDataModule, build_pretraining_datamodule from nemo.collections.llm.gpt.data.squad import SquadDataModule @@ -28,5 +28,5 @@ "MockDataModule", "PreTrainingDataModule", "build_pretraining_datamodule", - "HfDatasetDataModule", + "HFDatasetDataModule", ] diff --git a/nemo/collections/llm/gpt/data/api.py b/nemo/collections/llm/gpt/data/api.py index 2ebb30e781d19..374bee83b8b2a 100644 --- a/nemo/collections/llm/gpt/data/api.py +++ b/nemo/collections/llm/gpt/data/api.py @@ -16,7 +16,7 @@ import nemo_run as run from nemo.collections.llm.gpt.data.dolly import DollyDataModule -from nemo.collections.llm.gpt.data.hf_dataset import HfDatasetDataModule +from nemo.collections.llm.gpt.data.hf_dataset import HFDatasetDataModule from nemo.collections.llm.gpt.data.mock import MockDataModule from nemo.collections.llm.gpt.data.squad import SquadDataModule @@ -42,7 +42,7 @@ def dolly() -> pl.LightningDataModule: @run.cli.factory @run.autoconvert def hf_dataset(dataset: str) -> pl.LightningDataModule: - return HfDatasetDataModule(dataset=dataset, global_batch_size=16, micro_batch_size=2) + return HFDatasetDataModule(dataset=dataset, global_batch_size=16, micro_batch_size=2) __all__ = ["mock", "squad", "dolly", "hf_dataset"] diff --git a/nemo/collections/llm/gpt/data/hf_dataset.py b/nemo/collections/llm/gpt/data/hf_dataset.py index 46562b6e72c8b..0f45ecf265b79 100644 --- a/nemo/collections/llm/gpt/data/hf_dataset.py +++ b/nemo/collections/llm/gpt/data/hf_dataset.py @@ -18,7 +18,7 @@ from nemo.lightning.pytorch.plugins import MegatronDataSampler -class HfDatasetDataModule(pl.LightningDataModule): +class HFDatasetDataModule(pl.LightningDataModule): def __init__( self, dataset, @@ -88,7 +88,7 @@ def train_dataloader(self, collate_fn=None): from nemo.lightning.data import add_megatron_sampler if collate_fn is None: - collate_fn = lambda x: HfDatasetDataModule.collate_fn(x, pad_token_id=self.pad_token_id) + collate_fn = lambda x: HFDatasetDataModule.collate_fn(x, pad_token_id=self.pad_token_id) return DataLoader( self.dataset, diff --git a/nemo/collections/llm/gpt/model/__init__.py b/nemo/collections/llm/gpt/model/__init__.py index 9f186ebba90fb..4e9448eaef2c0 100644 --- a/nemo/collections/llm/gpt/model/__init__.py +++ b/nemo/collections/llm/gpt/model/__init__.py @@ -45,7 +45,7 @@ Gemma2Config27B, Gemma2Model, ) -from nemo.collections.llm.gpt.model.hf_auto_model_for_causal_lm import HfAutoModelForCausalLM +from nemo.collections.llm.gpt.model.hf_auto_model_for_causal_lm import HFAutoModelForCausalLM from nemo.collections.llm.gpt.model.llama import ( CodeLlamaConfig7B, CodeLlamaConfig13B, @@ -191,5 +191,5 @@ "transformer_engine_layer_spec", "transformer_engine_full_layer_spec", "local_layer_spec", - "HfAutoModelForCausalLM", + "HFAutoModelForCausalLM", ] diff --git a/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py b/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py index 26e4604adc437..481dd9a0e187f 100644 --- a/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py +++ b/nemo/collections/llm/gpt/model/hf_auto_model_for_causal_lm.py @@ -31,7 +31,7 @@ def masked_cross_entropy(logits, targets, mask=None): return F.cross_entropy(logits, targets) -class HfAutoModelForCausalLM(pl.LightningModule, io.IOMixin, fn.FNMixin): +class HFAutoModelForCausalLM(pl.LightningModule, io.IOMixin, fn.FNMixin): def __init__( self, model_name='gpt2', @@ -41,6 +41,7 @@ def __init__( model_transform=None, model_accelerator=None, trust_remote_code=False, + default_dtype=torch.bfloat16, ): super().__init__() self.save_hyperparameters() @@ -53,11 +54,12 @@ def __init__( self.model_transform = model_transform self.model_accelerator = model_accelerator self.trust_remote_code = trust_remote_code + self.default_dtype = default_dtype @property def tokenizer(self): if self._tokenizer is None: - self._tokenizer = HfAutoModelForCausalLM.configure_tokenizer(self.model_name, self.trust_remote_code) + self._tokenizer = HFAutoModelForCausalLM.configure_tokenizer(self.model_name, self.trust_remote_code) return self._tokenizer @tokenizer.setter @@ -79,7 +81,10 @@ def configure_model(self): from transformers import AutoConfig config = AutoConfig.from_pretrained(self.model_name, trust_remote_code=self.trust_remote_code) - self.model = AutoModelForCausalLM.from_config(config, trust_remote_code=self.trust_remote_code) + dtype = getattr(config, 'torch_dtype', self.default_dtype) + self.model = AutoModelForCausalLM.from_config( + config, torch_dtype=dtype, trust_remote_code=self.trust_remote_code + ) if self.model_accelerator is not None: self.model_accelerator(self.model) diff --git a/nemo/collections/llm/recipes/hf_auto_model_for_causal_lm.py b/nemo/collections/llm/recipes/hf_auto_model_for_causal_lm.py index d93b167b45b67..5d2bea23686ce 100644 --- a/nemo/collections/llm/recipes/hf_auto_model_for_causal_lm.py +++ b/nemo/collections/llm/recipes/hf_auto_model_for_causal_lm.py @@ -23,7 +23,7 @@ from nemo import lightning as nl from nemo.collections.llm.api import finetune, pretrain from nemo.collections.llm.gpt.data.mock import MockDataModule -from nemo.collections.llm.gpt.model.hf_auto_model_for_causal_lm import HfAutoModelForCausalLM +from nemo.collections.llm.gpt.model.hf_auto_model_for_causal_lm import HFAutoModelForCausalLM from nemo.collections.llm.peft.lora import LoRA from nemo.collections.llm.recipes.log.default import default_log, default_resume, tensorboard_logger from nemo.collections.llm.recipes.optim.adam import pytorch_adam_with_cosine_annealing @@ -35,23 +35,23 @@ @run.cli.factory(name=NAME) def model(model_name, load_pretrained_weights) -> run.Config[pl.LightningModule]: """ - Factory function to create HfAutoModelForCausalLM model configurations. + Factory function to create HFAutoModelForCausalLM model configurations. Args: model_name (str): Model id on HF. Returns: - run.Config[pl.LightningModule]: Configuration for the HfAutoModelForCausalLM. + run.Config[pl.LightningModule]: Configuration for the HFAutoModelForCausalLM. Examples: CLI usage: - $ nemo llm pretrain --factory 'HfAutoModelForCausalLM(model_name="mistralai/Mistral-Nemo-Instruct-2407")' + $ nemo llm pretrain --factory 'HFAutoModelForCausalLM(model_name="mistralai/Mistral-Nemo-Instruct-2407")' Python API usage: >>> model_config = model(model_name="mistralai/Mistral-Nemo-Instruct-2407") >>> print(model_config) """ - return run.Config(HfAutoModelForCausalLM, model_name=model_name, load_pretrained_weights=load_pretrained_weights) + return run.Config(HFAutoModelForCausalLM, model_name=model_name, load_pretrained_weights=load_pretrained_weights) def trainer( @@ -69,7 +69,7 @@ def trainer( gradient_clip_val: float = 1.0, ) -> run.Config[nl.Trainer]: """ - Configure the NeMo Lightning Trainer for HfAutoModelForCausalLM. + Configure the NeMo Lightning Trainer for HFAutoModelForCausalLM. This function sets up the distributed training strategy and other training parameters. @@ -91,7 +91,7 @@ def trainer( Examples: CLI usage: - $ nemo llm pretrain trainer=HfAutoModelForCausalLM ... + $ nemo llm pretrain trainer=HFAutoModelForCausalLM ... Python API usage: >>> trainer_config = trainer(num_nodes=2, num_gpus_per_node=8) @@ -131,7 +131,7 @@ def pretrain_recipe( model_name: str = '', ) -> run.Partial: """ - Create a pre-training recipe for a HfAutoModelForCausalLM model. + Create a pre-training recipe for a HFAutoModelForCausalLM model. This function sets up a complete configuration for pre-training, including model, trainer, data, logging, optimization, and resumption settings. @@ -148,7 +148,7 @@ def pretrain_recipe( Examples: CLI usage: - $ nemo llm pretrain --factory 'HfAutoModelForCausalLM(model_name="mistralai/Mistral-Nemo-Instruct-2407")' + $ nemo llm pretrain --factory 'HFAutoModelForCausalLM(model_name="mistralai/Mistral-Nemo-Instruct-2407")' Python API usage: >>> recipe = pretrain_recipe(name="auto_pretrain", num_nodes=2, model_name="mistralai/Mistral-Nemo-Instruct-2407") @@ -179,7 +179,7 @@ def finetune_recipe( model_name: str = '', ) -> run.Partial: """ - Create a fine-tuning recipe for a HfAutoModelForCausalLM model. + Create a fine-tuning recipe for a HFAutoModelForCausalLM model. This function sets up a complete configuration for fine-tuning, including model, trainer, data, logging, optimization, and resumption settings. diff --git a/nemo/collections/llm/recipes/t5_11b.py b/nemo/collections/llm/recipes/t5_11b.py index ee7323aa044fb..c54bf48b96134 100644 --- a/nemo/collections/llm/recipes/t5_11b.py +++ b/nemo/collections/llm/recipes/t5_11b.py @@ -175,7 +175,8 @@ def pretrain_recipe( guide in the `examples/llm/pretrain/` directory. """ - opt_config = OptimizerConfig( + opt_config = run.Config( + OptimizerConfig, optimizer='adam', lr=0.0001, use_distributed_optimizer=True, @@ -183,7 +184,8 @@ def pretrain_recipe( weight_decay=0.01, ) - lr_scheduler = WarmupAnnealingScheduler( + lr_scheduler = run.Config( + WarmupAnnealingScheduler, warmup_steps=None, warmup_ratio=0.01, max_steps=1000000, @@ -202,7 +204,7 @@ def pretrain_recipe( MockDataModule, seq_length=512, seq_length_dec=128, global_batch_size=1920, micro_batch_size=24 ), log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), - optim=MegatronOptimizerModule(config=opt_config, lr_scheduler=lr_scheduler), + optim=run.Config(MegatronOptimizerModule, config=opt_config, lr_scheduler=lr_scheduler), resume=default_resume(), ) @@ -248,15 +250,17 @@ def finetune_recipe( on fine-tuning LLMs with NeMo, see the fine-tuning guide in the `examples/llm/finetune/` directory. """ - opt_config = OptimizerConfig( + opt_config = run.Config( + OptimizerConfig, optimizer='adam', - lr=1e-4, + lr=0.0001, use_distributed_optimizer=True, bf16=True, weight_decay=0.01, ) - lr_scheduler = WarmupAnnealingScheduler( + lr_scheduler = run.Config( + WarmupAnnealingScheduler, warmup_steps=50, max_steps=2000, min_lr=0.00001, @@ -273,7 +277,7 @@ def finetune_recipe( SquadDataModule, seq_length=512, seq_length_dec=128, global_batch_size=128, micro_batch_size=1 ), log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), - optim=MegatronOptimizerModule(config=opt_config, lr_scheduler=lr_scheduler), + optim=run.Config(MegatronOptimizerModule, config=opt_config, lr_scheduler=lr_scheduler), resume=nemo_resume(checkpoint_path), ) diff --git a/nemo/collections/llm/recipes/t5_3b.py b/nemo/collections/llm/recipes/t5_3b.py index 82772e1b865ab..b1783594d2f7b 100644 --- a/nemo/collections/llm/recipes/t5_3b.py +++ b/nemo/collections/llm/recipes/t5_3b.py @@ -175,7 +175,8 @@ def pretrain_recipe( guide in the `examples/llm/pretrain/` directory. """ - opt_config = OptimizerConfig( + opt_config = run.Config( + OptimizerConfig, optimizer='adam', lr=0.0001, use_distributed_optimizer=True, @@ -183,7 +184,8 @@ def pretrain_recipe( weight_decay=0.01, ) - lr_scheduler = WarmupAnnealingScheduler( + lr_scheduler = run.Config( + WarmupAnnealingScheduler, warmup_steps=None, warmup_ratio=0.01, max_steps=1000000, @@ -202,7 +204,7 @@ def pretrain_recipe( MockDataModule, seq_length=512, seq_length_dec=128, global_batch_size=1920, micro_batch_size=24 ), log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), - optim=MegatronOptimizerModule(config=opt_config, lr_scheduler=lr_scheduler), + optim=run.Config(MegatronOptimizerModule, config=opt_config, lr_scheduler=lr_scheduler), resume=default_resume(), ) @@ -248,15 +250,17 @@ def finetune_recipe( on fine-tuning LLMs with NeMo, see the fine-tuning guide in the `examples/llm/finetune/` directory. """ - opt_config = OptimizerConfig( + opt_config = run.Config( + OptimizerConfig, optimizer='adam', - lr=1e-4, + lr=0.0001, use_distributed_optimizer=True, bf16=True, weight_decay=0.01, ) - lr_scheduler = WarmupAnnealingScheduler( + lr_scheduler = run.Config( + WarmupAnnealingScheduler, warmup_steps=50, max_steps=2000, min_lr=0.00001, @@ -273,7 +277,7 @@ def finetune_recipe( SquadDataModule, seq_length=512, seq_length_dec=128, global_batch_size=128, micro_batch_size=1 ), log=default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), - optim=MegatronOptimizerModule(config=opt_config, lr_scheduler=lr_scheduler), + optim=run.Config(MegatronOptimizerModule, config=opt_config, lr_scheduler=lr_scheduler), resume=nemo_resume(checkpoint_path), ) diff --git a/nemo/collections/multimodal/data/energon/task_encoder.py b/nemo/collections/multimodal/data/energon/task_encoder.py index 23758b3a43dbf..7a8d0f0ab0336 100644 --- a/nemo/collections/multimodal/data/energon/task_encoder.py +++ b/nemo/collections/multimodal/data/energon/task_encoder.py @@ -48,7 +48,8 @@ class MultiModalTaskEncoder( and similarity interleaved samples. This class extends the DefaultTaskEncoder and provides a flexible mechanism to handle and encode - different types of multimodal data. Support for VQA, captioning and interleaved samples is provided by default. It supports registering custom encoders for each sample type + different types of multimodal data. Support for VQA, captioning and interleaved samples is provided by default. + It supports registering custom encoders for each sample type and provides methods for encoding individual samples, batching them, and further processing the batch for model input. """ @@ -59,8 +60,8 @@ def __init__(self, tokenizer, image_processor, multimodal_sample_config): Parameters: tokenizer (Tokenizer): The tokenizer used for processing text across different sample types. - image_processor (ImageProcessor): The image processor used for preprocessing images across different sample types. - multimodal_sample_config (MultiModalSampleConfig): Configuration object for multimodal samples, including tokens and placeholders. + image_processor (ImageProcessor): The image processor used for preprocessing images. + multimodal_sample_config (MultiModalSampleConfig): MultiModalSampleConfig object. """ self.tokenizer = tokenizer self.encoders: Dict[str, SampleEncoder] = { @@ -173,5 +174,6 @@ def encode_batch(self, batch_data: ImageTextRawBatch) -> dict: position_ids = torch.arange(seq_length, dtype=torch.long) position_ids = position_ids.unsqueeze(0).repeat(micro_batch_size, 1) batch_dict['position_ids'] = position_ids - batch_dict['attention_mask'] = None + if 'attention_mask' not in batch_dict: + batch_dict['attention_mask'] = None return batch_dict diff --git a/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py index 898ddb7d716be..9da2419520c21 100644 --- a/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/gpt_sft_dataset.py @@ -573,15 +573,23 @@ def _build_samples_mapping(self): self.samples_mapping = None def _build_loss_mask(self, processed_example): + seq_boundaries = processed_example['seq_boundaries'] if self.answer_only_loss: - seq_boundaries = processed_example['seq_boundaries'] return np.concatenate( [ processed_example['loss_mask'][seq_boundaries[i] + 1 : seq_boundaries[i + 1]] for i in range(len(seq_boundaries) - 1) ] ) - return [1.0] * (len(processed_example['input_ids']) - len(processed_example['seq_boundaries']) + 1) + return np.concatenate( + [ + [ + 0 if x == self.tokenizer.eos_id else 1.0 + for x in processed_example['input_ids'][seq_boundaries[i] : seq_boundaries[i + 1] - 1] + ] + for i in range(len(seq_boundaries) - 1) + ] + ) def _maybe_cast_to_list(self, x): return [item.tolist() if isinstance(item, np.ndarray) else item for item in x] @@ -622,16 +630,40 @@ def collate_fn(self, batch): position_ids: List[List[int]] = [] cu_seqlens: List[List[int]] = [] + cu_seqlens_unpadded: List[List[int]] = [] for item in batch: position_ids.append([]) cu_seqlens.append([0]) + cu_seqlens_unpadded.append([0]) seqlens = np.array(item['seq_boundaries'][1:]) - np.array(item['seq_boundaries'][:-1]) for l in seqlens: # length minus 1 because input_ids is truncated by 1 for labels position_ids[-1].extend(list(range(l - 1))) cu_seqlens[-1].append(cu_seqlens[-1][-1] + l - 1) - # set last seq to the max seq len because rope and attn kernels expect no padding - cu_seqlens[-1][-1] = max_length + + # the last seq needs to be the max seq len because rope and attn kernels expect no padding + assert cu_seqlens[-1][-1] <= max_length + + # since data is prepadded when cp_size > 1, there may be some extra padding at the end + # of the packed sequence. In this case, we need to add the max seq len to the end. + if cu_seqlens[-1][-1] != max_length: + cu_seqlens[-1].append(max_length) + + for i in range(len(item['seq_boundaries']) - 1): + current_seq = item['input_ids'][item['seq_boundaries'][i] : item['seq_boundaries'][i + 1] - 1] + + # since the data could be prepadded with tokenizer's eos_id, we can find out the index of all the eos_id + eos_idx = np.where(np.array(current_seq) == self.tokenizer.eos_id) + + # The second eos_id index marks the length of the original unpadded sequence if the sequence is + # prepadded for cp_size > 1. Otherwise, there is no extra padding. + seqlen_unpadded = eos_idx[0][0] + 1 if eos_idx[0].any() else len(current_seq) + cu_seqlens_unpadded[-1].append(cu_seqlens_unpadded[-1][-1] + seqlen_unpadded) + + # if extra paddings are added in the packed sequence, they can't be counted as + # actual tokens for training + if len(cu_seqlens[-1]) > len(cu_seqlens_unpadded[-1]): + cu_seqlens_unpadded[-1].append(cu_seqlens_unpadded[-1][-1]) assert len(input_ids[0]) == len( position_ids[0] @@ -652,12 +684,16 @@ def collate_fn(self, batch): if self.return_cu_seqlen: cu_seqlens = self._collate_item(cu_seqlens, max_length=max(len(l) for l in cu_seqlens) + 1, pad_id=-1) - + cu_seqlens_unpadded = self._collate_item( + cu_seqlens_unpadded, max_length=max(len(l) for l in cu_seqlens_unpadded) + 1, pad_id=-1 + ) # Pre-generate `cu_seqlens_argmin` and `max_seqlen` as CPU tensor to avoid device-to-host copies. cu_seqlens = torch.IntTensor(cu_seqlens) cu_seqlens_argmin = torch.argmin(cu_seqlens, dim=1, keepdim=True) seqlens = cu_seqlens[:, 1:] - cu_seqlens[:, :-1] max_seqlen, _ = seqlens.max(dim=1, keepdim=True) + cu_seqlens_unpadded = torch.IntTensor(cu_seqlens_unpadded) + cu_seqlens_unpadded_argmin = torch.argmin(cu_seqlens_unpadded, dim=1, keepdim=True) processed_batch.update( { @@ -667,6 +703,8 @@ def collate_fn(self, batch): 'cu_seqlens': torch.IntTensor(cu_seqlens), # cu_seqlens_q must be in dtype torch.int32 'cu_seqlens_argmin': cu_seqlens_argmin, # only required for perf 'max_seqlen': max_seqlen, # only required for perf + 'cu_seqlens_unpadded': torch.IntTensor(cu_seqlens_unpadded), + 'cu_seqlens_unpadded_argmin': cu_seqlens_unpadded_argmin, } ) else: diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index f71b1ad13c6db..a4b8242e01858 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -1231,22 +1231,23 @@ def get_batch_on_this_context_parallel_rank(self, batch): cp_size = parallel_state.get_context_parallel_world_size() if cp_size > 1: cp_rank = parallel_state.get_context_parallel_rank() - for key, val in batch.items(): - if val is not None and key != "context_lengths": - seq_dim = 1 if key != 'attention_mask' else 2 - val = val.view( - *val.shape[0:seq_dim], - 2 * cp_size, - val.shape[seq_dim] // (2 * cp_size), - *val.shape[(seq_dim + 1) :], - ) - index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device="cpu", pin_memory=True).cuda( - non_blocking=True - ) - val = val.index_select(seq_dim, index) - val = val.view(*val.shape[0:seq_dim], -1, *val.shape[(seq_dim + 2) :]) - batch[key] = val - + # check if the batch is not in THD format + if 'cu_seqlens' not in batch: + for key, val in batch.items(): + if val is not None and key != "context_lengths": + seq_dim = 1 if key != 'attention_mask' else 2 + val = val.view( + *val.shape[0:seq_dim], + 2 * cp_size, + val.shape[seq_dim] // (2 * cp_size), + *val.shape[(seq_dim + 1) :], + ) + index = torch.tensor( + [cp_rank, (2 * cp_size - cp_rank - 1)], device="cpu", pin_memory=True + ).cuda(non_blocking=True) + val = val.index_select(seq_dim, index) + val = val.view(*val.shape[0:seq_dim], -1, *val.shape[(seq_dim + 2) :]) + batch[key] = val batch['num_valid_tokens_in_ub'] = num_valid_tokens_in_ub return batch @@ -1261,12 +1262,17 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_ required_keys = set() max_seqlen = batch['max_seqlen'].squeeze() if 'max_seqlen' in batch else None cu_seqlens_argmin = batch['cu_seqlens_argmin'] if 'cu_seqlens_argmin' in batch else None + cu_seqlens_unpadded_argmin = ( + batch['cu_seqlens_unpadded_argmin'] if 'cu_seqlens_unpadded_argmin' in batch else None + ) if parallel_state.get_pipeline_model_parallel_world_size() == 1: required_keys.update(batch.keys()) else: required_keys.add('attention_mask') if 'cu_seqlens' in batch: required_keys.add('cu_seqlens') + if 'cu_seqlens_unpadded' in batch: + required_keys.add('cu_seqlens_unpadded') if parallel_state.is_pipeline_first_stage(): required_keys.update(('tokens', 'position_ids')) if parallel_state.is_pipeline_last_stage(): @@ -1301,12 +1307,16 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_ if 'cu_seqlens' in batch: # packed sequence from GPTSFTPackedDataset # these args are passed eventually into TEDotProductAttention.forward() cu_seqlens = batch['cu_seqlens'].squeeze() # remove batch size dimension (mbs=1) + cu_seqlens_unpadded = batch['cu_seqlens_unpadded'].squeeze() # remove -1 "paddings" added in collate_fn if cu_seqlens_argmin is not None: cu_seqlens = cu_seqlens[: cu_seqlens_argmin.item()] else: cu_seqlens = cu_seqlens[: torch.argmin(cu_seqlens)] - + if cu_seqlens_unpadded_argmin is not None: + cu_seqlens_unpadded = cu_seqlens_unpadded[: cu_seqlens_unpadded_argmin.item()] + else: + cu_seqlens_unpadded = cu_seqlens_unpadded[: torch.argmin(cu_seqlens_unpadded)] try: from megatron.core.packed_seq_params import PackedSeqParams except (ImportError, ModuleNotFoundError) as e: @@ -1317,9 +1327,42 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_ ) raise e + # get packed sequences for this context parallel rank + cp_size = parallel_state.get_context_parallel_world_size() + if cp_size > 1: + try: + import transformer_engine_torch as tex + except ModuleNotFoundError as e: + logging.error( + "Please update Transformer Engine to >= 1.10 to use Context Parallel with THD format data" + ) + raise e + cp_rank = parallel_state.get_context_parallel_rank() + for key in required_keys: + val = batch[key] + if key not in { + "cu_seqlens", + "cu_seqlens_unpadded", + "cu_seqlens_argmin", + "cu_seqlens_unpadded_argmin", + "max_seqlen", + "token_count", + }: + index = tex.thd_get_partitioned_indices(cu_seqlens, val.size(1), cp_size, cp_rank) + val = val.index_select(1, index) + batch[key] = val + forward_args = { + 'input_ids': batch['tokens'], + 'position_ids': batch['position_ids'], + 'attention_mask': None if self.get_attention_mask_from_fusion else batch['attention_mask'], + 'labels': batch['labels'] if 'labels' in batch else None, + } + forward_args['packed_seq_params'] = PackedSeqParams( - cu_seqlens_q=cu_seqlens, - cu_seqlens_kv=cu_seqlens, + cu_seqlens_q=cu_seqlens_unpadded, + cu_seqlens_kv=cu_seqlens_unpadded, + cu_seqlens_q_padded=cu_seqlens, + cu_seqlens_kv_padded=cu_seqlens, max_seqlen_q=max_seqlen, max_seqlen_kv=max_seqlen, qkv_format='thd', diff --git a/nemo/collections/vlm/__init__.py b/nemo/collections/vlm/__init__.py index 266790f3af71f..b5e693830fa54 100644 --- a/nemo/collections/vlm/__init__.py +++ b/nemo/collections/vlm/__init__.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from nemo.collections.vlm.llava_next.data import LlavaNextMockDataModule, LlavaNextTaskEncoder +from nemo.collections.vlm.llava_next.model.base import LlavaNextConfig +from nemo.collections.vlm.llava_next.model.llava_next import LlavaNextConfig7B, LlavaNextConfig13B, LlavaNextModel from nemo.collections.vlm.mllama.data import MLlamaLazyDataModule, MLlamaMockDataModule from nemo.collections.vlm.mllama.model.base import ( CrossAttentionTextConfig, @@ -29,7 +32,6 @@ DataConfig, ImageDataConfig, ImageToken, - LlavaNextTaskEncoder, MultiModalToken, NevaLazyDataModule, NevaMockDataModule, @@ -81,4 +83,10 @@ "MLlamaConfig90BInstruct", "mllama_11b", "mllama_90b", + "llava_next_7b", + "LlavaNextConfig7B", + "LlavaNextConfig13B", + "LlavaNextModel", + "LlavaNextMockDataModule", + "LlavaNextTaskEncoder", ] diff --git a/nemo/collections/vlm/llava_next/__init__.py b/nemo/collections/vlm/llava_next/__init__.py new file mode 100644 index 0000000000000..d9155f923f186 --- /dev/null +++ b/nemo/collections/vlm/llava_next/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/collections/vlm/llava_next/data/__init__.py b/nemo/collections/vlm/llava_next/data/__init__.py new file mode 100644 index 0000000000000..1c71e5355f4be --- /dev/null +++ b/nemo/collections/vlm/llava_next/data/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nemo.collections.vlm.llava_next.data.energon import LlavaNextTaskEncoder +from nemo.collections.vlm.llava_next.data.mock import MockDataModule as LlavaNextMockDataModule + +__all__ = [ + "LlavaNextMockDataModule", + "LlavaNextTaskEncoder", +] diff --git a/nemo/collections/vlm/neva/data/llava_next_energon.py b/nemo/collections/vlm/llava_next/data/energon.py similarity index 71% rename from nemo/collections/vlm/neva/data/llava_next_energon.py rename to nemo/collections/vlm/llava_next/data/energon.py index c45ee50e5be34..effa3236ade71 100644 --- a/nemo/collections/vlm/neva/data/llava_next_energon.py +++ b/nemo/collections/vlm/llava_next/data/energon.py @@ -13,7 +13,7 @@ # limitations under the License. from dataclasses import dataclass, field -from typing import Dict, List +from typing import Dict, List, Optional import torch from megatron.energon import VQASample, batch_list, batch_pad_stack @@ -25,16 +25,48 @@ from nemo.utils import logging +@dataclass class LlavaNextTextSample(ImageTextSample): + ''' + Sample type for LLaVA-Next, extending ImageTextSample to support tiled image data. + + This class adds additional attributes for handling high-resolution images processed as tiles, + along with metadata about the tiled images. + + Attributes: + num_media_tiles (int): The number of tiles used to represent the high-resolution image. + image_sizes (torch.Tensor): A tensor representing the sizes of the tiled images. + attention_mask (Optional[torch.Tensor]): An optional attention mask for the sample, + used to determine which tokens or tiles are attended to during processing. Defaults to None. + ''' + num_media_tiles: int = 0 + image_sizes: torch.tensor = field(default_factory=lambda: torch.tensor([])) + attention_mask: Optional[torch.tensor] = None @dataclass class LlavaNextTextRawBatch(ImageTextRawBatch): + """ + Batch type for raw LLaVA-Next samples, supporting tiled image data. + + This class aggregates multiple `LlavaNextTextSample` instances into a batch for processing. + It includes attributes for managing tiled images and associated metadata for each sample in the batch. + + Attributes: + num_media_tiles (List[int]): A list containing the number of tiles for each image in the batch. + image_sizes (torch.Tensor): A tensor containing the sizes of all tiled images in the batch. + attention_mask (Optional[torch.Tensor]): Attention mask. Defaults to None. + """ + num_media_tiles: List[int] = field(default_factory=list) + image_sizes: torch.tensor = field(default_factory=lambda: torch.tensor([])) + attention_mask: Optional[torch.tensor] = None class LlavaNextSampleEncoder(VQASampleEncoder): + """LlavaNextSampleEncoder""" + def __init__(self, tokenizer, image_processor, multimodal_sample_config=MultiModalSampleConfig()): """ Initialize the LlavaNextSampleEncoder, inherited from VQASampleEncoder for multimodal samples @@ -81,15 +113,14 @@ def encode(self, input_sample: VQASample, output_sample: LlavaNextTextSample): images, loss masks, and metadata. """ conversation_prompt = self.apply_prompt_template(input_sample) - logging.debug(f"task encoder encode_sample conversation_prompt {conversation_prompt}") + logging.debug(f"[Energon] task encoder encode_sample conversation_prompt {conversation_prompt}") # tokenize prompt tokens = self.tokenize(conversation_prompt) labels = self.compute_labels(tokens, input_sample) - tokens = tokens[:-1].contiguous() labels = labels[1:].contiguous() - logging.debug(f"task encoder encode_sample after tokenize prompt tokens {tokens}") - logging.debug(f"task encoder encode_sample lables {labels}") + logging.debug(f"[Energon] task encoder encode_sample after tokenize prompt tokens {tokens}") + logging.debug(f"[Energon] task encoder encode_sample lables {labels}") loss_mask = self.compute_loss_mask(labels) processed_image = self.process_image(input_sample.image) output_sample.__key__ = input_sample.__key__ @@ -98,10 +129,16 @@ def encode(self, input_sample: VQASample, output_sample: LlavaNextTextSample): output_sample.labels = labels output_sample.loss_mask = loss_mask output_sample.num_media_tiles = processed_image.shape[0] + output_sample.attention_mask = torch.ones(len(tokens), dtype=torch.long) + height = input_sample.image.shape[1] + width = input_sample.image.shape[2] + output_sample.image_sizes = torch.tensor([[height, width]], dtype=torch.long) return output_sample class LlavaNextTaskEncoder(MultiModalTaskEncoder): + """LlavaNextTaskEncoder""" + def __init__(self, tokenizer, image_processor, multimodal_sample_config): """ Initialize the LlavaNextTaskEncoder. @@ -133,7 +170,16 @@ def batch(self, samples: List[LlavaNextTextSample]) -> LlavaNextTextRawBatch: LlavaNextTextRawBatch: A batch containing all input samples' images, tokens, labels, loss masks, and other metadata prepared for model processing. """ - keys, images, tokens, labels, loss_mask, num_media_tiles = [], [], [], [], [], [] + keys, images, tokens, labels, loss_mask, num_media_tiles, image_sizes, attention_mask = ( + [], + [], + [], + [], + [], + [], + [], + [], + ) for sample in samples: keys.append(sample.__key__) images.append(sample.images) @@ -141,6 +187,8 @@ def batch(self, samples: List[LlavaNextTextSample]) -> LlavaNextTextRawBatch: labels.append(sample.labels) loss_mask.append(sample.loss_mask) num_media_tiles.append(sample.num_media_tiles) + image_sizes.append(sample.image_sizes) + attention_mask.append(sample.attention_mask) batch_keys = batch_list(keys) @@ -148,8 +196,9 @@ def batch(self, samples: List[LlavaNextTextSample]) -> LlavaNextTextRawBatch: batch_tokens = pad_sequence(tokens, batch_first=True) batch_labels = pad_sequence(labels, batch_first=True) - + image_sizes = torch.cat(image_sizes, dim=0) batch_loss_mask = batch_pad_stack(loss_mask) + batch_attention_mask = batch_pad_stack(attention_mask) batch_num_media_tiles = torch.tensor(batch_list(num_media_tiles), dtype=torch.int) return LlavaNextTextRawBatch( __keys__=batch_keys, @@ -158,4 +207,6 @@ def batch(self, samples: List[LlavaNextTextSample]) -> LlavaNextTextRawBatch: labels=batch_labels, loss_mask=batch_loss_mask, num_media_tiles=batch_num_media_tiles, + image_sizes=image_sizes, + attention_mask=batch_attention_mask, ) diff --git a/nemo/collections/vlm/llava_next/data/mock.py b/nemo/collections/vlm/llava_next/data/mock.py new file mode 100644 index 0000000000000..f61df7336e6fd --- /dev/null +++ b/nemo/collections/vlm/llava_next/data/mock.py @@ -0,0 +1,311 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, Optional + +import lightning.pytorch as pl +import numpy as np +import torch +from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS +from torch.utils import data +from torch.utils.data import DataLoader, Dataset + +from nemo.collections.vlm.neva.data.multimodal_tokens import IMAGE_TOKEN_INDEX +from nemo.lightning.pytorch.plugins import MegatronDataSampler +from nemo.utils import logging + + +class MockDataModule(pl.LightningDataModule): + """ + A mock data module for LLaVA-Next training, validation, and testing. + + Provides datasets and data loaders for training, validation, and testing phases. + Includes data sampling and preprocessing for multimodal tasks. + """ + + def __init__( + self, + seq_length: int = 2048, + decoder_seq_length: Optional[int] = None, + tokenizer=None, + image_processor=None, + micro_batch_size: int = 4, + global_batch_size: int = 8, + rampup_batch_size: Optional[List[int]] = None, + num_train_samples: int = 10_000_000, + num_val_samples: int = 10_000_000, + num_test_samples: int = 10_000_000, + num_workers: int = 8, + pin_memory: bool = True, + persistent_workers: bool = False, + ): + """ + Initializes the mock data module with data sampling and preprocessing configurations. + + Args: + seq_length (int): Maximum sequence length for tokens. + decoder_seq_length (Optional[int]): Sequence length for the decoder. + tokenizer: Tokenizer for text processing. + image_processor: Processor for image preprocessing. + micro_batch_size (int): Batch size per GPU. + global_batch_size (int): Total batch size across GPUs. + rampup_batch_size (Optional[List[int]]): Batch size ramp-up schedule. + num_train_samples (int): Number of training samples. + num_val_samples (int): Number of validation samples. + num_test_samples (int): Number of testing samples. + num_workers (int): Number of workers for data loading. + pin_memory (bool): Whether to pin memory for data loaders. + persistent_workers (bool): Whether to keep workers alive after the first iteration. + """ + super().__init__() + self.seq_length = seq_length + self.decoder_seq_len = decoder_seq_length + self.num_train_samples = num_train_samples + self.num_val_samples = num_val_samples + self.num_test_samples = num_test_samples + self.num_workers = num_workers + self.pin_memory = pin_memory + self.persistent_workers = persistent_workers + + if tokenizer is None or image_processor is None: + logging.warning( + f"Processor or tokenizer are not provided! Fall back to `llava-hf/llava-v1.6-vicuna-7b-hf`." + ) + from transformers import AutoProcessor + + from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer + + model_name = "llava-hf/llava-v1.6-vicuna-7b-hf" + + processor = AutoProcessor.from_pretrained(model_name) + self.tokenizer = tokenizer or AutoTokenizer(model_name) + self.image_processor = image_processor or processor.image_processor + self.data_sampler = MegatronDataSampler( + seq_len=self.seq_length, + decoder_seq_len=self.decoder_seq_len, + micro_batch_size=micro_batch_size, + global_batch_size=global_batch_size, + rampup_batch_size=rampup_batch_size, + ) + + def setup(self, stage: str = "") -> None: + """ + Sets up the training, validation, and testing datasets. + + Args: + stage (str): Stage of the setup ('train', 'valid', 'test'). + """ + self._train_ds = _MockLlavaNextDataset( + self.tokenizer, self.image_processor, "train", self.num_train_samples, self.seq_length + ) + self._validation_ds = _MockLlavaNextDataset( + self.tokenizer, self.image_processor, "valid", self.num_val_samples, self.seq_length + ) + self._test_ds = _MockLlavaNextDataset( + self.tokenizer, self.image_processor, "test", self.num_test_samples, self.seq_length + ) + + def train_dataloader(self) -> TRAIN_DATALOADERS: + """ + Creates the training data loader. + + Returns: + TRAIN_DATALOADERS: Training data loader. + """ + if not hasattr(self, "_train_ds"): + self.setup() + return self._create_dataloader(self._train_ds) + + def val_dataloader(self) -> EVAL_DATALOADERS: + """ + Creates the validation data loader. + + Returns: + EVAL_DATALOADERS: Validation data loader. + """ + if not hasattr(self, "_validation_ds"): + self.setup() + return self._create_dataloader(self._validation_ds) + + def test_dataloader(self) -> EVAL_DATALOADERS: + """ + Creates the testing data loader. + + Returns: + TEST_DATALOADERS: Testing data loader. + """ + if not hasattr(self, "_test_ds"): + self.setup() + return self._create_dataloader(self._test_ds) + + def _create_dataloader(self, dataset, **kwargs) -> DataLoader: + """ + Creates a generic data loader for the given dataset. + + Args: + dataset: The dataset for which the data loader is created. + **kwargs: Additional arguments for the DataLoader. + + Returns: + DataLoader: The created data loader. + """ + return DataLoader( + dataset, + num_workers=self.num_workers, + pin_memory=self.pin_memory, + persistent_workers=self.persistent_workers, + collate_fn=dataset.collate_fn, + **kwargs, + ) + + +class _MockLlavaNextDataset(Dataset): + """ + A mock dataset for LLaVA-Next, generating synthetic multimodal data. + + Attributes: + tokenizer: Tokenizer for text inputs. + image_processor: Processor for image inputs. + name (str): Name of the dataset ('train', 'valid', 'test'). + num_samples (int): Number of samples in the dataset. + seq_length (int): Sequence length for text tokens. + seed (int): Random seed for reproducibility. + """ + + def __init__( + self, + tokenizer, + image_processor, + name: str, + num_samples: int, + seq_length: int, + seed: int = 42, + ) -> None: + """ + Initializes the mock dataset with synthetic multimodal data. + + Args: + tokenizer: Tokenizer for text inputs. + image_processor: Processor for image inputs. + name (str): Dataset name ('train', 'valid', 'test'). + num_samples (int): Total number of samples in the dataset. + seq_length (int): Sequence length for text tokens. + seed (int): Random seed for data generation. + """ + super().__init__() + self.name = name + self.seq_length = seq_length + + self.vocab_size = tokenizer.vocab_size + + crop_size = image_processor.crop_size + self.image_height, self.image_width = crop_size["height"], crop_size["width"] + + self.length = num_samples + self.seed = seed + + self.loss_mask = torch.ones(self.seq_length, dtype=torch.float) + self.position_ids = torch.arange(self.seq_length, dtype=torch.int64) + self.tokenizer = tokenizer + self.image_processor = image_processor + + def __len__(self) -> int: + """ + Returns the length of the dataset. + + Returns: + int: Number of samples in the dataset. + """ + return self.length + + def _get_text(self, idx: int) -> np.ndarray: + """ + Generates synthetic text data. + + Args: + idx (int): Index of the sample. + + Returns: + np.ndarray: Synthetic text token IDs. + """ + np_gen = np.random.default_rng(seed=(self.seed + idx)) + return np_gen.integers(self.vocab_size, size=[self.seq_length], dtype=np.int64) + + def __getitem__(self, idx) -> Dict[str, torch.Tensor]: + """ + Generates a synthetic multimodal sample. + + Args: + idx (int): Index of the sample. + + Returns: + Dict[str, torch.Tensor]: A dictionary containing synthetic tokens, images, and metadata. + """ + # Generate data of the expected size and datatype (based on GPTDataset). + np_gen = np.random.default_rng(seed=(self.seed + idx)) + tokens = torch.from_numpy(np_gen.integers(self.vocab_size, size=[self.seq_length + 1], dtype=np.int64)) + tokens[2] = IMAGE_TOKEN_INDEX # ImageToken token index + labels = tokens.clone() + images = torch.from_numpy(np_gen.random(size=[3, self.image_height, self.image_width], dtype=np.float32)) + tokens = tokens[:-1] + labels = labels[1:] + + # attention_mask, image_sizes, num_media_tiles required for llava-next. Neva model will ignore these + attention_mask = torch.ones(len(tokens), dtype=torch.long) + image_sizes = torch.tensor([[self.image_height, self.image_width]], dtype=torch.long) + image_array = self.image_processor.preprocess(images, return_tensors='pt', do_rescale=False)['pixel_values'][0] + num_media_tiles = image_array.shape[0] + return { + "media": image_array, + "tokens": tokens, + "labels": labels, + "loss_mask": self.loss_mask, + "position_ids": self.position_ids, + "image_sizes": image_sizes, + "num_media_tiles": num_media_tiles, + "attention_mask": attention_mask, + } + + def _collate_fn(self, batch): + """ + A default implementation of a collation function. + Users should override this method to define custom data loaders. + """ + collated_batch = data.dataloader.default_collate(batch) + + collated_batch['media'] = collated_batch['media'].contiguous().view(-1, *collated_batch['media'].shape[2:]) + collated_batch['image_sizes'] = ( + collated_batch['image_sizes'].contiguous().view(-1, *collated_batch['image_sizes'].shape[2:]) + ) + return collated_batch + + def collate_fn(self, batch): + """Method that user pass as functor to DataLoader. + + The method optionally performs neural type checking and add types to the outputs. + + Please note, subclasses of Dataset should not implement `input_types`. + + # Usage: + dataloader = torch.utils.data.DataLoader( + ...., + collate_fn=dataset.collate_fn, + .... + ) + + Returns + ------- + Collated batch, with or without types. + """ + return self._collate_fn(batch) diff --git a/nemo/collections/vlm/llava_next/model/__init__.py b/nemo/collections/vlm/llava_next/model/__init__.py new file mode 100644 index 0000000000000..6d7b02482f62e --- /dev/null +++ b/nemo/collections/vlm/llava_next/model/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from nemo.collections.vlm.llava_next.model.base import LlavaNextConfig +from nemo.collections.vlm.llava_next.model.llava_next import LlavaNextConfig7B, LlavaNextConfig13B, LlavaNextModel + +__all__ = [ + "LlavaNextConfig", + "LlavaNextModel", + "LlavaNextConfig7B", + "LlavaNextConfig13B", +] diff --git a/nemo/collections/vlm/llava_next/model/base.py b/nemo/collections/vlm/llava_next/model/base.py new file mode 100644 index 0000000000000..7968c720db0e9 --- /dev/null +++ b/nemo/collections/vlm/llava_next/model/base.py @@ -0,0 +1,373 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass, field +from typing import Callable, Dict, List, Optional + +import torch +import torch.distributed +from megatron.core import parallel_state as ps +from megatron.core.inference_params import InferenceParams +from megatron.core.tensor_parallel import gather_from_sequence_parallel_region + +from nemo.collections.llm.gpt.model.base import get_batch_on_this_context_parallel_rank, get_packed_seq_params +from nemo.collections.vlm.llava_next.model.utils import merge_input_ids_with_image_features, pack_image_features +from nemo.collections.vlm.neva.data.multimodal_tokens import IMAGE_TOKEN_INDEX + + +def llava_next_data_step(dataloader_iter) -> Dict[str, torch.Tensor]: + """ + Processes a batch of data from the dataloader for the LLaVA Next model. + + Args: + dataloader_iter (Iterator): An iterator that provides batches of data from the dataloader. + + Returns: + Dict[str, torch.Tensor]: A dictionary containing the processed batch, ready for input into the model. + + Notes: + - Filters and moves required keys to the appropriate device. + - Slices the batch along the sequence dimension for context parallelism. + """ + from megatron.core import parallel_state + + # Based on: https://github.com/NVIDIA/Megatron-LM/blob/main/pretrain_gpt.py#L87 + # https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/models/language_modeling/ + # megatron_gpt_model.py#L828-L842 + batch = next(dataloader_iter) + _batch: dict + if isinstance(batch, tuple) and len(batch) == 3: + _batch = batch[0] + else: + _batch = batch + + required_keys = set() + required_keys.update( + ( + "tokens", + "attention_mask", + "media", + "num_media_tiles", + "image_sizes", + ) + ) + if parallel_state.is_pipeline_first_stage(): + required_keys.update(("position_ids", "attention_mask")) + if parallel_state.is_pipeline_last_stage(): + required_keys.update(("labels", "loss_mask", "attention_mask")) + + _batch = { + key: val.cuda(non_blocking=True) if key in required_keys and val is not None else None + for key, val in _batch.items() + } + # slice batch along sequence dimension for context parallelism + output = get_batch_on_this_context_parallel_rank(_batch) + + return output + + +def llava_next_forward_step(model, batch) -> torch.Tensor: + """ + Performs the forward step for the LLaVA Next model. + + Args: + model (torch.nn.Module): The LLaVA Next model instance. + batch (Dict[str, torch.Tensor]): A dictionary containing input tensors for the forward step. + + Returns: + torch.Tensor: The output from the model's forward computation. + + Notes: + - Constructs the forward arguments based on the provided batch. + - Includes optional parameters like packed sequence parameters if available. + """ + forward_args = { + "media": batch["media"], + "input_ids": batch["tokens"], + "position_ids": batch["position_ids"], + "attention_mask": batch.get("attention_mask", None), + "loss_mask": batch.get("loss_mask", None), + "labels": batch.get("labels", None), + "image_sizes": batch.get("image_sizes", None), + "num_media_tiles": batch.get("num_media_tiles", None), + } + + if 'cu_seqlens' in batch: + forward_args['packed_seq_params'] = get_packed_seq_params(batch) + return model(**forward_args) + + +from nemo.collections.vlm.neva.model.base import MCoreNevaModel, NevaConfig + + +@dataclass +class LlavaNextConfig(NevaConfig): + """ + Configuration class for the LLaVA Next model. + Overrides NevaConfig and modifies forward and data step fn. + + """ + + forward_step_fn: Callable = field(default=llava_next_forward_step) + data_step_fn: Callable = field(default=llava_next_data_step) + + def configure_model(self, tokenizer) -> "MCoreLlavaNextModel": + """ + Configures the LLaVA Next model with the appropriate settings. + + Args: + tokenizer: Tokenizer instance to be used with the model. + + Returns: + MCoreLlavaNextModel: An instance of the LLaVA Next model. + """ + + self.language_transformer_config.tensor_model_parallel_size = self.tensor_model_parallel_size + self.language_transformer_config.sequence_parallel = self.sequence_parallel + self.vision_transformer_config.tensor_model_parallel_size = self.tensor_model_parallel_size + self.vision_projection_config.tensor_model_parallel_size = self.tensor_model_parallel_size + self.language_transformer_config.pipeline_model_parallel_size = self.pipeline_model_parallel_size + + if self.encoder_pipeline_model_parallel_size > 0: + assert self.encoder_pipeline_model_parallel_size == 1, "ViT can only live on 1 pipeline stage." + self.vision_transformer_config.pipeline_model_parallel_size = self.encoder_pipeline_model_parallel_size + self.vision_projection_config.pipeline_model_parallel_size = self.encoder_pipeline_model_parallel_size + self.language_transformer_config.encoder_pipeline_model_parallel_size = ( + self.encoder_pipeline_model_parallel_size + ) + if self.encoder_tensor_model_parallel_size > 0: + self.vision_transformer_config.tensor_model_parallel_size = self.encoder_tensor_model_parallel_size + self.vision_projection_config.tensor_model_parallel_size = self.encoder_tensor_model_parallel_size + + model = MCoreLlavaNextModel( + config=self, + tokenizer=tokenizer, + pre_process=ps.is_pipeline_first_stage() + or ps.get_pipeline_model_parallel_rank() == self.encoder_pipeline_model_parallel_size, + post_process=ps.is_pipeline_last_stage(), + add_encoder=ps.is_pipeline_first_stage(), + add_decoder=ps.is_pipeline_last_stage() + or ps.get_pipeline_model_parallel_rank() >= self.encoder_pipeline_model_parallel_size, + drop_vision_class_token=self.drop_vision_class_token, + ) + + return model + + +class MCoreLlavaNextModel(MCoreNevaModel): + """ + The LLaVA Next model class, extending MCoreNevaModel. + + Attributes: + image_newline (torch.nn.Parameter): A learnable parameter for handling image newlines. + """ + + def __init__( + self, + config: LlavaNextConfig, + tokenizer=None, + pre_process: bool = True, + post_process: bool = True, + add_encoder: bool = True, + add_decoder: bool = True, + drop_vision_class_token: bool = False, + ) -> None: + """ + Initializes the LLaVA Next model. + Calls the super class init and initialize image_newline parameter + + Args: + config (LlavaNextConfig): Model configuration instance. + tokenizer: Optional tokenizer instance. + pre_process (bool): Whether to enable preprocessing. + post_process (bool): Whether to enable postprocessing. + add_encoder (bool): Whether to add the encoder module. + add_decoder (bool): Whether to add the decoder module. + drop_vision_class_token (bool): Whether to drop the vision class token. + """ + super().__init__( + config=config, + tokenizer=tokenizer, + pre_process=pre_process, + post_process=post_process, + add_encoder=add_encoder, + add_decoder=add_decoder, + drop_vision_class_token=drop_vision_class_token, + ) + # extra image_newline learnable parameter for llava_next + embed_std = 1 / math.sqrt(config.vision_projection_config.hidden_size) + self.image_newline = torch.nn.Parameter(torch.randn(config.vision_projection_config.hidden_size) * embed_std) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + image_sizes: torch.Tensor, + loss_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + media: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + inference_params: Optional[InferenceParams] = None, + num_media_tiles: Optional[List[int]] = None, + media_token_index: Optional[int] = IMAGE_TOKEN_INDEX, + runtime_gather_output: Optional[bool] = None, + ) -> torch.Tensor: + """Forward function of the LLaVA Next model. + + Args: + images (torch.Tensor): input image of shape [num_tiles, img_h, img_w]. + num_tiles means the number of image tiles in this batch. + input_ids (torch.Tensor): input text ids [batch, text_seq_len]. + position_ids (torch.Tensor): input text position ids [batch, text_seq_len]. + image_sizes (torch.Tensor): Raw image sizes before tiling (N,2). + attention_mask (torch.Tensor): Attention mask for the language model [batch, text seq length]. + labels (torch.Tensor): Optional target text labels [batch, combined_seq_len]. + loss_mask (torch.Tensor): Text loss mask [batch, text_seq_len]. + inference_params (InferenceParams): Inference-time parameters including KV cache. + num_media_tiles (list of int): Number of tiles per image. Default None assumes 1 tile per image. + image_token_index (int): ID for input images. + + Returns: + output (torch.Tensor): Loss ([b, s]) if labels are provided; logits ([b, s, vocab_size]) otherwise. + loss_mask (torch.Tensor): Loss mask expanded to combined sequence length. Shape [b, s]. + """ + use_inference_kv_cache = ( + inference_params is not None and "image_tokens_count" in inference_params.key_value_memory_dict + ) + has_images = media.shape[0] > 0 + + # If running inference, we can skip media token computation + # if they were computed already earlier for this sample. + if use_inference_kv_cache: + media_embeddings = None + elif self.add_encoder and not has_images: + # If no images provided, use an empty image embeddings tensor. + media_embeddings = torch.tensor([], dtype=media.dtype, device=media.device).reshape(0, 0, 0) + elif self.add_encoder and has_images: + # media is in shape of (num_images_in_mbs, c, h, w) + # note num_images_in_mbs is not mbs but total images in this mbs. + if self.vision_model_from_hf: + self.vision_model = self.vision_model.eval() + media_embeddings = self.vision_model(media, output_hidden_states=True) + media_embeddings = media_embeddings[-1][ + self.config.vision_feature_layer + ] # [num_images, img_seq_len, h_vision] + else: + # TODO(yuya): MCore Clip path not yet support taking a specific layer hidden states + media = media.to(next(self.vision_model.parameters()).dtype) + media_embeddings = self.vision_model(media, num_unused_layers=-self.config.vision_feature_layer - 1) + if self._drop_vision_class_token: + class_token_len = getattr(self.vision_model, "class_token_len", 1) + media_embeddings = media_embeddings[:, class_token_len:, :] + + # contiguous() required as `permute` can sparsify the tensor and this breaks pipelining + media_embeddings = media_embeddings.contiguous() + # map vision model output size to language model input size. + media_embeddings = self.vision_projection(media_embeddings) # [img_seq_len, num_tiles, h_language] + # TODO: Support batched inference. + # In inference, the language model KV cache will be updated for image token positions. + # Store the image tokens sequence length to be used as an offset to the KV cache later. + if inference_params is not None: + inference_params.key_value_memory_dict["media_tokens_count"] = ( + media_embeddings.shape[0] * media_embeddings.shape[1] + ) + else: + media_embeddings = self.encoder_hidden_state + + if not self.add_decoder: + return media_embeddings + + language_embeddings = None + if self.pre_process: + input_ids_text = input_ids.clone() + # MultiModal Token indices are assumed to be values + input_ids_text[input_ids_text < 0] = 0 + # Note: This adds absolute position embedding but not RoPE. + # Each image is counted as one position. + # RoPE is added in language_model forward. Each image embedding is one position. + if self.sequence_parallel_lm: + # Pad to nearest multiple of TP world size for embedding. + tp_world_size = ps.get_tensor_model_parallel_world_size() + padded_seq_len = ( + int((input_ids_text.shape[1] + tp_world_size - 1) // tp_world_size * tp_world_size) + - input_ids_text.shape[1] + ) + if padded_seq_len != 0: + input_ids_text = torch.nn.functional.pad(input_ids_text, (0, padded_seq_len)) + if position_ids is not None: + position_ids = torch.nn.functional.pad(position_ids, (0, padded_seq_len)) + language_embeddings = self.language_model.embedding( + input_ids=input_ids_text, position_ids=position_ids + ) # [text_seq_len, b, h_language] + if self.sequence_parallel_lm: + # Gather the language embeddings back. + # We use the full embedding to insert image embeddings + # and then scatter to avoid load imbalance. + language_embeddings = gather_from_sequence_parallel_region( + language_embeddings, tensor_parallel_output_grad=False + ) + # Remove the padding done for SP as we'll need new padding calculation + # after image embeddings are inserted. + if padded_seq_len != 0: + language_embeddings = language_embeddings[:-padded_seq_len] + language_embeddings = language_embeddings.transpose(1, 0).contiguous() # [b, text_seq_len, h_language] + + # Assume 1 tile per image if the number of tiles is not provided. + if num_media_tiles is None: + num_media_tiles = torch.ones(media.shape[0], dtype=torch.int, device=input_ids.device) + elif isinstance(num_media_tiles, list): + num_media_tiles = torch.tensor(num_media_tiles, dtype=torch.int, device=input_ids.device) + + media_embeddings = torch.split(media_embeddings, num_media_tiles.tolist(), dim=0) + media_embeddings, feature_lens = pack_image_features( + media_embeddings, + image_sizes, + vision_feature_select_strategy='default', + image_newline=self.image_newline, + ) + + combined_embeddings, attention_mask, position_ids, final_labels, final_input_ids, final_loss_mask = ( + merge_input_ids_with_image_features( + media_embeddings, + feature_lens, + language_embeddings, + input_ids, + attention_mask, + position_ids, + labels=labels, + image_token_index=media_token_index, + ) + ) + combined_embeddings = combined_embeddings.permute(1, 0, 2) + combined_embeddings = combined_embeddings.contiguous() + output = self.language_model( + input_ids=None, + position_ids=None, + attention_mask=attention_mask, + decoder_input=combined_embeddings, + labels=final_labels, + inference_params=inference_params, + runtime_gather_output=runtime_gather_output, + ) + + if labels is None or loss_mask is None: + return output + + return output, final_loss_mask.contiguous() + + +__all__ = [ + "LlavaNextConfig", +] diff --git a/nemo/collections/vlm/llava_next/model/llava_next.py b/nemo/collections/vlm/llava_next/model/llava_next.py new file mode 100644 index 0000000000000..fac5d5dd08715 --- /dev/null +++ b/nemo/collections/vlm/llava_next/model/llava_next.py @@ -0,0 +1,267 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from pathlib import Path +from typing import Callable, List, Optional, Union + +import torch +import torch.distributed +from megatron.core.inference_params import InferenceParams +from megatron.core.optimizer import OptimizerConfig +from megatron.core.transformer.transformer_config import TransformerConfig +from transformers import LlavaNextForConditionalGeneration + +from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec +from nemo.collections.llm import Llama2Config7B, Llama2Config13B, LlamaConfig +from nemo.collections.vlm.llava_next.model.base import LlavaNextConfig, MCoreLlavaNextModel +from nemo.collections.vlm.neva.model.base import HFCLIPVisionConfig, MultimodalProjectorConfig, NevaModel +from nemo.collections.vlm.neva.model.llava import HFLlavaImporter +from nemo.lightning import OptimizerModule, io, teardown +from nemo.lightning.pytorch.optim import MegatronOptimizerModule, OptimizerModule + + +@dataclass +class LlavaNextConfig7B(LlavaNextConfig): + """ + Configuration class for the 7B parameter variant of the LLaVA 16 model. + + Inherits all attributes and methods from Llava15Config7B without modification. + """ + + from transformers import PretrainedConfig + + language_transformer_config: TransformerConfig = field(default_factory=lambda: Llama2Config7B()) + vision_transformer_config: Union[TransformerConfig, PretrainedConfig] = field( + default_factory=lambda: HFCLIPVisionConfig(pretrained_model_name_or_path="openai/clip-vit-large-patch14-336") + ) + vision_projection_config: TransformerConfig = field( + default_factory=lambda: MultimodalProjectorConfig(input_size=1024, hidden_size=4096, ffn_hidden_size=4096) + ) + + +@dataclass +class LlavaNextConfig13B(LlavaNextConfig): + """ + Configuration class for the 13B parameter variant of the LLaVA 16 model. + + Inherits all attributes and methods from Llava15Config13B without modification. + """ + + from transformers import PretrainedConfig + + language_transformer_config: TransformerConfig = field(default_factory=lambda: Llama2Config13B()) + vision_transformer_config: Union[TransformerConfig, PretrainedConfig] = field( + default_factory=lambda: HFCLIPVisionConfig(pretrained_model_name_or_path="openai/clip-vit-large-patch14-336") + ) + vision_projection_config: TransformerConfig = field( + default_factory=lambda: MultimodalProjectorConfig(input_size=1024, hidden_size=5120, ffn_hidden_size=5120) + ) + + +class LlavaNextModel(NevaModel): + """ + The LLaVA Next model class, extending NevaModel. + + Attributes: + config (LlavaNextConfig): Configuration object for the model. + optim (Optional[OptimizerModule]): Optimizer module. Defaults to a Megatron optimizer. + tokenizer (Optional[TokenizerSpec]): Tokenizer specification for processing text inputs. + model_transform (Optional[Callable[[torch.nn.Module], torch.nn.Module]]): + Optional transformation applied to the model after initialization. + """ + + def __init__( + self, + config: LlavaNextConfig, + optim: Optional[OptimizerModule] = None, + tokenizer: Optional["TokenizerSpec"] = None, + model_transform: Optional[Callable[[torch.nn.Module], torch.nn.Module]] = None, + ): + """ + Initializes the LlavaNextModel. + + Args: + config (LlavaNextConfig): Configuration object for the model. + optim (Optional[OptimizerModule]): optimizer module. Defaults to Megatron optimizer. + tokenizer (Optional[TokenizerSpec]): Optional tokenizer specification for processing text inputs. + model_transform (Optional[Callable[[torch.nn.Module], torch.nn.Module]]): + Optional transformation function applied to the model after initialization. + """ + super().__init__( + config=config, + optim=optim or MegatronOptimizerModule(config=OptimizerConfig(lr=1e-4, use_distributed_optimizer=True)), + tokenizer=tokenizer, + model_transform=model_transform, + ) + + def configure_model(self) -> MCoreLlavaNextModel: + """ + Configures the underlying model instance if it has not been initialized. + + Returns: + MCoreLlavaNextModel: The configured model instance. + """ + if not hasattr(self, "module"): + self.module = self.config.configure_model(self.tokenizer) + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + image_sizes: torch.Tensor, + loss_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + media: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + inference_params: InferenceParams = None, + num_media_tiles: Optional[List[int]] = None, + ) -> torch.Tensor: + """ + Performs the forward pass of the LLaVA Next model. + + Args: + input_ids (torch.Tensor): Input token IDs of shape [batch, text_seq_len]. + position_ids (torch.Tensor): Position IDs of shape [batch, text_seq_len]. + image_sizes (torch.Tensor): Raw image sizes before tiling, of shape [batch, 2]. + loss_mask (Optional[torch.Tensor]): Text loss mask of shape [batch, text_seq_len]. + attention_mask (Optional[torch.Tensor]): Attention mask shape [batch, text_seq_len]. + media (Optional[torch.Tensor]): Input media tensor. + labels (Optional[torch.Tensor]): Target labels of shape [batch, combined_seq_len]. + inference_params (InferenceParams): Inference-time parameters. + num_media_tiles (Optional[List[int]]): Number of tiles per image. Default assumes 1 tile per image. + + Returns: + torch.Tensor: The model output. Shape depends on whether labels are provided. + - If `labels` is provided: Loss tensor of shape [batch, seq_len]. + - If `labels` is not provided: Logits tensor of shape [batch, seq_len, vocab_size]. + """ + output_tensor = self.module( + media=media, + input_ids=input_ids, + position_ids=position_ids, + image_sizes=image_sizes, + loss_mask=loss_mask, + attention_mask=attention_mask, + labels=labels, + inference_params=inference_params, + num_media_tiles=num_media_tiles, + ) + + return output_tensor + + +@io.model_importer(LlavaNextModel, "hf") +class HFLlavaNextImporter( + HFLlavaImporter, + io.ModelConnector["LlavaNextForConditionalGeneration", LlavaNextModel], +): + """ + Importer class for converting HuggingFace LLaVA Next checkpoint to NeMo format. + + Inherits: + HFLlavaImporter: Base class for HuggingFace LLaVA model importers. + io.ModelConnector: Connector interface to handle setup, save, and load using the Lightning framework. + + Methods: + init: Initializes a new LlavaNextModel instance. + apply: Converts the HuggingFace model to NeMo format and saves it. + config: Generates and returns the LlavaNextConfig for the model. + """ + + def init(self) -> LlavaNextModel: + """ + Initializes the LlavaNextModel. + + Returns: + LlavaNextModel: An instance of the LLaVA Next model initialized with the configuration. + """ + return LlavaNextModel(self.config, tokenizer=self.tokenizer) + + def apply(self, output_path: Path) -> Path: + """ + Converts the HuggingFace LLaVA Next model to NeMo format and saves it to the specified path. + + Args: + output_path (Path): The path where the converted NeMo model will be saved. + + Returns: + Path: The output path where the NeMo model was saved. + """ + + source = LlavaNextForConditionalGeneration.from_pretrained(str(self)) + target = self.init() + trainer = self.nemo_setup(target) + self.convert_state(source, target, image_newline=True) + print(f"Converted Llava next model to Nemo, saving to {output_path}") + + self.nemo_save(output_path, trainer) + + print(f"Converted Llava next model saved to {output_path}") + + teardown(trainer, target) + del trainer, target + + return output_path + + @property + def config(self) -> LlavaNextConfig: + """ + Generates the configuration for the LLaVA Next model based on the HuggingFace model. + + Returns: + LlavaNextConfig: A configuration object for the LLaVA Next model. + """ + from transformers import LlavaConfig as HFLlavaConfig + + source = HFLlavaConfig.from_pretrained(str(self)) + text_conifg = source.text_config + + def make_vocab_size_divisible_by(vocab_size): + base = 128 + while vocab_size % base != 0: + base //= 2 + return base + + language_transformer_config = LlamaConfig( + num_layers=text_conifg.num_hidden_layers, + hidden_size=text_conifg.hidden_size, + ffn_hidden_size=text_conifg.intermediate_size, + num_attention_heads=text_conifg.num_attention_heads, + init_method_std=text_conifg.initializer_range, + layernorm_epsilon=text_conifg.rms_norm_eps, + num_query_groups=text_conifg.num_key_value_heads, + rotary_base=text_conifg.rope_theta, + gated_linear_unit=True, + make_vocab_size_divisible_by=make_vocab_size_divisible_by(text_conifg.vocab_size), + share_embeddings_and_output_weights=False, + ) + vision_transformer_config = HFCLIPVisionConfig( + pretrained_model_name_or_path="openai/clip-vit-large-patch14-336" + ) + vision_projection_config = MultimodalProjectorConfig(input_size=1024, hidden_size=4096, ffn_hidden_size=4096) + + output = LlavaNextConfig( + language_transformer_config=language_transformer_config, + vision_transformer_config=vision_transformer_config, + vision_projection_config=vision_projection_config, + vision_feature_layer=source.vision_feature_layer, + ) + + return output + + +__all__ = [ + "LlavaNextModel", +] diff --git a/nemo/collections/vlm/llava_next/model/utils.py b/nemo/collections/vlm/llava_next/model/utils.py new file mode 100644 index 0000000000000..2996bc277983b --- /dev/null +++ b/nemo/collections/vlm/llava_next/model/utils.py @@ -0,0 +1,436 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +# 'These functions implementation is adapted from +# https://github.com/huggingface/transformers/blob/ +# 53fad641cfdb5105e2470bcf3ef17ea8e25cc300/src/transformers/models/llava_next/modeling_llava_next.py' + + +def get_image_sequence_length(img_h, img_w, patch_dim, add_class_token, class_token_len): + """Get image sequence length given image size, patch size, and class token.""" + num_patches_per_dim_h = img_h // patch_dim + num_patches_per_dim_w = img_w // patch_dim + num_patches = num_patches_per_dim_h * num_patches_per_dim_w + return num_patches + (class_token_len if add_class_token else 0) + + +def merge_input_ids_with_image_features( + image_features, + feature_lens, + inputs_embeds, + input_ids, + attention_mask, + position_ids=None, + labels=None, + image_token_index=-200, + ignore_index=-100, +): + """ + Merge input_ids with with image features into final embeddings + Args: + image_features (`torch.Tensor` of shape `(all_feature_lens, embed_dim)`): + All vision vectors of all images in the batch + feature_lens (`torch.LongTensor` of shape `(num_images)`): + The length of visual embeddings of each image as stacked in `image_features` + inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, embed_dim)`): + Token embeddings before merging with visual embeddings + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Input_ids of tokens, possibly filled with image token + attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Mask to avoid performing attention on padding token indices. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + labels (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*) + :abels need to be recalculated to support training (if provided) + image_token_index (`int`, *optional*) + Token id used to indicate the special "image" token. Defaults to `config.image_token_index` + ignore_index (`int`, *optional*) + Value that is used to pad `labels` and will be ignored when calculated loss. Default: -100. + Returns: + final_embedding, final_attention_mask, position_ids, final_labels + Explanation: + each image has variable length embeddings, with length specified by feature_lens + image_features is concatenation of all visual embed vectors + task: fill each with the correct number of visual embeddings + Example: + X (5 patches), Y (3 patches), Z (8) + X, Y are in the same sequence (in-context learning) + if right padding + input_ids: [ + a b c d e f X g h i j k Y l m + o p q r Z s t u v _ _ _ _ _ _ + ] + input_ids should be: [ + a b c d e f X X X X X g h i j k Y Y Y l m + o p q r Z Z Z Z Z Z Z Z s t u v _ _ _ _ _ + ] + labels should be: [ + a b c d e f _ _ _ _ _ g h i j k _ _ _ l m + o p q r _ _ _ _ _ _ _ _ s t u v _ _ _ _ _ + ] + elif left padding + input_ids: [ + a b c d e f X g h i j k Y l m + _ _ _ _ _ _ o p q r Z s t u v + ] + input_ids should be: [ + a b c d e f X X X X X g h i j k Y Y Y l m + _ _ _ _ _ o p q r Z Z Z Z Z Z Z Z s t u v + ] + labels should be: [ + a b c d e f _ _ _ _ _ g h i j k _ _ _ l m + _ _ _ _ _ o p q r _ _ _ _ _ _ _ _ s t u v + ] + Edge cases: + * If tokens are same but image token sizes are different, then cannot infer left or right padding + ```python + cat_img = Image.open(requests.get( + "http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw) + chart_img = Image.open(requests.get( + "https://github.com/haotian-liu/LLaVA/blob/" + "1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true" + , stream=True).raw) + prompts = [ + "[INST] \nWhat is shown in this image? [/INST]", + "[INST] \nWhat is shown in this image? [/INST]", + ] + inputs = processor(prompts, [chart_img, cat_img], return_tensors='pt', padding=True).to("cuda") + chart_img has 2634 tokens, while cat_img has 2340 tokens + ``` + input_ids: [ + a b c d X g h + i j Y k l m n + ] + where X is 3 tokens while Y is 5, this mean after merge + if left-padding (batched generation) + input_ids should be: [ + _ _ a b c d X X X g h + i j Y Y Y Y Y k l m n + ] + elif (right padding) (training) + input_ids should be: [ + a b c d X X X g h _ _ + i j Y Y Y Y Y k l m n + ] + """ + + padding_side = 'right' + pad_token_id = 0 + with torch.no_grad(): + # ! in llava 1.6, number of patches is variable + num_images = feature_lens.size(0) + num_image_features, embed_dim = image_features.shape + if feature_lens.sum() != num_image_features: + raise ValueError(f"{feature_lens=} / {feature_lens.sum()} != {image_features.shape=}") + batch_size = input_ids.shape[0] + _left_padding = torch.any(attention_mask[:, 0] == 0) + _right_padding = torch.any(attention_mask[:, -1] == 0) + + left_padding = padding_side == "left" + if batch_size > 1: + if _left_padding and _right_padding: + raise ValueError(f"both side of attention_mask has zero, invalid. {attention_mask}") + elif _right_padding and left_padding: + left_padding = False + elif _left_padding and not left_padding: + left_padding = True + # Whether to turn off right padding + # 1. Create a mask to know where special image tokens are + special_image_token_mask = input_ids == image_token_index + # special_image_token_mask: [bsz, seqlen] + num_special_image_tokens = torch.sum(special_image_token_mask, dim=-1) + # num_special_image_tokens: [bsz] + # Reserve for padding of num_images + total_num_special_image_tokens = torch.sum(special_image_token_mask) + if total_num_special_image_tokens != num_images: + raise ValueError( + f"Number of image tokens in input_ids ({total_num_special_image_tokens}) " + f"different from num_images ({num_images})." + ) + # Compute the maximum embed dimension + # max_image_feature_lens is max_feature_lens per batch + feature_lens = feature_lens.to(input_ids.device) + feature_lens_batch = feature_lens.split(num_special_image_tokens.tolist(), dim=0) + feature_lens_batch_sum = torch.tensor([x.sum() for x in feature_lens_batch], device=input_ids.device) + embed_sequence_lengths = ( + (attention_mask == 1).long().sum(-1) - num_special_image_tokens + feature_lens_batch_sum + ) + max_embed_dim = embed_sequence_lengths.max() + batch_indices, non_image_indices = torch.where((input_ids != image_token_index) & (attention_mask == 1)) + # batch_indices, non_image_indices = torch.where((input_ids != image_token_index) ) + # 2. Compute the positions where text should be written + # Calculate new positions for text tokens in merged image-text sequence. + # `special_image_token_mask` identifies image tokens. + # Each image token will be replaced by `nb_text_tokens_per_images` text tokens. + # `torch.cumsum` computes how each image token shifts subsequent text token positions. + # - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one. + # ! instead of special_image_token_mask * (num_image_patches - 1) + # special_image_token_mask * (num_feature_len - 1) + special_image_token_mask = special_image_token_mask.long() + special_image_token_mask[special_image_token_mask == 1] = feature_lens - 1 + new_token_positions = torch.cumsum((special_image_token_mask + 1), -1) - 1 + if left_padding: + # shift right token positions so that they are ending at the same number + # the below here was incorrect? + # new_token_positions += new_token_positions[:, -1].max() - new_token_positions[:, -1:] + new_token_positions += max_embed_dim - 1 - new_token_positions[:, -1:] + + text_to_overwrite = new_token_positions[batch_indices, non_image_indices] + # 3. Create the full embedding, already padded to the maximum position + final_embedding = torch.zeros( + batch_size, max_embed_dim, embed_dim, dtype=inputs_embeds.dtype, device=inputs_embeds.device + ) + final_attention_mask = torch.zeros( + batch_size, max_embed_dim, dtype=attention_mask.dtype, device=inputs_embeds.device + ) + final_input_ids = torch.full( + (batch_size, max_embed_dim), pad_token_id, dtype=input_ids.dtype, device=inputs_embeds.device + ) + # In case the Vision model or the Language model has been offloaded to CPU, we need to manually + # set the corresponding tensors into their correct target device. + target_device = inputs_embeds.device + batch_indices, non_image_indices, text_to_overwrite = ( + batch_indices.to(target_device), + non_image_indices.to(target_device), + text_to_overwrite.to(target_device), + ) + attention_mask = attention_mask.to(target_device) + input_ids = input_ids.to(target_device) + + # 4. Fill the embeddings based on the mask. If we have ["hey" "", "how", "are"] + # we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the image features + final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[batch_indices, non_image_indices] + final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[batch_indices, non_image_indices] + final_input_ids[batch_indices, text_to_overwrite] = input_ids[batch_indices, non_image_indices] + final_labels = None + if labels is not None: + labels = labels.to(target_device) + final_labels = torch.full_like(final_attention_mask, ignore_index).to(torch.long) + final_labels[batch_indices, text_to_overwrite] = labels[batch_indices, non_image_indices] + + # 5. Fill the embeddings corresponding to the images. Anything that is not `text_positions` needs filling (#29835) + with torch.no_grad(): + image_to_overwrite = torch.full( + (batch_size, max_embed_dim), True, dtype=torch.bool, device=inputs_embeds.device + ) + image_to_overwrite[batch_indices, text_to_overwrite] = False + embed_indices = torch.arange(max_embed_dim).unsqueeze(0).to(target_device) + embed_indices = embed_indices.expand(batch_size, max_embed_dim) + embed_seq_lens = embed_sequence_lengths[:, None].to(target_device) + + if left_padding: + # exclude padding on the left + max_embed_dim = max_embed_dim.to(target_device) + val = (max_embed_dim - embed_indices) <= embed_seq_lens + else: + # exclude padding on the right + val = embed_indices < embed_seq_lens + image_to_overwrite &= val + + if image_to_overwrite.sum() != num_image_features: + raise ValueError( + f"{image_to_overwrite.sum()=} != {num_image_features=} The input provided to the model are wrong. " + f"The number of image tokens is {torch.sum(special_image_token_mask)} while" + f" the number of image given to the model is {num_images}. " + f"This prevents correct indexing and breaks batch generation." + ) + final_embedding[image_to_overwrite] = image_features.contiguous().reshape(-1, embed_dim).to(target_device) + final_attention_mask |= image_to_overwrite + position_ids = (final_attention_mask.cumsum(-1) - 1).masked_fill_((final_attention_mask == 0), 1) + final_loss_mask = None + if final_labels is not None: + final_loss_mask = (final_labels != ignore_index).long() + + return final_embedding, final_attention_mask, position_ids, final_labels, final_input_ids, final_loss_mask + + +def unpad_image(tensor, original_size): + """ + Unpads a PyTorch tensor of a padded and resized image. + Args: + tensor (`torch.Tensor`): + The image tensor, assumed to be of shape (num_channels, height, width). + original_size (`tuple`): + The original size of the image (height, width). + Returns: + `torch.Tensor`: The unpadded image tensor. + """ + import numpy as np + + if not isinstance(original_size, (list, tuple)): + if not isinstance(original_size, (torch.Tensor, np.ndarray)): + raise TypeError( + f"image_size invalid type: {type(original_size)} not valid ", + "should be either list, tuple, np.ndarray or tensor", + ) + original_size = original_size.tolist() + original_height, original_width = original_size + current_height, current_width = tensor.shape[1:] + + original_aspect_ratio = original_width / original_height + current_aspect_ratio = current_width / current_height + + if original_aspect_ratio > current_aspect_ratio: + scale_factor = current_width / original_width + new_height = int(original_height * scale_factor) + padding = (current_height - new_height) // 2 + unpadded_tensor = tensor[:, padding : current_height - padding, :] + else: + scale_factor = current_height / original_height + new_width = int(original_width * scale_factor) + padding = (current_width - new_width) // 2 + unpadded_tensor = tensor[:, :, padding : current_width - padding] + + return unpadded_tensor + + +def select_best_resolution(original_size: tuple, possible_resolutions: list) -> tuple: + """ + Selects the best resolution from a list of possible resolutions based on the original size. + This is done by calculating the effective and wasted resolution for each possible resolution. + The best fit resolution is the one that maximizes the effective resolution and minimizes the wasted resolution. + Args: + original_size (tuple): + The original size of the image in the format (height, width). + possible_resolutions (list): + A list of possible resolutions in the format [(height1, width1), (height2, width2), ...]. + Returns: + tuple: The best fit resolution in the format (height, width). + """ + original_height, original_width = original_size + best_fit = None + max_effective_resolution = 0 + min_wasted_resolution = float("inf") + + for height, width in possible_resolutions: + scale = min(width / original_width, height / original_height) + downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale) + effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height) + wasted_resolution = (width * height) - effective_resolution + + if effective_resolution > max_effective_resolution or ( + effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution + ): + max_effective_resolution = effective_resolution + min_wasted_resolution = wasted_resolution + best_fit = (height, width) + + return best_fit + + +def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): + """ + Calculate the shape of the image patch grid after the preprocessing for images of any resolution. + Args: + image_size (`tuple`): + The size of the input image in the format (width, height). + grid_pinpoints (`List`): + A list containing possible resolutions. Each item in the list should be a tuple or list + of the form `(height, width)`. + patch_size (`int`): + The size of each image patch. + Returns: + tuple: The shape of the image patch grid in the format (width, height). + """ + import numpy as np + + if not isinstance(grid_pinpoints, list): + raise TypeError("grid_pinpoints should be a list of tuples or lists") + + # ! VERY IMPORTANT if image_size is tensor, must convert to into tuple, otherwise it will cause wrong calculate + if not isinstance(image_size, (list, tuple)): + if not isinstance(image_size, (torch.Tensor, np.ndarray)): + raise TypeError( + f"image_size invalid type: {type(image_size)} not valid, " + "should be either list, tuple, np.ndarray or tensor" + ) + image_size = image_size.tolist() + + height, width = select_best_resolution(image_size, grid_pinpoints) + return height // patch_size, width // patch_size + + +# These functions implementation is adapted from +# https://github.com/huggingface/transformers/blob/ +# 53fad641cfdb5105e2470bcf3ef17ea8e25cc300/src/transformers/models/llava_next/modeling_llava_next.py#L655' + + +def pack_image_features(image_features, image_sizes, vision_feature_select_strategy, image_newline=None): + """ + Reshape, unpad and then pack each image_feature into a single image_features tensor containing all visual vectors. + Args: + image_features (`List[torch.Tensor]` of length num_images, + each of shape `(num_patches, image_length, embed_dim)`) + List of image feature tensor, each contains all the visual feature of all patches. + image_sizes (`torch.Tensor` of shape `(num_images, 2)`) + Actual image size of each images (H, W). + vision_feature_select_strategy (`str`) + The feature selection strategy used to select the vision feature from the vision backbone. + image_newline (`torch.Tensor` of shape `(embed_dim)`) + New line embedding vector. + Returns: + image_features (`torch.Tensor` of shape `(all_feat_len, embed_dim)`) + feature_lens (`List[int]`) + token length of each image in image_features + """ + from transformers import LlavaNextConfig + + config = LlavaNextConfig() + new_image_features = [] + feature_lens = [] + + for image_idx, image_feature in enumerate(image_features): + if image_feature.shape[0] > 1: + base_image_feature = image_feature[0] + image_feature = image_feature[1:] + height = width = config.vision_config.image_size // config.vision_config.patch_size + + if vision_feature_select_strategy == "default": + expected_num_patches = height * width + elif vision_feature_select_strategy == "full": + expected_num_patches = height * width + 1 + if expected_num_patches != base_image_feature.shape[0]: + raise ValueError("The number of patches is not consistent with the image size.") + + num_patch_height, num_patch_width = get_anyres_image_grid_shape( + image_sizes[image_idx], + config.image_grid_pinpoints, + config.vision_config.image_size, + ) + image_feature = image_feature.view(num_patch_height, num_patch_width, height, width, -1) + image_feature = image_feature.permute(4, 0, 2, 1, 3).contiguous() + image_feature = image_feature.flatten(1, 2).flatten(2, 3) + image_feature = unpad_image(image_feature, image_sizes[image_idx]) + if image_newline is not None: + image_feature = torch.cat( + ( + image_feature, + image_newline[:, None, None].expand(*image_feature.shape[:-1], 1).to(image_feature.dtype), + ), + dim=-1, + ) + image_feature = image_feature.flatten(1, 2).transpose(0, 1) + image_feature = torch.cat((base_image_feature, image_feature), dim=0) + else: + image_feature = image_feature[0] + if image_newline is not None: + image_feature = torch.cat((image_feature, image_newline[None].to(image_feature)), dim=0) + new_image_features.append(image_feature) + feature_lens.append(image_feature.size(0)) + image_features = torch.cat(new_image_features, dim=0) + feature_lens = torch.tensor(feature_lens, dtype=torch.long, device=image_features.device) + return image_features, feature_lens diff --git a/nemo/collections/vlm/neva/data/__init__.py b/nemo/collections/vlm/neva/data/__init__.py index df9716fe56105..f210d01a06fda 100644 --- a/nemo/collections/vlm/neva/data/__init__.py +++ b/nemo/collections/vlm/neva/data/__init__.py @@ -14,7 +14,6 @@ from nemo.collections.vlm.neva.data.config import DataConfig, ImageDataConfig, VideoDataConfig from nemo.collections.vlm.neva.data.lazy import NevaLazyDataModule -from nemo.collections.vlm.neva.data.llava_next_energon import LlavaNextTaskEncoder from nemo.collections.vlm.neva.data.mock import MockDataModule as NevaMockDataModule from nemo.collections.vlm.neva.data.multimodal_tokens import ImageToken, MultiModalToken, VideoToken @@ -27,5 +26,4 @@ "MultiModalToken", "ImageToken", "VideoToken", - "LlavaNextTaskEncoder", ] diff --git a/nemo/collections/vlm/neva/model/llava.py b/nemo/collections/vlm/neva/model/llava.py index 5e02b4f9e9d73..7f5f46380b299 100644 --- a/nemo/collections/vlm/neva/model/llava.py +++ b/nemo/collections/vlm/neva/model/llava.py @@ -102,7 +102,7 @@ def apply(self, output_path: Path) -> Path: return output_path - def convert_state(self, source, target): + def convert_state(self, source, target, image_newline=False): mapping = { "language_model.model.embed_tokens.weight": "language_model.embedding.word_embeddings.weight", "language_model.model.layers.*.self_attn.o_proj.weight": "language_model.decoder.layers.*.self_attention.linear_proj.weight", @@ -133,6 +133,9 @@ def convert_state(self, source, target): else: raise KeyError("Unable to map vision projection keys.") + if image_newline: + mapping.update({"image_newline": "image_newline"}) + if "vision_model.vision_model.embeddings.class_embedding" in target.module.state_dict().keys(): mapping.update( { diff --git a/nemo/collections/vlm/recipes/__init__.py b/nemo/collections/vlm/recipes/__init__.py index ba8706437c560..e3225dec8c4f6 100644 --- a/nemo/collections/vlm/recipes/__init__.py +++ b/nemo/collections/vlm/recipes/__init__.py @@ -13,11 +13,12 @@ # limitations under the License. -from nemo.collections.vlm.recipes import llava15_7b, llava15_13b, mllama_11b, mllama_90b +from nemo.collections.vlm.recipes import llava15_7b, llava15_13b, llava_next_7b, mllama_11b, mllama_90b __all__ = [ "llava15_7b", "llava15_13b", "mllama_11b", "mllama_90b", + "llava_next_7b", ] diff --git a/nemo/collections/vlm/recipes/llava_next_7b.py b/nemo/collections/vlm/recipes/llava_next_7b.py new file mode 100644 index 0000000000000..c483ff788f265 --- /dev/null +++ b/nemo/collections/vlm/recipes/llava_next_7b.py @@ -0,0 +1,242 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Optional + +import nemo_run as run +import pytorch_lightning as pl +import torch + +from nemo import lightning as nl +from nemo.collections import llm, vlm +from nemo.collections.llm.recipes.finetune_default import nemo_resume +from nemo.collections.llm.recipes.log.default import tensorboard_logger +from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing +from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed +from nemo.collections.vlm import LlavaNextMockDataModule +from nemo.utils.exp_manager import TimingCallback + +NAME = "llava_next_7b" + + +@run.cli.factory(name=NAME) +def model(config=run.Config(vlm.LlavaNextConfig7B)) -> run.Config[pl.LightningModule]: + """ + Factory function to create a LlavaNext 7B model configuration. + + Returns: + run.Config[pl.LightningModule]: Configuration for the Llava Next 7B model. + + Examples: + CLI usage: + $ nemo llm pretrain model=llava_next_7b ... + + Python API usage: + >>> model_config = model() + >>> print(model_config) + """ + return run.Config(vlm.LlavaNextModel, config=config) + + +@run.cli.factory(target=llm.finetune, name=NAME) +def finetune_recipe( + dir: Optional[str] = None, + name: str = "default", + num_nodes: int = 1, + num_gpus_per_node: int = 8, + peft_scheme: Optional[str] = 'none', +) -> run.Partial: + """ + Create a fine-tuning recipe for LlavaNext 7B model. + + This function sets up a complete configuration for fine-tuning, including + model, trainer, data, logging, optimization, and resumption settings. + The recipe uses LoRA (Low-Rank Adaptation) for efficient fine-tuning, unless peft_scheme is set to None. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the fine-tuning run. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + + Returns: + run.Partial: Partial configuration for fine-tuning. + + Examples: + CLI usage: + $ nemo llm finetune --factory llava_next_7b + + Python API usage: + >>> recipe = finetune_recipe(name="llava_next_7b_finetune", num_nodes=1) + >>> print(recipe) + """ + + strategy = run.Config( + nl.MegatronStrategy, + tensor_model_parallel_size=4, + pipeline_model_parallel_size=1, + encoder_pipeline_model_parallel_size=0, + pipeline_dtype=torch.bfloat16, + ) + + trainer = run.Config( + nl.Trainer, + accelerator="gpu", + accumulate_grad_batches=1, + devices=num_gpus_per_node, + limit_val_batches=10, + log_every_n_steps=1, + max_steps=5190, + num_nodes=num_nodes, + plugins=bf16_mixed(), + strategy=strategy, + val_check_interval=1000, + callbacks=[run.Config(TimingCallback)], + ) + from transformers import AutoProcessor + + from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer + + recipe = run.Partial( + llm.finetune, + model=model( + config=run.Config( + vlm.LlavaNextConfig7B, + freeze_language_model=False, + freeze_vision_model=True, + freeze_vision_projection=False, + ) + ), + trainer=trainer, + data=run.Config( + LlavaNextMockDataModule, + seq_length=4096, + global_batch_size=8, + micro_batch_size=2, + tokenizer=None, + image_processor=None, + num_workers=4, + ), + log=llm.default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), + optim=distributed_fused_adam_with_cosine_annealing(max_lr=2.0e-05, min_lr=2.0e-07, warmup_steps=150), + resume=nemo_resume("llava-hf/llava-v1.6-vicuna-7b-hf"), + ) + + if peft_scheme is None or peft_scheme.lower() == 'none': + recipe.trainer.strategy.tensor_model_parallel_size = 2 + recipe.optim.config.lr = 2e-05 + elif peft_scheme.lower() == 'lora': + recipe.peft = run.Config( + vlm.LoRA, + freeze_vision_model=False, + target_modules=[ + "*.language_model.*.linear_qkv", + "*.language_model.*.linear_q", + "*.language_model.*.linear_kv", + "*.language_model.*.linear_proj", + "*.language_model.*.linear_fc1", + "*.language_model.*.linear_fc2", + ], + ) + recipe.optim.config.lr = 1e-4 + else: + raise ValueError(f"Unrecognized peft scheme: {peft_scheme}") + + return recipe + + +@run.cli.factory(target=llm.pretrain, name=NAME) +def pretrain_recipe( + dir: Optional[str] = None, + name: str = "default", + num_nodes: int = 1, + num_gpus_per_node: int = 8, + peft_scheme: Optional[str] = 'none', +) -> run.Partial: + """ + Create a Pre-training recipe for Llava1.6 7B model. + + This function sets up a complete configuration for pre-training, including + model, trainer, data, logging, optimization, and resumption settings. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the fine-tuning run. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + + Returns: + run.Partial: Partial configuration for fine-tuning. + + Examples: + CLI usage: + $ nemo llm pretrain --factory llava_next_7b + + Python API usage: + >>> recipe = finetune_recipe(name="llava_next_7b_pretrain", num_nodes=1) + >>> print(recipe) + """ + + strategy = run.Config( + nl.MegatronStrategy, + tensor_model_parallel_size=4, + pipeline_model_parallel_size=1, + encoder_pipeline_model_parallel_size=0, + pipeline_dtype=torch.bfloat16, + ) + + trainer = run.Config( + nl.Trainer, + accelerator="gpu", + accumulate_grad_batches=1, + devices=num_gpus_per_node, + limit_val_batches=10, + log_every_n_steps=1, + max_steps=5190, + num_nodes=num_nodes, + plugins=bf16_mixed(), + strategy=strategy, + val_check_interval=1000, + callbacks=[run.Config(TimingCallback)], + ) + from transformers import AutoProcessor + + from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer + + recipe = run.Partial( + llm.pretrain, + model=model( + config=run.Config( + vlm.LlavaNextConfig7B, + freeze_language_model=True, + freeze_vision_model=True, + freeze_vision_projection=False, + ) + ), + trainer=trainer, + data=run.Config( + LlavaNextMockDataModule, + seq_length=4096, + global_batch_size=8, + micro_batch_size=2, + tokenizer=None, + image_processor=None, + num_workers=4, + ), + log=llm.default_log(dir=dir, name=name, tensorboard_logger=tensorboard_logger(name=name)), + optim=distributed_fused_adam_with_cosine_annealing(max_lr=0.001, min_lr=2.0e-05, warmup_steps=150), + ) + + return recipe diff --git a/nemo/deploy/nlp/__init__.py b/nemo/deploy/nlp/__init__.py index 5ebbe68166649..633544e300ed0 100644 --- a/nemo/deploy/nlp/__init__.py +++ b/nemo/deploy/nlp/__init__.py @@ -12,15 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. - -use_query_llm = True -try: - from nemo.deploy.nlp.query_llm import NemoQueryLLM, NemoQueryLLMPyTorch -except Exception: - use_query_llm = False - -use_megatron_llm = True -try: - from nemo.deploy.nlp.megatronllm_deployable import MegatronLLMDeployable -except Exception: - use_megatron_llm = False +from nemo.deploy.nlp.megatronllm_deployable import MegatronLLMDeployable +from nemo.deploy.nlp.query_llm import NemoQueryLLM, NemoQueryLLMPyTorch diff --git a/nemo/export/__init__.py b/nemo/export/__init__.py index d9155f923f186..6b1f8c90aa8f1 100644 --- a/nemo/export/__init__.py +++ b/nemo/export/__init__.py @@ -11,3 +11,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from nemo.export.tensorrt_lazy_compiler import trt_compile diff --git a/nemo/export/tensorrt_lazy_compiler.py b/nemo/export/tensorrt_lazy_compiler.py new file mode 100644 index 0000000000000..ab40278efa947 --- /dev/null +++ b/nemo/export/tensorrt_lazy_compiler.py @@ -0,0 +1,714 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import inspect +import os +import tempfile +import threading +from collections import OrderedDict +from logging import getLogger +from pathlib import Path +from types import MethodType +from typing import Any, Dict, List, Sequence, Tuple, Union + +import torch + +from nemo.utils.export_utils import add_casts_around_norms, replace_for_export +from nemo.utils.import_utils import safe_import + +polygraphy, polygraphy_imported = safe_import("polygraphy") +if polygraphy_imported: + from polygraphy.backend.common import bytes_from_path + from polygraphy.backend.trt import ( + CreateConfig, + Profile, + engine_bytes_from_network, + engine_from_bytes, + network_from_onnx_path, + ) + +trt, trt_imported = safe_import("tensorrt") +torch_tensorrt, _ = safe_import("torch_tensorrt") +cudart, _ = safe_import("cuda.cudart") + +lock_sm = threading.Lock() + + +def trt_to_torch_dtype_dict(): + """ + Map of TRT dtype -> Torch dtype + """ + return { + trt.int32: torch.int32, + trt.float32: torch.float32, + trt.float16: torch.float16, + trt.bfloat16: torch.float16, + trt.int64: torch.int64, + trt.int8: torch.int8, + trt.bool: torch.bool, + } + + +def get_profile_shapes(input_shape: Sequence[int], dynamic_batchsize: Sequence[int] | None): + """ + Given a sample input shape, calculate min/opt/max shapes according to dynamic_batchsize. + """ + + def scale_batch_size(input_shape: Sequence[int], scale_num: int): + scale_shape = [*input_shape] + scale_shape[0] = scale_num + return scale_shape + + # Use the dynamic batchsize range to generate the min, opt and max model input shape + if dynamic_batchsize: + min_input_shape = scale_batch_size(input_shape, dynamic_batchsize[0]) + opt_input_shape = scale_batch_size(input_shape, dynamic_batchsize[1]) + max_input_shape = scale_batch_size(input_shape, dynamic_batchsize[2]) + else: + min_input_shape = opt_input_shape = max_input_shape = input_shape + return min_input_shape, opt_input_shape, max_input_shape + + +def get_dynamic_axes(profiles): + """ + This method calculates dynamic_axes to use in onnx.export(). + Args: + profiles: [[min,opt,max],...] list of profile dimensions + """ + dynamic_axes: dict[str, list[int]] = {} + if not profiles: + return dynamic_axes + for profile in profiles: + for key in profile: + axes = [] + vals = profile[key] + for i in range(len(vals[0])): + if vals[0][i] != vals[2][i]: + axes.append(i) + if len(axes) > 0: + dynamic_axes[key] = axes + return dynamic_axes + + +def cuassert(cuda_ret): + """ + Error reporting method for CUDA calls. + Args: + cuda_ret: CUDA return code. + """ + err = cuda_ret[0] + if err != 0: + raise RuntimeError(f"CUDA ERROR: {err}") + if len(cuda_ret) > 1: + return cuda_ret[1] + return None + + +class ShapeError(Exception): + """ + Exception class to report errors from setting TRT plan input shapes + """ + + pass + + +class TRTEngine: + """ + An auxiliary class to implement running of TRT optimized engines + + """ + + def __init__(self, plan_path, logger=None): + """ + Loads serialized engine, creates execution context and activates it + Args: + plan_path: path to serialized TRT engine. + logger: optional logger object + """ + self.plan_path = plan_path + self.logger = logger or getLogger("trt_compile") + self.logger.info(f"Loading TensorRT engine: {self.plan_path}") + self.engine = engine_from_bytes(bytes_from_path(self.plan_path)) + self.tensors = OrderedDict() + self.cuda_graph_instance = None # cuda graph + self.context = self.engine.create_execution_context() + self.input_names = [] + self.output_names = [] + self.dtypes = [] + self.cur_profile = 0 + self.input_table = {} + dtype_dict = trt_to_torch_dtype_dict() + for idx in range(self.engine.num_io_tensors): + binding = self.engine[idx] + if self.engine.get_tensor_mode(binding) == trt.TensorIOMode.INPUT: + self.input_names.append(binding) + elif self.engine.get_tensor_mode(binding) == trt.TensorIOMode.OUTPUT: + self.output_names.append(binding) + dtype = dtype_dict[self.engine.get_tensor_dtype(binding)] + self.dtypes.append(dtype) + self.logger.info( + f"Loaded TensorRT engine: {self.plan_path}.\nInputs: {self.input_names}\nOutputs: {self.output_names}" + ) + + def allocate_buffers(self, device): + """ + Allocates outputs to run TRT engine + Args: + device: GPU device to allocate memory on + """ + ctx = self.context + + for i, binding in enumerate(self.output_names): + shape = list(ctx.get_tensor_shape(binding)) + if binding not in self.tensors or list(self.tensors[binding].shape) != shape: + t = torch.empty(shape, dtype=self.dtypes[i], device=device).contiguous() + self.tensors[binding] = t + ctx.set_tensor_address(binding, t.data_ptr()) + + def set_inputs(self, feed_dict, stream): + """ + Sets input bindings for TRT engine according to feed_dict + Args: + feed_dict: a dictionary [str->Tensor] + stream: CUDA stream to use + """ + e = self.engine + ctx = self.context + + last_profile = self.cur_profile + + def try_set_inputs(): + for binding in self.input_names: + t = feed_dict.get(self.input_table[binding], None) + if t is not None: + t = t.contiguous() + shape = t.shape + ctx.set_input_shape(binding, shape) + ctx.set_tensor_address(binding, t.data_ptr()) + + while True: + try: + try_set_inputs() + break + except ShapeError: + next_profile = (self.cur_profile + 1) % e.num_optimization_profiles + if next_profile == last_profile: + raise + self.cur_profile = next_profile + ctx.set_optimization_profile_async(self.cur_profile, stream) + except Exception: + raise + left = ctx.infer_shapes() + assert len(left) == 0 + + def infer(self, stream, use_cuda_graph=False): + """ + Runs TRT engine. + Args: + stream: CUDA stream to run on + use_cuda_graph: use CUDA graph. Note: requires all inputs to be the same GPU memory between calls. + """ + if use_cuda_graph: + if self.cuda_graph_instance is not None: + cuassert(cudart.cudaGraphLaunch(self.cuda_graph_instance, stream)) + cuassert(cudart.cudaStreamSynchronize(stream)) + else: + # do inference before CUDA graph capture + noerror = self.context.execute_async_v3(stream) + if not noerror: + raise ValueError("ERROR: inference failed.") + # capture cuda graph + cuassert( + cudart.cudaStreamBeginCapture( + stream, cudart.cudaStreamCaptureMode.cudaStreamCaptureModeThreadLocal + ) + ) + self.context.execute_async_v3(stream) + graph = cuassert(cudart.cudaStreamEndCapture(stream)) + self.cuda_graph_instance = cuassert(cudart.cudaGraphInstantiate(graph, 0)) + self.logger.info("CUDA Graph captured!") + else: + noerror = self.context.execute_async_v3(stream) + cuassert(cudart.cudaStreamSynchronize(stream)) + if not noerror: + raise ValueError("ERROR: inference failed.") + + return self.tensors + + +def make_tensor(d): + """ + Creates a new tensor from d, returns d if d is already a tensor + """ + return d if isinstance(d, torch.Tensor) else torch.tensor(d).cuda() + + +def unroll_input(input_names, input_example): + """ + Simulates list/tuple unrolling during ONNX export + """ + unrolled_input = {} + for name in input_names: + val = input_example[name] + if val is not None: + if isinstance(val, list) or isinstance(val, tuple): + for i in range(len(val)): + unrolled_input[f"{name}_{i}"] = make_tensor(val[i]) + else: + unrolled_input[name] = make_tensor(val) + return unrolled_input + + +def parse_groups( + ret: List[torch.Tensor], output_lists: List[List[int]] +) -> Tuple[Union[torch.Tensor, List[torch.Tensor]], ...]: + """ + Implements parsing of 'output_lists' arg of trt_compile(). + + Args: + ret: plain list of Tensors + + output_lists: list of output group sizes: to form some Lists/Tuples out of 'ret' List, this will be a list + of group dimensions, like [[], [5], [-1]] for returning Tensor, list of 5 items and dynamic list. + Format: [[group_n] | [], ...] + [] or group_n == 0 : next output from ret is a scalar + group_n > 0 : next output from ret is a list of group_n length + group_n == -1: next output is a dynamic list. This entry can be at any + position in output_lists, but can appear only once. + Returns: + Tuple of Union[torch.Tensor, List[torch.Tensor]], according to the grouping in output_lists + + """ + groups: Tuple[Union[torch.Tensor, List[torch.Tensor]], ...] = tuple() + cur = 0 + for l in range(len(output_lists)): + gl = output_lists[l] + assert len(gl) == 0 or len(gl) == 1 + if len(gl) == 0 or gl[0] == 0: + groups = (*groups, ret[cur]) + cur = cur + 1 + elif gl[0] > 0: + groups = (*groups, ret[cur : cur + gl[0]]) + cur = cur + gl[0] + elif gl[0] == -1: + rev_groups: Tuple[Union[torch.Tensor, List[torch.Tensor]], ...] = tuple() + rcur = len(ret) + for rl in range(len(output_lists) - 1, l, -1): + rgl = output_lists[rl] + assert len(rgl) == 0 or len(rgl) == 1 + if len(rgl) == 0 or rgl[0] == 0: + rcur = rcur - 1 + rev_groups = (*rev_groups, ret[rcur]) + elif rgl[0] > 0: + rcur = rcur - rgl[0] + rev_groups = (*rev_groups, ret[rcur : rcur + rgl[0]]) + else: + raise ValueError("Two -1 lists in output") + groups = (*groups, ret[cur:rcur], *rev_groups[::-1]) + break + return groups + + +class TrtCompiler: + """ + This class implements: + - TRT lazy persistent export + - Running TRT with optional fallback to Torch + (for TRT engines with limited profiles) + """ + + def __init__( + self, + model, + plan_path, + precision="fp16", + method="onnx", + input_names=None, + output_names=None, + output_lists=None, + export_args=None, + build_args=None, + input_profiles=None, + dynamic_batchsize=None, + use_cuda_graph=False, + timestamp=None, + fallback=False, + forward_override=None, + logger=None, + ): + """ + Initialization method: + Tries to load persistent serialized TRT engine + Saves its arguments for lazy TRT build on first forward() call + Args: + model: Model to "wrap". + plan_path : Path where to save persistent serialized TRT engine. + precision: TRT builder precision o engine model. Should be 'fp32'|'tf32'|'fp16'|'bf16'. + method: One of 'onnx'|'torch_trt'. + Default is 'onnx' (torch.onnx.export()->TRT). This is the most stable and efficient option. + 'torch_trt' may not work for some nets. Also AMP must be turned off for it to work. + input_names: Optional list of input names. If None, will be read from the function signature. + output_names: Optional list of output names. Note: If not None, patched forward() will return a dictionary. + output_lists: Optional list of output group sizes: when forward() returns Lists/Tuples, this will be a list + of their dimensions, like [[], [5], [-1]] for Tensor, list of 5 items and dynamic list. + export_args: Optional args to pass to export method. See onnx.export() and Torch-TensorRT docs for details. + build_args: Optional args to pass to TRT builder. See polygraphy.Config for details. + input_profiles: Optional list of profiles for TRT builder and ONNX export. + Each profile is a map of the form : {"input id" : [min_shape, opt_shape, max_shape], ...}. + dynamic_batchsize: A sequence with three elements to define the input batch size range for the model to be + converted. Should be a sequence like [MIN_BATCH, OPT_BATCH, MAX_BATCH]. + [note]: If neither input_profiles nor dynamic_batchsize specified, static shapes will be used. + use_cuda_graph: Use CUDA Graph for inference. Note: inputs have to be the same GPU memory between calls! + timestamp: Optional timestamp to rebuild TRT engine (e.g. if config file changes). + fallback: Allow to fall back to Pytorch when TRT inference fails (e.g, shapes exceed max profile). + """ + + method_vals = ["onnx", "torch_trt"] + if method not in method_vals: + raise ValueError(f"trt_compile(): 'method' should be one of {method_vals}, got: {method}.") + precision_vals = ["fp32", "tf32", "fp16", "bf16"] + if precision not in precision_vals: + raise ValueError(f"trt_compile(): 'precision' should be one of {precision_vals}, got: {precision}.") + + self.plan_path = plan_path + self.precision = precision + self.method = method + self.return_dict = output_names is not None + self.output_names = output_names or [] + self.output_lists = output_lists or [] + self.profiles = input_profiles or [] + self.dynamic_batchsize = dynamic_batchsize + self.export_args = export_args or {} + self.build_args = build_args or {} + self.engine: TRTEngine | None = None + self.use_cuda_graph = use_cuda_graph + self.fallback = fallback + self.disabled = False + + self.logger = logger or getLogger("trt_compile") + self.argspec = inspect.getfullargspec(model.forward) + # Normally we read input_names from forward() but can be overridden + if input_names is None: + input_names = self.argspec.args[1:] + self.defaults = {} + if self.argspec.defaults is not None: + for i in range(len(self.argspec.defaults)): + d = self.argspec.defaults[-i - 1] + if d is not None: + d = make_tensor(d) + self.defaults[self.argspec.args[-i - 1]] = d + + self.input_names = input_names + self.old_forward = model.forward + + # Force engine rebuild if older than the timestamp + if timestamp is not None and os.path.exists(self.plan_path) and os.path.getmtime(self.plan_path) < timestamp: + os.remove(self.plan_path) + + def _inputs_to_dict(self, input_example): + trt_inputs = {} + for i, inp in enumerate(input_example): + input_name = self.input_names[i] + trt_inputs[input_name] = inp + return trt_inputs + + def _load_engine(self): + """ + Loads TRT plan from disk and activates its execution context. + """ + try: + self.engine = TRTEngine(self.plan_path, self.logger) + # Make sure we have names correct + input_table = {} + for name in self.engine.input_names: + if name.startswith("__") and name not in self.input_names: + orig_name = name[2:] + else: + orig_name = name + input_table[name] = orig_name + self.engine.input_table = input_table + self.logger.info(f"Engine loaded, inputs:{self.engine.input_table}") + except Exception as e: + self.logger.info(f"Exception while loading the engine:\n{e}") + + def forward(self, model, argv, kwargs): + """ + Main forward method: + Builds TRT engine if not available yet. + Tries to run TRT engine + If exception thrown and self.callback==True: falls back to original Pytorch + + Args: Passing through whatever args wrapped module's forward() has + Returns: Passing through wrapped module's forward() return value(s) + + """ + args = self.defaults + args.update(kwargs) + if len(argv) > 0: + args.update(self._inputs_to_dict(argv)) + + if self.engine is None and not self.disabled: + # Restore original forward for export + new_forward = model.forward + model.forward = self.old_forward + try: + self._load_engine() + if self.engine is None: + build_args = args.copy() + with torch.no_grad(): + self._build_and_save(model, build_args) + # This will reassign input_names from the engine + self._load_engine() + assert self.engine is not None + except Exception as e: + if self.fallback: + self.logger.info(f"Failed to build engine: {e}") + self.disabled = True + else: + raise e + if not self.disabled and not self.fallback: + # Delete all parameters + for param in model.parameters(): + del param + # Call empty_cache to release GPU memory + torch.cuda.empty_cache() + # restore TRT hook + model.forward = new_forward + # Run the engine + try: + if self.engine is not None: + # forward_trt is not thread safe as we do not use per-thread execution contexts + with lock_sm: + device = torch.cuda.current_device() + stream = torch.cuda.Stream(device=device) + self.engine.set_inputs(unroll_input(self.input_names, args), stream.cuda_stream) + self.engine.allocate_buffers(device=device) + # Need this to synchronize with Torch stream + stream.wait_stream(torch.cuda.current_stream()) + ret = self.engine.infer(stream.cuda_stream, use_cuda_graph=self.use_cuda_graph) + # if output_names is not None, return dictionary + if not self.return_dict: + ret = list(ret.values()) + if self.output_lists: + ret = parse_groups(ret, self.output_lists) + elif len(ret) == 1: + ret = ret[0] + return ret + except Exception as e: + if self.fallback: + self.logger.info(f"Exception: {e}\nFalling back to Pytorch ...") + else: + raise e + return self.old_forward(*argv, **kwargs) + + def _onnx_to_trt(self, onnx_path): + """ + Builds TRT engine from ONNX file at onnx_path and saves to self.plan_path + """ + + profiles = [] + for profile in self.profiles: + p = Profile() + for id, val in profile.items(): + p.add(id, min=val[0], opt=val[1], max=val[2]) + profiles.append(p) + + build_args = self.build_args.copy() + build_args["tf32"] = self.precision != "fp32" + if self.precision == "fp16": + build_args["fp16"] = True + elif self.precision == "bf16": + build_args["bf16"] = True + + self.logger.info(f"Building TensorRT engine for {onnx_path}: {self.plan_path}") + network = network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]) + return engine_bytes_from_network(network, config=CreateConfig(profiles=profiles, **build_args)) + + def _build_and_save(self, model, input_example): + """ + If TRT engine is not ready, exports model to ONNX, + builds TRT engine and saves serialized TRT engine to the disk. + Args: + input_example: passed to onnx.export() + """ + + if self.engine is not None: + return + + export_args = self.export_args + engine_bytes = None + + add_casts_around_norms(model) + replace_for_export(model) + + if self.method == "torch_trt": + enabled_precisions = [torch.float32] + if self.precision == "fp16": + enabled_precisions.append(torch.float16) + elif self.precision == "bf16": + enabled_precisions.append(torch.bfloat16) + inputs = list(input_example.values()) + + def get_torch_trt_input(input_shape, dynamic_batchsize): + min_input_shape, opt_input_shape, max_input_shape = get_profile_shapes(input_shape, dynamic_batchsize) + return torch_tensorrt.Input( + min_shape=min_input_shape, opt_shape=opt_input_shape, max_shape=max_input_shape + ) + + tt_inputs = [get_torch_trt_input(i.shape, self.dynamic_batchsize) for i in inputs] + engine_bytes = torch_tensorrt.convert_method_to_trt_engine( + model, + "forward", + arg_inputs=tt_inputs, + enabled_precisions=enabled_precisions, + **export_args, + ) + else: + dbs = self.dynamic_batchsize + if dbs: + if len(self.profiles) > 0: + raise ValueError("ERROR: Both dynamic_batchsize and input_profiles set for TrtCompiler!") + if len(dbs) != 3: + raise ValueError("dynamic_batchsize has to have len ==3 ") + profile = {} + for id, val in input_example.items(): + + def add_profile(id, val): + sh = val.shape + if len(sh) > 0: + sh = sh[1:] + profile[id] = [[dbs[0], *sh], [dbs[1], *sh], [dbs[2], *sh]] + + if isinstance(val, list) or isinstance(val, tuple): + for i in range(len(val)): + add_profile(f"{id}_{i}", val[i]) + elif isinstance(val, torch.Tensor): + add_profile(id, val) + self.profiles = [profile] + + self.dynamic_axes = get_dynamic_axes(self.profiles) + + if len(self.dynamic_axes) > 0: + export_args.update({"dynamic_axes": self.dynamic_axes}) + + # Use temporary directory for easy cleanup in case of external weights + with tempfile.TemporaryDirectory() as tmpdir: + if export_args.get("dynamo", False): + input_names = None + else: + input_names = list(unroll_input(self.input_names, input_example).keys()) + onnx_path = str(Path(tmpdir) / "model.onnx") + self.logger.info( + f"Exporting to {onnx_path}:\n" + + f"output_names={self.output_names}\ninput_names={self.input_names}\nexport args: {export_args}" + ) + torch.onnx.export( + model, + (input_example,), + onnx_path, + input_names=input_names, + output_names=self.output_names, + **export_args, + ) + if polygraphy_imported: + from polygraphy.backend.onnx.loader import fold_constants, onnx_from_path, save_onnx + + onnx_model = fold_constants(onnx_from_path(onnx_path), size_threshold=16 * 1000 * 1000) + save_onnx(onnx_model, onnx_path) + self.logger.info("Export to ONNX successful.") + engine_bytes = self._onnx_to_trt(onnx_path) + if engine_bytes: + open(self.plan_path, "wb").write(engine_bytes) + + +def trt_forward(self, *argv, **kwargs): + """ + Patch function to replace original model's forward() with. + Redirects to TrtCompiler.forward() + """ + return self._trt_compiler.forward(self, argv, kwargs) + + +def trt_compile( + model: torch.nn.Module, + base_path: str, + args: Dict[str, Any] | None = None, + submodule: Union[str, List[str]] | None = None, + logger: Any | None = None, +) -> torch.nn.Module: + """ + Instruments model or submodule(s) with TrtCompiler and replaces its forward() with TRT hook. + Note: TRT 10.3 is recommended for best performance. Some nets may even fail to work with TRT 8.x + Args: + model: module to patch with TrtCompiler object. + base_path: TRT plan(s) saved to f"{base_path}[.{submodule}].plan" path. + dirname(base_path) must exist, base_path does not have to. + If base_path does point to existing file (e.g. associated checkpoint), + that file becomes a dependency - its mtime is added to args["timestamp"]. + args: Optional dict : unpacked and passed to TrtCompiler() - see TrtCompiler above for details. + submodule: Optional hierarchical id(s) of submodule to patch, e.g. ['image_decoder.decoder'] + If None, TrtCompiler patch is applied to the whole model. + Otherwise, submodule (or list of) is being patched. + logger: Optional logger for diagnostics. + Returns: + Always returns same model passed in as argument. This is for ease of use in configs. + """ + + default_args: Dict[str, Any] = { + "method": "onnx", + "precision": "fp16", + "build_args": {"builder_optimization_level": 5, "precision_constraints": "obey"}, + } + + default_args.update(args or {}) + args = default_args + + if trt_imported and polygraphy_imported and torch.cuda.is_available(): + # if "path" filename point to existing file (e.g. checkpoint) + # it's also treated as dependency + if os.path.exists(base_path): + timestamp = int(os.path.getmtime(base_path)) + if "timestamp" in args: + timestamp = max(int(args["timestamp"]), timestamp) + args["timestamp"] = timestamp + + def wrap(model, path): + if not hasattr(model, "_trt_compiler"): + model.orig_forward = model.forward + wrapper = TrtCompiler(model, path + ".plan", logger=logger, **args) + model._trt_compiler = wrapper + model.forward = MethodType(trt_forward, model) + + def find_sub(parent, submodule): + idx = submodule.find(".") + # if there is "." in name, call recursively + if idx != -1: + parent_name = submodule[:idx] + parent = getattr(parent, parent_name) + submodule = submodule[idx + 1 :] + return find_sub(parent, submodule) + return parent, submodule + + if submodule is not None: + if isinstance(submodule, str): + submodule = [submodule] + for s in submodule: + parent, sub = find_sub(model, s) + wrap(getattr(parent, sub), base_path + "." + s) + else: + wrap(model, base_path) + else: + logger = logger or getLogger("trt_compile") + logger.warning("TensorRT and/or polygraphy packages are not available! trt_compile() has no effect.") + + return model diff --git a/nemo/lightning/_strategy_lib.py b/nemo/lightning/_strategy_lib.py index 182454012d795..1fb4b4e0a757d 100644 --- a/nemo/lightning/_strategy_lib.py +++ b/nemo/lightning/_strategy_lib.py @@ -515,6 +515,17 @@ def get_safe(param_id): def load_model_state_dict(megatron_parallel, checkpoint: Mapping[str, Any], strict: bool = True) -> None: from megatron.core import parallel_state + from megatron.core.dist_checkpointing.validation import StrictHandling, parse_strict_flag + + ## convert from StrictHandling to bool for PTL + if strict is not None and not isinstance(strict, bool): + strict = parse_strict_flag(strict) + strict_options = [ + StrictHandling.ASSUME_OK_UNEXPECTED, + StrictHandling.RAISE_UNEXPECTED, + StrictHandling.RAISE_ALL, + ] + strict = strict in strict_options for index, module in enumerate(megatron_parallel): if parallel_state.get_virtual_pipeline_model_parallel_world_size() is not None: diff --git a/nemo/lightning/io/pl.py b/nemo/lightning/io/pl.py index 788697887e390..f2c70034fd503 100644 --- a/nemo/lightning/io/pl.py +++ b/nemo/lightning/io/pl.py @@ -155,9 +155,6 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_optio checkpoint_dir = ckpt_to_weights_subdir(path, is_saving=True) fs = get_filesystem(checkpoint_dir) - if fs.isdir(checkpoint_dir) and dist_checkpointing.check_is_distributed_checkpoint(checkpoint_dir): - logging.info(f'Distributed checkpoint at path {checkpoint_dir} already exists, skipping saving') - return fs.makedirs(checkpoint_dir, exist_ok=True) validate_sharding_integrity = not (self.validated_consistency and self.assume_constant_structure) @@ -173,7 +170,11 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_optio @override def load_checkpoint( - self, path: _PATH, sharded_state_dict=None, map_location: Optional[Callable] = None + self, + path: _PATH, + sharded_state_dict=None, + map_location: Optional[Callable] = None, + strict: Optional['StrictHandling'] | bool = None, ) -> Dict[str, Any]: """Loads checkpoint using :func:`torch.load`, with additional handling for ``fsspec`` remote loading of files. @@ -190,6 +191,7 @@ def load_checkpoint( """ from megatron.core import dist_checkpointing + from megatron.core.dist_checkpointing.validation import StrictHandling if map_location is not None: raise ValueError("`map_location` argument is not supported for `MegatronCheckpointIO.load_checkpoint`.") @@ -223,8 +225,21 @@ def load_checkpoint( if sharded_strategy is not None: logging.info(f'Using {sharded_strategy} dist-ckpt load strategy.') + if isinstance(strict, bool): + # For backward-compatibility reasons and a bug in MCore (strict check not applied to factories) + # we must apply a simple strict check here. + if not strict: + sharded_state_dict = self.adjust_non_strict_load(path, sharded_state_dict) + strict = StrictHandling.ASSUME_OK_UNEXPECTED if strict else StrictHandling.LOG_ALL + if strict is None: + # Default behavior + strict = StrictHandling.ASSUME_OK_UNEXPECTED + checkpoint = dist_checkpointing.load( - sharded_state_dict=sharded_state_dict, checkpoint_dir=str(path), sharded_strategy=sharded_strategy + sharded_state_dict=sharded_state_dict, + checkpoint_dir=str(path), + sharded_strategy=sharded_strategy, + strict=strict, ) checkpoint = _fix_tensors_device(checkpoint) @@ -287,6 +302,34 @@ def save_sharded_strategy(self) -> 'SaveShardedStrategy': self._save_sharded_strategy = self._determine_dist_ckpt_save_strategy() return self._save_sharded_strategy + def adjust_non_strict_load(self, path: _PATH, sharded_state_dict: Dict[str, Any]): + from megatron.core import dist_checkpointing + from megatron.core.dist_checkpointing.dict_utils import extract_matching_values + from megatron.core.dist_checkpointing.mapping import ShardedBase + + ckpt_sharded_metadata = dist_checkpointing.load_tensors_metadata(path) + loaded_keys = [] + missing_keys = [] + unexpected_keys = [] + + def should_remove_missing_sharded_base(x: Any): + if isinstance(x, ShardedBase): + if x.key in ckpt_sharded_metadata: + loaded_keys.append(x.key) + return False + else: + unexpected_keys.append(x.key) + return True + return False + + _, sharded_state_dict = extract_matching_values(sharded_state_dict, should_remove_missing_sharded_base) + logging.info(f'The following keys are not in the checkpoint and will not be loaded: {unexpected_keys}') + + # TODO: compute missing_keys by: + # 1. all_gather_object of loaded_keys + # 2. missing_keys = ckpt_sharded_metadata.keys() - loaded_keys + return sharded_state_dict + def _fix_tensors_device(ckpt: Dict) -> Dict: """Ensure checkpoint tensors are on the correct device.""" diff --git a/nemo/lightning/pytorch/callbacks/nsys.py b/nemo/lightning/pytorch/callbacks/nsys.py index f350eae407305..2a6bc3668b946 100644 --- a/nemo/lightning/pytorch/callbacks/nsys.py +++ b/nemo/lightning/pytorch/callbacks/nsys.py @@ -74,10 +74,14 @@ def on_train_batch_start(self, trainer, pl_module, batch, batch_idx: int) -> Opt """ device = trainer.strategy.root_device - current_step = trainer.strategy.current_epoch_step + try: + # Not all strategies have this. e.g.: + # AttributeError: 'SingleDeviceStrategy' object has no attribute 'current_epoch_step' + current_step = trainer.strategy.current_epoch_step + except AttributeError: + current_step = self._nsys_profile_start_step if device.type == 'cuda': if current_step == self._nsys_profile_start_step and get_rank() in self._nsys_profile_ranks: - logging.info("====== Start nsys profiling ======") torch.cuda.cudart().cudaProfilerStart() if self._nsys_profile_gen_shape: torch.autograd.profiler.emit_nvtx(record_shapes=True).__enter__() @@ -91,9 +95,11 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx: int) """ device = trainer.strategy.root_device - current_step = trainer.strategy.current_epoch_step + try: + current_step = trainer.strategy.current_epoch_step + except AttributeError: + current_step = self._nsys_profile_end_step if device.type == 'cuda': if current_step == self._nsys_profile_end_step and get_rank() in self._nsys_profile_ranks: - logging.info("====== End nsys profiling ======") torch.cuda.cudart().cudaProfilerStop() torch.autograd.profiler.emit_nvtx().__exit__(None, None, None) diff --git a/nemo/lightning/pytorch/callbacks/peft.py b/nemo/lightning/pytorch/callbacks/peft.py index fb846043c8aa6..c94e1f8e003e8 100644 --- a/nemo/lightning/pytorch/callbacks/peft.py +++ b/nemo/lightning/pytorch/callbacks/peft.py @@ -405,7 +405,11 @@ def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_optio @override def load_checkpoint( - self, path: _PATH, sharded_state_dict=None, map_location: Optional[Callable] = None + self, + path: _PATH, + sharded_state_dict=None, + map_location: Optional[Callable] = None, + strict: Optional['StrictHandling'] | bool = None, ) -> Dict[str, Any]: """ ===================== @@ -452,7 +456,7 @@ def load_checkpoint( self.model_ckpt_path = path # Note: this will include the Trainer-state of the model-checkpoint - model_ckpt = self.checkpoint_io.load_checkpoint(path, sharded_state_dict, map_location) + model_ckpt = self.checkpoint_io.load_checkpoint(path, sharded_state_dict, map_location, strict) if adapter_ckpt is not None: ## PEFT Resume, FIRST TIME adapter_ckpt['state_dict'].update(model_ckpt['state_dict']) diff --git a/nemo/lightning/pytorch/strategies/fsdp_strategy.py b/nemo/lightning/pytorch/strategies/fsdp_strategy.py index 4c5a165c2d8df..05d0f3d629f37 100644 --- a/nemo/lightning/pytorch/strategies/fsdp_strategy.py +++ b/nemo/lightning/pytorch/strategies/fsdp_strategy.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import shutil from collections import OrderedDict from pathlib import Path @@ -212,17 +213,17 @@ def save_checkpoint( checkpoint["sharded_state_dict"] = pyt_to_mcore_state_dict(checkpoint.pop("state_dict")) checkpoint["state_dict"] = OrderedDict([]) - ## replace unsharded optimizer_states with sharded dict. - ## note that if trainer.save_checkpoint(path, save_weights_only=True) is called, - ## the checkpoint will contain only model weights. Optimizer states will be omitted. - if ( - "optimizer_states" in checkpoint - and self.trainer.state.fn == TrainerFn.FITTING - and self.ckpt_save_optimizer - ): + if "optimizer_states" in checkpoint and self.trainer.state.fn == TrainerFn.FITTING: + # Clear the optimizer states. This handles the case where ckpt_save_optimizer=False + # Ideally, the optimizer state dicts should not be generated in this case checkpoint["optimizer_states"] = {} - checkpoint['optimizer'] = get_optimizer_state_dict(self.model, self.optimizers) - pyt_to_mcore_state_dict(checkpoint['optimizer']['state'], prefix="optimizer.state.") + + ## replace unsharded optimizer_states with sharded dict. + ## note that if trainer.save_checkpoint(path, save_weights_only=True) is called, + ## the checkpoint will contain only model weights. Optimizer states will be omitted. + if self.ckpt_save_optimizer: + checkpoint['optimizer'] = get_optimizer_state_dict(self.model, self.optimizers) + pyt_to_mcore_state_dict(checkpoint['optimizer']['state'], prefix="optimizer.state.") self.checkpoint_io.save_checkpoint(checkpoint, filepath, storage_options=storage_options) diff --git a/nemo/lightning/pytorch/strategies/megatron_strategy.py b/nemo/lightning/pytorch/strategies/megatron_strategy.py index 2b9caf24bce63..b74677b01b09b 100644 --- a/nemo/lightning/pytorch/strategies/megatron_strategy.py +++ b/nemo/lightning/pytorch/strategies/megatron_strategy.py @@ -156,6 +156,9 @@ class MegatronStrategy(DDPStrategy, io.IOMixin): ckpt_load_directly_on_device (bool): if True, loads the weights directly on GPU. Has effect only for `zarr` based checkpoints (PyT Distributed always loads on device). Defaults to True. + ckpt_load_strictness (StrictHandling, optional): defines loading strictness. + If not None, overwrites the `strict` flag passed to `load_checkpoint`. + Defaults to None. setup_optimizers (bool): Whether to call the trainer's setup_optimizers function to perform any necessary conversions of optimizer parameters and move optimizer parameters to the correct device. Defaults to True. @@ -204,6 +207,7 @@ def __init__( ckpt_parallel_load: bool = True, ckpt_parallel_save_optim: bool = True, ckpt_load_directly_on_device: bool = True, + ckpt_load_strictness: Optional['StrictHandling'] = None, setup_optimizers: bool = True, init_model_parallel: bool = True, replace_progress_bar: bool = True, @@ -238,6 +242,7 @@ def __init__( self.lazy_init = lazy_init self.ckpt_load_optimizer = ckpt_load_optimizer self.ckpt_save_optimizer = ckpt_save_optimizer + self.ckpt_load_strictness = ckpt_load_strictness self.pipeline_dtype = pipeline_dtype self._setup_optimizers = setup_optimizers self._init_model_parallel = init_model_parallel @@ -278,7 +283,7 @@ def connect(self, model: pl.LightningModule) -> None: """Attaches a model to strategy.""" super().connect(model) - assert not 'is_hf_model' in model.__dict__, "Cannot use HfAutoModelForCausalLM with MegatronParallel" + assert not 'is_hf_model' in model.__dict__, "Cannot use HFAutoModelForCausalLM with MegatronParallel" dtype_config = getattr(self._precision_plugin, "dtype_config", None) if self.pipeline_dtype is None and dtype_config: @@ -693,16 +698,16 @@ def save_checkpoint( if "sharded_state_dict" not in checkpoint: checkpoint["sharded_state_dict"] = self.megatron_parallel.sharded_state_dict() - ## replace unsharded optimizer_states with sharded dict. - ## note that if trainer.save_checkpoint(path, save_weights_only=True) is called, - ## the checkpoint will contain only model weights. Optimizer states will be omitted. - if ( - "optimizer_states" in checkpoint - and self.trainer.state.fn == TrainerFn.FITTING - and self.ckpt_save_optimizer - ): + if "optimizer_states" in checkpoint and self.trainer.state.fn == TrainerFn.FITTING: + # Clear the optimizer states. This handles the case where ckpt_save_optimizer=False + # Ideally, the optimizer state dicts should not be generated in this case checkpoint["optimizer_states"] = {} - checkpoint["optimizer"] = [self.optimizer_sharded_state_dict()] + + ## replace unsharded optimizer_states with sharded dict. + ## note that if trainer.save_checkpoint(path, save_weights_only=True) is called, + ## the checkpoint will contain only model weights. Optimizer states will be omitted. + if self.ckpt_save_optimizer: + checkpoint["optimizer"] = [self.optimizer_sharded_state_dict()] self.checkpoint_io.save_checkpoint(checkpoint, filepath, storage_options=storage_options) @@ -733,7 +738,12 @@ def load_checkpoint(self, checkpoint_path: Union[str, Path], selective_restore: if self.lightning_module.optimizers(use_pl_optimizer=False): sharded_state_dict["optimizer"] = [self.optimizer_sharded_state_dict(is_loading=True)] - checkpoint = self.checkpoint_io.load_checkpoint(checkpoint_path, sharded_state_dict=sharded_state_dict) + strict = ( + self.lightning_module.strict_loading if self.ckpt_load_strictness is None else self.ckpt_load_strictness + ) + checkpoint = self.checkpoint_io.load_checkpoint( + checkpoint_path, sharded_state_dict=sharded_state_dict, strict=strict + ) if selective_restore: final_checkpoint = {} @@ -755,7 +765,8 @@ def selective_restore(self) -> None: if self.restore_config.load_model_state: logging.info(f"Restoring model weights from {self.restore_config}") - self.load_model_state_dict(checkpoint=checkpoint) + strict = True if self.ckpt_load_strictness is None else self.ckpt_load_strictness + self.load_model_state_dict(checkpoint=checkpoint, strict=strict) if self.restore_config.load_optim_state: logging.info(f"Restoring optimizer states from {self.restore_config}") @@ -790,6 +801,7 @@ def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict: bool = Tr """loads model state dict""" assert self.megatron_parallel is not None + strict = strict if self.ckpt_load_strictness is None else self.ckpt_load_strictness _strategy_lib.load_model_state_dict(self.megatron_parallel, checkpoint, strict=strict) if not 'optimizer' in checkpoint: diff --git a/nemo/lightning/resume.py b/nemo/lightning/resume.py index 7b534646731c1..6d6ddda1fd80d 100644 --- a/nemo/lightning/resume.py +++ b/nemo/lightning/resume.py @@ -37,17 +37,25 @@ def _try_restore_tokenizer(model, ckpt_path): + from nemo.collections.common.tokenizers import TokenizerSpec from nemo.lightning.io import load_context try: tokenizer = load_context(ckpt_path, "model.tokenizer") + except ValueError as e: + logging.warning( + f"Encountered error while trying to restore tokenizer. Tokenizer is not restored. " f"Original error: {e}" + ) + return model + + if isinstance(tokenizer, TokenizerSpec): model.tokenizer = tokenizer model.__io__.tokenizer = tokenizer.__io__ - except: - # Ignore if the ckpt doesn't have a tokenizer. - pass - finally: - return model + else: + # Ignore if the ckpt doesn't have a tokenizer. type(tokenizer)==TrainerContext in this case. + logging.warning("Checkpoint does not have model.tokenizer field. Tokenizer is not restored.") + + return model @dataclass(kw_only=True) @@ -56,8 +64,10 @@ class AutoResume: checkpoints in NeMo. Attributes: - restore_config (Optional[RestoreConfig]): Optional config for selectively restoring specific parts like model weights, optimizer states, etc. - If the config contains a path from HF or another non-NeMo checkpoint format, the checkpoint will be automatically converted to a NeMo compatible format. + restore_config (Optional[RestoreConfig]): Optional config for selectively restoring specific parts like model + weights, optimizer states, etc. + If the config contains a path from HF or another non-NeMo checkpoint format, the checkpoint will be + automatically converted to a NeMo compatible format. resume_from_folder or the run's log_dir takes precedence over restore_config. resume_from_directory (str): Path to the checkpointing directory to restore from. resume_from_path (str): Path to a specific checkpoint to restore from. @@ -209,17 +219,22 @@ def _find_trainer_ckpt_path(self) -> Optional[Path]: if not checkpoint_dir.exists() or (not len(end_checkpoints) > 0 and not len(last_checkpoints) > 0): if self.resume_ignore_no_checkpoint: - warn = f"There were no checkpoints found in checkpoint_dir or no checkpoint folder at checkpoint_dir :{checkpoint_dir}. " + warn = ( + f"There were no checkpoints found in checkpoint_dir or no checkpoint folder at checkpoint_dir " + f":{checkpoint_dir}. " + ) if checkpoint is None: warn += "Training from scratch." logging.warning(warn) else: if self.restore_config: - # resume_if_exists is True but run is not resumable. Do not fail and try to do selective restore later instead. + # resume_if_exists is True but run is not resumable. Do not fail and try to do selective restore + # later instead. return None else: raise NotFoundError( - f"There were no checkpoints found in checkpoint_dir or no checkpoint folder at checkpoint_dir :{checkpoint_dir}. Cannot resume." + f"There were no checkpoints found in checkpoint_dir or no checkpoint folder at checkpoint_dir " + f":{checkpoint_dir}. Cannot resume." ) elif len(end_checkpoints) > 0: if not self.resume_past_end: @@ -240,7 +255,8 @@ def _find_trainer_ckpt_path(self) -> Optional[Path]: # Select the checkpoint with the latest modified time checkpoint = sorted(last_checkpoints, key=lambda pth: pth.lstat().st_mtime, reverse=True)[0] logging.warning( - f"Multiple checkpoints {last_checkpoints} matches *last.ckpt. Selecting one with the latest modified time." + f"Multiple checkpoints {last_checkpoints} matches *last.ckpt. Selecting one with the latest " + f"modified time." ) else: checkpoint = last_checkpoints[0] diff --git a/nemo/utils/sequence_packing_utils.py b/nemo/utils/sequence_packing_utils.py index cee2be248f733..2ca03ce44b671 100644 --- a/nemo/utils/sequence_packing_utils.py +++ b/nemo/utils/sequence_packing_utils.py @@ -115,7 +115,7 @@ def create_hist(dataset: np.array, truncate_seq_len: int): logging.info("Creating histogram from tokenized dataset...") sequences = collections.defaultdict(list) - counts = [0] * truncate_seq_len + counts = [0] * (truncate_seq_len + 1) for item_dict in dataset: # Minus 1 here to account for the fact that transformer input and label have one less token than the full sequence @@ -129,7 +129,7 @@ def create_hist(dataset: np.array, truncate_seq_len: int): logging.debug(counts) histogram = [] - for seq_len in range(truncate_seq_len): + for seq_len in range(truncate_seq_len + 1): histogram.append(len(sequences[seq_len])) return sequences, histogram diff --git a/requirements/requirements_asr.txt b/requirements/requirements_asr.txt index d28b3f7980a7b..783f7a483dc5a 100644 --- a/requirements/requirements_asr.txt +++ b/requirements/requirements_asr.txt @@ -8,6 +8,7 @@ kaldiio lhotse>=1.26.0 librosa>=0.10.2 marshmallow +optuna packaging pyannote.core pyannote.metrics diff --git a/scripts/deploy/nlp/query_inframework.py b/scripts/deploy/nlp/query_inframework.py index e77ab72a1f04b..a62e09fa071d8 100644 --- a/scripts/deploy/nlp/query_inframework.py +++ b/scripts/deploy/nlp/query_inframework.py @@ -15,7 +15,7 @@ import argparse import sys -from nemo.deploy.nlp.query_llm import NemoQueryLLMPyTorch +from nemo.deploy.nlp import NemoQueryLLMPyTorch def get_args(argv): diff --git a/scripts/nlp_language_modeling/prepare_packed_ft_dataset.py b/scripts/nlp_language_modeling/prepare_packed_ft_dataset.py index 7ff2342e4087e..19a3e6a78228e 100644 --- a/scripts/nlp_language_modeling/prepare_packed_ft_dataset.py +++ b/scripts/nlp_language_modeling/prepare_packed_ft_dataset.py @@ -88,6 +88,14 @@ def tokenize_dataset(cfg: 'DictConfig'): # using the same template as SFT/PEFT script. This may be overkill but guarantees the preprocess settings # are identical to normal SFT training data_cfg = cfg.model.data.train_ds + pad_seq_length_to_mult = 16 + cp_size = cfg.model.get("context_parallel_size", 1) + + # if context parallel is used, each individual data length in one packed dataset sample + # needs to be a multiple of (cp_size * 2): https://github.com/NVIDIA/TransformerEngine/pull/641 + if cp_size > 1: + pad_seq_length_to_mult = max(pad_seq_length_to_mult, cp_size * 2) + if os.path.isdir(cfg.tokenizer_path): # pass in a Hugging Face folder which contains tokenizer.json tokenizer = get_nmt_tokenizer(library="huggingface", model_name=cfg.tokenizer_path, use_fast=True) @@ -99,7 +107,7 @@ def tokenize_dataset(cfg: 'DictConfig'): tokenizer=tokenizer, max_seq_length=data_cfg.max_seq_length, min_seq_length=data_cfg.min_seq_length, - pad_seq_length_to_mult=16, # adds padding in collate_fn so this value is irrelevant here + pad_seq_length_to_mult=pad_seq_length_to_mult, add_bos=data_cfg.get('add_bos', False), add_eos=data_cfg.get('add_eos', True), add_sep=data_cfg.get('add_sep', False), @@ -121,7 +129,40 @@ def tokenize_dataset(cfg: 'DictConfig'): is_test=True, ) - return np.array([dataset[i] for i in range(len(dataset))]) + max_seq_length = dataset.max_seq_length + pad_id = dataset.tokenizer.eos_id + pad_seq_length_to_mult = dataset.pad_seq_length_to_mult + dataset = np.array([dataset[i] for i in range(len(dataset))]) + if cp_size > 1: + + def pre_pad_dataset(data, max_seq_length, max_length_to_pad, pad_id): + ''' + pad each individual data point to the length of max_length + ''' + assert max_seq_length >= max_length_to_pad + for key, val in data.items(): + if key in {'input_ids', 'context_ids'}: + if len(val) <= max_length_to_pad: + # because input_ids are truncated by 1 for inputs and labels, + # we add 1 extra padding here to make sure padded inputs and labels + # are is a multiple of (cp_size * 2) + val = val + [pad_id] * (max_length_to_pad - len(val) + 1) + data[key] = val + elif len(val) > max_seq_length: + logging.info( + f"""The current sequence length {len(val)} for packing is + larger than the max_seq_length specified ({max_seq_length}). + The current seqquence length is truncated to the size of max_seq_length. + Please consider increase the sequence packing size""" + ) + data[key] = val[:max_seq_length] + return + + ceil_to_nearest = lambda n, m: (n + m - 1) // m * m + for data in dataset: + max_length_to_pad = min(max_seq_length, ceil_to_nearest(len(data['input_ids']), pad_seq_length_to_mult)) + pre_pad_dataset(data, max_seq_length, max_length_to_pad, pad_id) + return dataset @dataclass diff --git a/scripts/vlm/llava_next_finetune.py b/scripts/vlm/llava_next_finetune.py new file mode 100644 index 0000000000000..334b360d7c704 --- /dev/null +++ b/scripts/vlm/llava_next_finetune.py @@ -0,0 +1,236 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Example: + torchrun --nproc_per_node=8 scripts/vlm/llava_next_finetune.py \ + --devices=8 --tp=4 --data_type=mock + + torchrun --nproc_per_node=8 scripts/vlm/llava_next_finetune.py \ + --devices=8 --tp=4 --data_type=energon --data_path='' \ + --language_model_path=/root/.cache/nemo/models/lmsys/vicuna-7b-v1.5 +""" + +import argparse + +import torch +from megatron.core.optimizer import OptimizerConfig +from pytorch_lightning.loggers import WandbLogger + +from nemo import lightning as nl +from nemo.collections import llm, vlm +from nemo.lightning.pytorch.optim import CosineAnnealingScheduler +from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule +from nemo.utils.exp_manager import TimingCallback + + +def main(args): + # pylint: disable=C0115,C0116 + + # Global and micro batch sizes + gbs = args.gbs + mbs = args.mbs + max_steps = args.max_steps + + decoder_seq_length = 4096 + + if args.data_type == "energon": + from transformers import AutoProcessor + + from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer + from nemo.collections.multimodal.data.energon import SimpleMultiModalDataModule + from nemo.collections.multimodal.data.energon.config import MultiModalSampleConfig + from nemo.collections.vlm import LlavaNextTaskEncoder + + data_path = args.data_path + model_id = "llava-hf/llava-v1.6-vicuna-7b-hf" + processor = AutoProcessor.from_pretrained(model_id) + tokenizer = AutoTokenizer(model_id) + + multimodal_sample_config = MultiModalSampleConfig() + + task_encoder = LlavaNextTaskEncoder( + tokenizer=tokenizer.tokenizer, + image_processor=processor.image_processor, + multimodal_sample_config=multimodal_sample_config, + ) + data = SimpleMultiModalDataModule( + path=data_path, + tokenizer=tokenizer, + image_processor=processor.image_processor, + num_workers=32, + micro_batch_size=mbs, + global_batch_size=gbs, + multimodal_sample_config=multimodal_sample_config, + task_encoder=task_encoder, + ) + + elif args.data_type == "mock": + data = vlm.LlavaNextMockDataModule( + seq_length=decoder_seq_length, + global_batch_size=gbs, + micro_batch_size=mbs, + tokenizer=None, + image_processor=None, + num_workers=4, + ) + else: + raise ValueError(f"Data type {args.data_type} not supported") + + # Submodules configurations + language_transformer_config = llm.Llama2Config7B( + seq_length=decoder_seq_length, + ) + vision_transformer_config = vlm.HFCLIPVisionConfig( + pretrained_model_name_or_path="openai/clip-vit-large-patch14-336" + ) + vision_projection_config = vlm.MultimodalProjectorConfig( + projector_type=args.projector_type, + input_size=vision_transformer_config.hidden_size, + hidden_size=language_transformer_config.hidden_size, + ffn_hidden_size=language_transformer_config.hidden_size, + ) + + # Llava Next model configuration + llava_next_config = vlm.LlavaNextConfig( + language_transformer_config=language_transformer_config, + vision_transformer_config=vision_transformer_config, + vision_projection_config=vision_projection_config, + language_model_from_pretrained=args.language_model_path, + freeze_language_model=False, + freeze_vision_model=True, + ) + + model = vlm.LlavaNextModel(llava_next_config, tokenizer=data.tokenizer) + + # Training strategy setup + strategy = nl.MegatronStrategy( + tensor_model_parallel_size=args.tp_size, + pipeline_model_parallel_size=args.pp_size, + encoder_pipeline_model_parallel_size=args.encoder_pp_size, + pipeline_dtype=torch.bfloat16, + sequence_parallel=False, + ) + + # Checkpoint callback setup + checkpoint_callback = nl.ModelCheckpoint( + save_last=True, + monitor="reduced_train_loss", + save_top_k=2, + every_n_train_steps=1000, + dirpath=args.log_dir, + ) + + # Trainer setup + trainer = nl.Trainer( + num_nodes=args.num_nodes, + devices=args.devices, + max_steps=max_steps, + accelerator="gpu", + strategy=strategy, + plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"), + callbacks=[checkpoint_callback, TimingCallback()], + val_check_interval=500, + limit_val_batches=gbs, + log_every_n_steps=1, + num_sanity_val_steps=0, + ) + + # Logger setup + nemo_logger = nl.NeMoLogger( + log_dir=args.log_dir, + name=args.name, + wandb=WandbLogger(project=args.wandb_project, name=args.name) if args.wandb_project is not None else None, + ) + + # Auto resume setup + resume = nl.AutoResume( + resume_if_exists=True, + resume_ignore_no_checkpoint=True, + resume_from_directory=args.log_dir, + restore_config=nl.RestoreConfig(path=args.restore_path) if args.restore_path is not None else None, + ) + + # Optimizer and scheduler setup + opt_config = OptimizerConfig( + optimizer='adam', + lr=args.lr, + adam_beta1=0.9, + adam_beta2=0.95, + use_distributed_optimizer=True, + bf16=True, + ) + sched = CosineAnnealingScheduler( + max_steps=trainer.max_steps, + warmup_steps=150, + constant_steps=0, + min_lr=2.0e-07, + ) + opt = MegatronOptimizerModule(opt_config, sched) + + # PEFT setup + if args.peft == 'lora': + peft = vlm.peft.LoRA( + target_modules=[ + "linear_qkv", + "linear_proj", + "linear_fc1", + "linear_fc2", + ] + ) + else: + peft = None + + llm.finetune( + model=model, + data=data, + trainer=trainer, + peft=peft, + log=nemo_logger, + optim=opt, + resume=resume, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Llava Next Finetuning Script") + + # Argument parsing + parser.add_argument("--data_type", type=str, required=False, default="mock", help="mock | energon") + parser.add_argument("--data_path", type=str, required=False, default=None, help="Path to the dataset JSON file") + parser.add_argument( + "--log_dir", type=str, required=False, default="/results", help="Directory for logging and checkpoints" + ) + parser.add_argument( + "--language_model_path", type=str, required=False, default=None, help="Path to the pretrained language model" + ) + parser.add_argument( + "--restore_path", type=str, required=False, default=None, help="Path to restore model from checkpoint" + ) + parser.add_argument("--devices", type=int, required=False, default=1) + parser.add_argument("--num_nodes", type=int, required=False, default=1) + parser.add_argument("--max_steps", type=int, required=False, default=5190) + parser.add_argument("--tp_size", type=int, required=False, default=4) + parser.add_argument("--pp_size", type=int, required=False, default=1) + parser.add_argument("--encoder_pp_size", type=int, required=False, default=0) + parser.add_argument("--projector_type", type=str, required=False, default="mlp2x_gelu") + parser.add_argument("--name", type=str, required=False, default="llava_next_finetune") + parser.add_argument("--peft", type=str, default='none', help="none | lora") + parser.add_argument("--wandb_project", type=str, required=False, default=None) + parser.add_argument("--gbs", type=int, required=False, default=64, help="Global batch size") + parser.add_argument("--mbs", type=int, required=False, default=4, help="Micro batch size") + parser.add_argument("--lr", type=float, required=False, default=2.0e-05, help="Learning rate") + + args = parser.parse_args() + main(args) diff --git a/scripts/vlm/llava_next_generation.py b/scripts/vlm/llava_next_generation.py new file mode 100644 index 0000000000000..1b3d8a46b9550 --- /dev/null +++ b/scripts/vlm/llava_next_generation.py @@ -0,0 +1,166 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +import requests +import torch +from PIL import Image +from transformers import AutoProcessor + +from nemo import lightning as nl +from nemo.collections import vlm +from nemo.utils import logging + + +def load_image(image_url: str) -> Image.Image: + # pylint: disable=C0115,C0116 + try: + response = requests.get(image_url, stream=True) + response.raise_for_status() + image = Image.open(response.raw) + return image + except requests.exceptions.RequestException as e: + print(f"Error loading image from {image_url}: {e}") + return None + + +def generate(model, processor, raw_image, text): + # pylint: disable=C0115,C0116 + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What are these?"}, + {"type": "image"}, + ], + } + ] + + input_text = processor.apply_chat_template(messages, add_generation_prompt=True) + + input_ids = processor.tokenizer(input_text, return_tensors='pt').input_ids.cuda() + inputs = processor(input_text, raw_image, return_tensors='pt').to(0, torch.float32) + + input_ids[input_ids == 32000] = -200 + media = inputs['pixel_values'].cuda() + media = media.reshape(media.size(1), 3, 336, 336) + position_ids = ( + torch.arange(input_ids.size(1), dtype=torch.long, device=input_ids.device).unsqueeze(0).expand_as(input_ids) + ) + + generated_ids = input_ids.clone() + width, height = raw_image.size + image_sizes = torch.tensor([[height, width]], dtype=torch.long).cuda() + + for _ in range(20): + with torch.no_grad(): + attention_mask = (input_ids != 0).long().cuda() + output = model( + media=media, + input_ids=input_ids, + position_ids=position_ids, + image_sizes=image_sizes, + num_media_tiles=[media.size(0)], + attention_mask=attention_mask, + ) + next_token_ids = torch.argmax(output[:, -1], dim=-1, keepdim=True) + + generated_ids = torch.cat([generated_ids, next_token_ids], dim=-1) + + input_ids = generated_ids + position_ids = ( + torch.arange(input_ids.size(1), dtype=torch.long, device=input_ids.device) + .unsqueeze(0) + .expand_as(input_ids) + ) + print(f"next_token_ids {next_token_ids}") + + # If the generated token is the end of sequence token, stop generating + if next_token_ids.item() == processor.tokenizer.eos_token_id: + print(f"breaking") + break + generated_ids[generated_ids == -200] = 0 + generated_texts = processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=False) + logging.info("======== GENERATED TEXT OUTPUT ========") + logging.info(f"{generated_texts}") + logging.info("=======================================") + + +def main(args) -> None: + # pylint: disable=C0115,C0116 + model_id = 'llava-hf/llava-v1.6-vicuna-7b-hf' + strategy = nl.MegatronStrategy( + tensor_model_parallel_size=args.tp_size, + ckpt_load_optimizer=False, + ckpt_save_optimizer=False, + ) + trainer = nl.Trainer( + devices=args.tp_size, + max_steps=1000, + accelerator="gpu", + strategy=strategy, + plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"), + val_check_interval=1000, + limit_val_batches=50, + ) + + processor = AutoProcessor.from_pretrained(model_id) + tokenizer = processor.tokenizer + + fabric = trainer.to_fabric() + + if args.load_from_hf: + model = fabric.import_model("hf://llava-hf/llava-v1.6-vicuna-7b-hf", vlm.LlavaNextModel) + else: + model = vlm.LlavaNextModel(vlm.LlavaNextConfig7B(), tokenizer=tokenizer) + model = fabric.load_model(args.local_model_path, model) + + model = model.module.cuda() + model.eval() + model = model.to(torch.bfloat16) + + # Load the image + raw_image = load_image(args.image_url) + if raw_image is None: + return # Exit if the image can't be loaded + + generate(model, processor, raw_image=raw_image, text="What are these?") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Llava Next Generation example") + parser.add_argument( + "--load_from_hf", + action="store_true", + help="Flag to indicate whether to load the model from Hugging Face hub.", + ) + parser.add_argument( + "--local_model_path", + type=str, + default=None, + help="Local path to the model if not loading from Hugging Face.", + ) + parser.add_argument( + "--image_url", + type=str, + # pylint: disable=line-too-long + default="http://images.cocodataset.org/val2017/000000039769.jpg", + help="URL of the image to use for inference.", + ) + parser.add_argument("--devices", type=int, required=False, default=1) + parser.add_argument("--tp_size", type=int, required=False, default=1) + + args = parser.parse_args() + main(args) diff --git a/scripts/vlm/llava_next_nemo_run.py b/scripts/vlm/llava_next_nemo_run.py new file mode 100644 index 0000000000000..3193b05e10fc7 --- /dev/null +++ b/scripts/vlm/llava_next_nemo_run.py @@ -0,0 +1,71 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import nemo_run as run + +from nemo.collections import vlm + + +def configure_recipe(nodes: int = 1, gpus_per_node: int = 8, pretrain=False): + """Configure the recipe""" + if pretrain: + recipe = vlm.llava_next_7b.pretrain_recipe( + dir="./outputs/checkpoints/llava", # Path to store checkpoints + name="llava_pretrain", + num_nodes=nodes, + num_gpus_per_node=gpus_per_node, + ) + else: + recipe = vlm.llava_next_7b.finetune_recipe( + dir="./outputs/checkpoints/llava", # Path to store checkpoints + name="llava_finetune", + num_nodes=nodes, + num_gpus_per_node=gpus_per_node, + ) + recipe.trainer.max_steps = 100 + recipe.trainer.val_check_interval = 100 + recipe.model.config.freeze_vision_model = True + return recipe + + +def local_executor_torchrun(nodes: int = 1, devices: int = 8) -> run.LocalExecutor: + # pylint: disable=C0115,C0116 + # Env vars for jobs are configured here + env_vars = {} + + executor = run.LocalExecutor(ntasks_per_node=devices, launcher="torchrun", env_vars=env_vars) + + return executor + + +def run_pretraining(): + # pylint: disable=C0115,C0116 + recipe = configure_recipe(pretrain=True) + executor = local_executor_torchrun(nodes=recipe.trainer.num_nodes, devices=recipe.trainer.devices) + + run.run(recipe, executor=executor) + + +def run_finetuning(): + # pylint: disable=C0115,C0116 + recipe = configure_recipe(pretrain=False) + executor = local_executor_torchrun(nodes=recipe.trainer.num_nodes, devices=recipe.trainer.devices) + + run.run(recipe, executor=executor) + + +# This condition is necessary for the script to be compatible with Python's multiprocessing module. +if __name__ == "__main__": + run_pretraining() + # run_finetuning() diff --git a/scripts/vlm/llava_next_pretrain.py b/scripts/vlm/llava_next_pretrain.py new file mode 100644 index 0000000000000..bb84e3dae1e5b --- /dev/null +++ b/scripts/vlm/llava_next_pretrain.py @@ -0,0 +1,223 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Example: + torchrun --nproc_per_node=8 scripts/vlm/llava_next_pretrain.py \ + --devices=8 --tp=4 --data_type=mock + + torchrun --nproc_per_node=8 scripts/vlm/llava_next_pretrain.py \ + --devices=8 --tp=4 --data_type=energon --data_path='' \ + --language_model_path=/root/.cache/nemo/models/lmsys/vicuna-7b-v1.5 +""" + +import argparse + +import torch +from megatron.core.optimizer import OptimizerConfig +from pytorch_lightning.loggers import WandbLogger + +from nemo import lightning as nl +from nemo.collections import llm, vlm +from nemo.lightning.pytorch.optim import CosineAnnealingScheduler +from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule +from nemo.utils.exp_manager import TimingCallback + + +def main(args): + # pylint: disable=C0115,C0116 + + # Global and micro batch sizes + gbs = args.gbs + mbs = args.mbs + max_steps = args.max_steps + + decoder_seq_length = 4096 + + if args.data_type == "energon": + from transformers import AutoProcessor + + from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer + from nemo.collections.multimodal.data.energon import SimpleMultiModalDataModule + from nemo.collections.multimodal.data.energon.config import MultiModalSampleConfig + from nemo.collections.vlm import LlavaNextTaskEncoder + + data_path = args.data_path + model_id = "llava-hf/llava-v1.6-vicuna-7b-hf" + processor = AutoProcessor.from_pretrained(model_id) + tokenizer = AutoTokenizer(model_id) + + multimodal_sample_config = MultiModalSampleConfig() + # Setting system prompt to empty string + multimodal_sample_config.conversation_template_config.system = '' + + task_encoder = LlavaNextTaskEncoder( + tokenizer=tokenizer.tokenizer, + image_processor=processor.image_processor, + multimodal_sample_config=multimodal_sample_config, + ) + data = SimpleMultiModalDataModule( + path=data_path, + tokenizer=tokenizer, + image_processor=processor.image_processor, + num_workers=32, + micro_batch_size=mbs, + global_batch_size=gbs, + multimodal_sample_config=multimodal_sample_config, + task_encoder=task_encoder, + ) + + elif args.data_type == "mock": + data = vlm.LlavaNextMockDataModule( + seq_length=decoder_seq_length, + global_batch_size=gbs, + micro_batch_size=mbs, + tokenizer=None, + image_processor=None, + num_workers=4, + ) + else: + raise ValueError(f"Data type {args.data_type} not supported") + + # Submodules configurations + language_transformer_config = llm.Llama2Config7B( + seq_length=decoder_seq_length, + ) + vision_transformer_config = vlm.HFCLIPVisionConfig( + pretrained_model_name_or_path="openai/clip-vit-large-patch14-336" + ) + vision_projection_config = vlm.MultimodalProjectorConfig( + projector_type=args.projector_type, + input_size=vision_transformer_config.hidden_size, + hidden_size=language_transformer_config.hidden_size, + ffn_hidden_size=language_transformer_config.hidden_size, + ) + + # Llava Next model configuration + llava_next_config = vlm.LlavaNextConfig( + language_transformer_config=language_transformer_config, + vision_transformer_config=vision_transformer_config, + vision_projection_config=vision_projection_config, + language_model_from_pretrained=args.language_model_path, + freeze_language_model=True, + freeze_vision_model=True, + ) + + model = vlm.LlavaNextModel(llava_next_config, tokenizer=data.tokenizer) + + # Training strategy setup + strategy = nl.MegatronStrategy( + tensor_model_parallel_size=args.tp_size, + pipeline_model_parallel_size=args.pp_size, + encoder_pipeline_model_parallel_size=args.encoder_pp_size, + pipeline_dtype=torch.bfloat16, + sequence_parallel=False, + ) + + # Checkpoint callback setup + checkpoint_callback = nl.ModelCheckpoint( + save_last=True, + monitor="reduced_train_loss", + save_top_k=2, + every_n_train_steps=1000, + dirpath=args.log_dir, + ) + + # Trainer setup + trainer = nl.Trainer( + num_nodes=args.num_nodes, + devices=args.devices, + max_steps=max_steps, + accelerator="gpu", + strategy=strategy, + plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"), + callbacks=[checkpoint_callback, TimingCallback()], + val_check_interval=500, + limit_val_batches=gbs, + log_every_n_steps=1, + num_sanity_val_steps=0, + ) + + # Logger setup + nemo_logger = nl.NeMoLogger( + log_dir=args.log_dir, + name=args.name, + wandb=WandbLogger(project=args.wandb_project, name=args.name) if args.wandb_project is not None else None, + ) + + # Auto resume setup + resume = nl.AutoResume( + resume_if_exists=True, + resume_ignore_no_checkpoint=True, + resume_from_directory=args.log_dir, + restore_config=nl.RestoreConfig(path=args.restore_path) if args.restore_path is not None else None, + ) + + # Optimizer and scheduler setup + opt_config = OptimizerConfig( + optimizer='adam', + lr=args.lr, + adam_beta1=0.9, + adam_beta2=0.95, + use_distributed_optimizer=True, + bf16=True, + ) + sched = CosineAnnealingScheduler( + max_steps=trainer.max_steps, + warmup_steps=150, + constant_steps=0, + min_lr=2.0e-05, + ) + opt = MegatronOptimizerModule(opt_config, sched) + + llm.pretrain( + model=model, + data=data, + trainer=trainer, + log=nemo_logger, + optim=opt, + resume=resume, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Llava Next Pretraining Script") + + # Argument parsing + parser.add_argument("--data_type", type=str, required=False, default="mock", help="mock | energon") + parser.add_argument("--data_path", type=str, required=False, default=None, help="Path to the dataset JSON file") + parser.add_argument( + "--log_dir", type=str, required=False, default="/results", help="Directory for logging and checkpoints" + ) + parser.add_argument( + "--language_model_path", type=str, required=False, default=None, help="Path to the pretrained language model" + ) + parser.add_argument( + "--restore_path", type=str, required=False, default=None, help="Path to restore model from checkpoint" + ) + parser.add_argument("--devices", type=int, required=False, default=1) + parser.add_argument("--num_nodes", type=int, required=False, default=1) + parser.add_argument("--max_steps", type=int, required=False, default=2100) + parser.add_argument("--tp_size", type=int, required=False, default=2) + parser.add_argument("--pp_size", type=int, required=False, default=1) + parser.add_argument("--encoder_pp_size", type=int, required=False, default=0) + parser.add_argument("--projector_type", type=str, required=False, default="mlp2x_gelu") + parser.add_argument("--name", type=str, required=False, default="llava_next_pretrain") + parser.add_argument("--wandb_project", type=str, required=False, default=None) + parser.add_argument("--gbs", type=int, required=False, default=32, help="Global batch size") + parser.add_argument("--mbs", type=int, required=False, default=4, help="Micro batch size") + parser.add_argument("--lr", type=float, required=False, default=0.001, help="Learning rate") + + args = parser.parse_args() + main(args) diff --git a/tests/collections/asr/test_diar_utils.py b/tests/collections/asr/test_diar_utils.py index f48292d27981f..cb364675fcf4f 100644 --- a/tests/collections/asr/test_diar_utils.py +++ b/tests/collections/asr/test_diar_utils.py @@ -48,7 +48,7 @@ get_online_subsegments_from_buffer, get_speech_labels_for_update, get_sub_range_list, - get_subsegments, + get_subsegments_scriptable, get_target_sig, int2fl, is_overlap, @@ -82,8 +82,7 @@ def matrix(mat, use_tensor=True, dtype=torch.long): def generate_orthogonal_embs(total_spks, perturb_sigma, emb_dim): - """Generate a set of artificial orthogonal embedding vectors from random numbers - """ + """Generate a set of artificial orthogonal embedding vectors from random numbers""" gaus = torch.randn(emb_dim, emb_dim) _svd = torch.linalg.svd(gaus) orth = _svd[0] @ _svd[2] @@ -110,7 +109,7 @@ def generate_toy_data( random_orthogonal_embs = generate_orthogonal_embs(n_spks, perturb_sigma, emb_dim) for scale_idx, (window, shift) in enumerate(zip(ms_window, ms_shift)): for spk_idx, (offset, dur) in enumerate(spk_timestamps): - segments_stt_dur = get_subsegments(offset=offset, window=window, shift=shift, duration=dur) + segments_stt_dur = get_subsegments_scriptable(offset=offset, window=window, shift=shift, duration=dur) segments = [[x[0], x[0] + x[1]] for x in segments_stt_dur] emb_cent = random_orthogonal_embs[spk_idx, :] emb = emb_cent.tile((len(segments), 1)) + 0.1 * torch.rand(len(segments), emb_dim) @@ -130,8 +129,7 @@ def generate_toy_data( class TestDiarizationSequneceUtilFunctions: - """Tests diarization and speaker-task related utils. - """ + """Tests diarization and speaker-task related utils.""" @pytest.mark.unit @pytest.mark.parametrize("Y", [[3, 3, 3, 4, 4, 5], [100, 100, 100, 104, 104, 1005]]) @@ -278,7 +276,10 @@ def test_embedding_reducer(self, n_spks, target_speaker_index, merge_quantity): em, ts, mc, mw, spk_ts, gt = generate_toy_data(n_spks=n_spks, spk_dur=10) em_s, ts_s = split_input_data(em, ts, mc) merged_embs, merged_clus_labels, _ = run_reducer( - pre_embs=em_s[-1], target_spk_idx=target_speaker_index, merge_quantity=merge_quantity, pre_clus_labels=gt, + pre_embs=em_s[-1], + target_spk_idx=target_speaker_index, + merge_quantity=merge_quantity, + pre_clus_labels=gt, ) assert (torch.sum(gt == target_speaker_index).item() - merge_quantity) == merged_clus_labels.shape[0] @@ -287,7 +288,11 @@ def test_embedding_reducer(self, n_spks, target_speaker_index, merge_quantity): @pytest.mark.parametrize("pcl", [torch.tensor([0] * 70 + [1] * 32)]) @pytest.mark.parametrize("mspb", [25]) def test_merge_scheduler_2clus(self, ntbr, pcl, mspb): - class_target_vol = get_merge_quantity(num_to_be_removed=ntbr, pre_clus_labels=pcl, min_count_per_cluster=mspb,) + class_target_vol = get_merge_quantity( + num_to_be_removed=ntbr, + pre_clus_labels=pcl, + min_count_per_cluster=mspb, + ) assert all(class_target_vol == torch.tensor([3, 0])) @pytest.mark.unit @@ -295,7 +300,11 @@ def test_merge_scheduler_2clus(self, ntbr, pcl, mspb): @pytest.mark.parametrize("pcl", [torch.tensor([0] * 80 + [1] * 35 + [2] * 32)]) @pytest.mark.parametrize("mspb", [0, 25]) def test_merge_scheduler_3clus(self, ntbr, pcl, mspb): - class_target_vol = get_merge_quantity(num_to_be_removed=ntbr, pre_clus_labels=pcl, min_count_per_cluster=mspb,) + class_target_vol = get_merge_quantity( + num_to_be_removed=ntbr, + pre_clus_labels=pcl, + min_count_per_cluster=mspb, + ) assert all(class_target_vol == torch.tensor([3, 0, 0])) @pytest.mark.unit @@ -303,7 +312,11 @@ def test_merge_scheduler_3clus(self, ntbr, pcl, mspb): @pytest.mark.parametrize("pcl", [torch.tensor([2] * 70 + [0] * 32 + [1] * 27 + [3] * 3)]) @pytest.mark.parametrize("mspb", [3, 10]) def test_merge_scheduler_4clus_shuff(self, ntbr, pcl, mspb): - class_target_vol = get_merge_quantity(num_to_be_removed=ntbr, pre_clus_labels=pcl, min_count_per_cluster=mspb,) + class_target_vol = get_merge_quantity( + num_to_be_removed=ntbr, + pre_clus_labels=pcl, + min_count_per_cluster=mspb, + ) assert all(class_target_vol == torch.tensor([18, 13, 56, 0])) @pytest.mark.unit @@ -311,7 +324,11 @@ def test_merge_scheduler_4clus_shuff(self, ntbr, pcl, mspb): @pytest.mark.parametrize("pcl", [torch.tensor([0] * 5 + [1] * 4 + [2] * 3)]) @pytest.mark.parametrize("mspb", [0, 2]) def test_merge_scheduler_3clus(self, ntbr, pcl, mspb): - class_target_vol = get_merge_quantity(num_to_be_removed=ntbr, pre_clus_labels=pcl, min_count_per_cluster=mspb,) + class_target_vol = get_merge_quantity( + num_to_be_removed=ntbr, + pre_clus_labels=pcl, + min_count_per_cluster=mspb, + ) assert all(class_target_vol == torch.tensor([2, 1, 0])) @pytest.mark.unit @@ -319,7 +336,11 @@ def test_merge_scheduler_3clus(self, ntbr, pcl, mspb): @pytest.mark.parametrize("pcl", [torch.tensor([0] * 7 + [1] * 5 + [2] * 3 + [3] * 5)]) @pytest.mark.parametrize("mspb", [2]) def test_merge_scheduler_3clus_repeat(self, ntbr, pcl, mspb): - class_target_vol = get_merge_quantity(num_to_be_removed=ntbr, pre_clus_labels=pcl, min_count_per_cluster=mspb,) + class_target_vol = get_merge_quantity( + num_to_be_removed=ntbr, + pre_clus_labels=pcl, + min_count_per_cluster=mspb, + ) assert all(class_target_vol == torch.tensor([2, 0, 0, 0])) @@ -414,13 +435,21 @@ def test_is_overlap_false(self, rangeA, rangeB): @pytest.mark.parametrize("x", [1.0, 2.3456]) @pytest.mark.parametrize("decimals", [1, 2, 3, 4]) def test_fl2int(self, x, decimals): - assert fl2int(x, decimals) == round(x * 10 ** decimals, 0) + assert fl2int(x, decimals) == round(x * 10**decimals, 0) @pytest.mark.unit @pytest.mark.parametrize("x", [1234]) - @pytest.mark.parametrize("decimals", [1, 2, 3, 4,]) + @pytest.mark.parametrize( + "decimals", + [ + 1, + 2, + 3, + 4, + ], + ) def test_int2fl(self, x, decimals): - assert abs(int2fl(x, decimals) - round(x / (10 ** decimals), decimals)) < (10 ** -(decimals + 1)) + assert abs(int2fl(x, decimals) - round(x / (10**decimals), decimals)) < (10 ** -(decimals + 1)) @pytest.mark.unit def test_merge_float_intervals_edge_margin_test(self): @@ -462,7 +491,11 @@ def test_get_speech_labels_for_update(self): vad_timestamps = torch.tensor([[0.9600, 4.8400]]) cursor_for_old_segments = 1.0 speech_labels_for_update, cumulative_speech_labels = get_speech_labels_for_update( - frame_start, buffer_end, cumulative_speech_labels, vad_timestamps, cursor_for_old_segments, + frame_start, + buffer_end, + cumulative_speech_labels, + vad_timestamps, + cursor_for_old_segments, ) assert (speech_labels_for_update - torch.tensor([[1.0000, 3.7600]])).sum() < 1e-8 assert (cumulative_speech_labels - torch.tensor([[0.9600, 4.8400]])).sum() < 1e-8 @@ -532,7 +565,10 @@ def test_tensor_to_list(self, source_range_list): @pytest.mark.unit @pytest.mark.parametrize( "buffer_start, buffer_end, subsegments, ind_offset, window, sample_rate", - [(0.0, 2.0, [[0.5, 1.0], [1.5, 2.0]], 0, 0.1, 16000), (0.0, 5.0, [[0.5, 2.5], [2.7, 5.0]], 0, 1.0, 16000),], + [ + (0.0, 2.0, [[0.5, 1.0], [1.5, 2.0]], 0, 0.1, 16000), + (0.0, 5.0, [[0.5, 2.5], [2.7, 5.0]], 0, 1.0, 16000), + ], ) def test_get_online_segments_from_slices( self, buffer_start, buffer_end, subsegments, ind_offset, window, sample_rate @@ -665,7 +701,13 @@ def test_offline_speaker_clustering_cpu(self, n_spks, total_sec, SSV, perturb_si @pytest.mark.parametrize("SSV, enhanced_count_thres, min_samples_for_nmesc", [(5, 40, 6)]) @pytest.mark.parametrize("seed", [0]) def test_offline_speaker_clustering_very_short_cpu( - self, n_spks, spk_dur, SSV, enhanced_count_thres, min_samples_for_nmesc, seed, + self, + n_spks, + spk_dur, + SSV, + enhanced_count_thres, + min_samples_for_nmesc, + seed, ): em, ts, mc, mw, spk_ts, gt = generate_toy_data( n_spks=n_spks, spk_dur=spk_dur, perturb_sigma=0.1, torch_seed=seed @@ -697,7 +739,13 @@ def test_offline_speaker_clustering_very_short_cpu( @pytest.mark.parametrize("n_spks, SSV, enhanced_count_thres, min_samples_for_nmesc", [(1, 5, 40, 6)]) @pytest.mark.parametrize("seed", [0]) def test_offline_speaker_clustering_very_short_gpu( - self, n_spks, spk_dur, SSV, enhanced_count_thres, min_samples_for_nmesc, seed, + self, + n_spks, + spk_dur, + SSV, + enhanced_count_thres, + min_samples_for_nmesc, + seed, ): em, ts, mc, mw, spk_ts, gt = generate_toy_data( n_spks=n_spks, spk_dur=spk_dur, perturb_sigma=0.1, torch_seed=seed @@ -908,7 +956,7 @@ def test_linear_sum_assignment_algorithm_cost_matrix(self, cost_matrix): Test the linear sum assignment algorithm with a cost matrix Compare with the scipy implementation and make sure the final cost is the same. - NOTE: There could be multiple solutions with the same cost in linear sum assignment problem. + NOTE: There could be multiple solutions with the same cost in linear sum assignment problem. This test only checks if the cost is the same. """ row_ind_nm, col_ind_nm = nemo_linear_sum_assignment(cost_matrix) diff --git a/tests/collections/llm/test_mnist_model_nemo2_fsdp.py b/tests/collections/llm/test_mnist_model_nemo2_fsdp.py index da45d0e1fc385..9418ee7e5e902 100644 --- a/tests/collections/llm/test_mnist_model_nemo2_fsdp.py +++ b/tests/collections/llm/test_mnist_model_nemo2_fsdp.py @@ -525,6 +525,7 @@ def run_train_mnist_litautoencoder_with_fsdp_strategy_single_gpu(): every_n_train_steps=5, # Enables the .nemo file-like checkpointing where all IOMixins are under SerDe always_save_context=True, + filename="{model_name}--{val_loss:.2f}-{step}-{consumed_samples}", ) root_dir = tmpdir save_dir = root_dir / name @@ -572,6 +573,7 @@ def run_train_mnist_litautoencoder_with_fsdp_strategy_single_gpu(): global_batch_size=2, output_log=False, # Disable logs to support predict_step ), + ckpt_load_optimizer=False, ) predict_trainer = nl.Trainer( accelerator="gpu", diff --git a/tests/deploy/nemo_deploy.py b/tests/deploy/nemo_deploy.py index 23db7c4f01f3d..45f2bae3425ec 100644 --- a/tests/deploy/nemo_deploy.py +++ b/tests/deploy/nemo_deploy.py @@ -21,7 +21,7 @@ import torch -from nemo.deploy.nlp.megatronllm_deployable import MegatronLLMDeployable +from nemo.deploy.nlp import MegatronLLMDeployable from tests.infer_data_path import get_infer_test_data run_export_tests = True diff --git a/tests/export/nemo_export.py b/tests/export/nemo_export.py index 16aca9ccea4bd..cb2b3619e4d36 100644 --- a/tests/export/nemo_export.py +++ b/tests/export/nemo_export.py @@ -43,7 +43,8 @@ from nemo.deploy.nlp import MegatronLLMDeployable, NemoQueryLLMPyTorch except Exception as e: LOGGER.warning( - f"Cannot import MegatronLLMDeployable, in-framework inference will not be available. {type(e).__name__}: {e}" + "Cannot import MegatronLLMDeployable or NemoQueryLLMPyTorch," + f" in-framework inference will not be available. {type(e).__name__}: {e}" ) in_framework_supported = False @@ -103,7 +104,8 @@ def get_accuracy_with_lambada(model, nq, task_ids, lora_uids, test_data_path): expected_output = record["last_word"].strip().lower() all_expected_outputs.append(expected_output) if model is not None: - if isinstance(model, MegatronLLMDeployable): + + if in_framework_supported and isinstance(model, MegatronLLMDeployable): model_output = model.generate( inputs=[prompt], length_params={"min_length": 1, "max_length": 1}, @@ -147,7 +149,7 @@ def get_accuracy_with_lambada(model, nq, task_ids, lora_uids, test_data_path): correct_answers_relaxed += 1 if nq is not None: - if isinstance(nq, NemoQueryLLMPyTorch): + if in_framework_supported and isinstance(nq, NemoQueryLLMPyTorch): deployed_output = nq.query_llm( prompts=[prompt], max_length=1, diff --git a/tests/export/test_trt_compile.py b/tests/export/test_trt_compile.py new file mode 100644 index 0000000000000..09a77004678fb --- /dev/null +++ b/tests/export/test_trt_compile.py @@ -0,0 +1,139 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import tempfile +import unittest +from typing import List + +import torch + +TEST_CASE_1 = ["fp32"] +TEST_CASE_2 = ["fp16"] + + +class ListAdd(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: List[torch.Tensor], y: torch.Tensor, z: torch.Tensor, bs: float = 0.1): + y1 = y.clone() + x1 = x.copy() + z1 = z + y + for xi in x: + y1 = y1 + xi + bs + return x1, [y1, z1], y1 + z1 + + +@unittest.skip +class TestTRTCompile(unittest.TestCase): + + def setUp(self): + self.gpu_device = torch.cuda.current_device() + + def tearDown(self): + current_device = torch.cuda.current_device() + if current_device != self.gpu_device: + torch.cuda.set_device(self.gpu_device) + + def test_torch_trt(self): + + model = torch.nn.Sequential(*[torch.nn.PReLU(), torch.nn.PReLU()]) + data1 = model.state_dict() + data1["0.weight"] = torch.tensor([0.1]) + data1["1.weight"] = torch.tensor([0.2]) + model.load_state_dict(data1) + model.cuda() + x = torch.randn(1, 16).to("cuda") + + with tempfile.TemporaryDirectory() as tempdir: + args = { + "method": "torch_trt", + "dynamic_batchsize": [1, 4, 8], + } + input_example = (x,) + output_example = model(*input_example) + trt_compile( + model, + f"{tempdir}/test_lists", + args=args, + ) + self.assertIsNone(model._trt_compiler.engine) + trt_output = model(*input_example) + # Check that lazy TRT build succeeded + self.assertIsNotNone(model._trt_compiler.engine) + torch.testing.assert_close(trt_output, output_example, rtol=0.01, atol=0.01) + + def test_profiles(self): + model = ListAdd().cuda() + + with torch.no_grad(), tempfile.TemporaryDirectory() as tmpdir: + args = { + "export_args": { + "dynamo": False, + }, + "input_profiles": [ + { + "x_0": [[1, 8], [2, 16], [2, 32]], + "x_1": [[1, 8], [2, 16], [2, 32]], + "x_2": [[1, 8], [2, 16], [2, 32]], + "y": [[1, 8], [2, 16], [2, 32]], + "z": [[1, 8], [1, 16], [1, 32]], + } + ], + "output_lists": [[-1], [2], []], + } + x = torch.randn(1, 16).to("cuda") + y = torch.randn(1, 16).to("cuda") + z = torch.randn(1, 16).to("cuda") + input_example = ([x, y, z], y.clone(), z.clone()) + output_example = model(*input_example) + trt_compile( + model, + f"{tmpdir}/test_dynamo_trt", + args=args, + ) + self.assertIsNone(model._trt_compiler.engine) + trt_output = model(*input_example) + # Check that lazy TRT build succeeded + self.assertIsNotNone(model._trt_compiler.engine) + torch.testing.assert_close(trt_output, output_example, rtol=0.01, atol=0.01) + + def test_lists(self): + model = ListAdd().cuda() + + with torch.no_grad(), tempfile.TemporaryDirectory() as tmpdir: + args = { + "export_args": { + "dynamo": True, + }, + "output_lists": [[-1], [2], []], + } + x = torch.randn(1, 16).to("cuda") + y = torch.randn(1, 16).to("cuda") + z = torch.randn(1, 16).to("cuda") + input_example = ([x, y, z], y.clone(), z.clone()) + output_example = model(*input_example) + trt_compile( + model, + f"{tmpdir}/test_lists", + args=args, + ) + self.assertIsNone(model._trt_compiler.engine) + trt_output = model(*input_example) + # Check that lazy TRT build succeeded + self.assertIsNotNone(model._trt_compiler.engine) + torch.testing.assert_close(trt_output, output_example, rtol=0.01, atol=0.01) + + +if __name__ == "__main__": + unittest.main()