From b46d958eb24d254fe24b9811e58b3e0ed5e66be1 Mon Sep 17 00:00:00 2001 From: "He Huang (Steve)" <105218074+stevehuang52@users.noreply.github.com> Date: Fri, 4 Oct 2024 17:25:36 -0400 Subject: [PATCH] Add NEST SSL to main (#10319) * add files Signed-off-by: stevehuang52 * fix typecheck Signed-off-by: stevehuang52 * update typecheck Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * add README Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * update masking Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * fix quantizer Signed-off-by: stevehuang52 * fix quantizer Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * update config Signed-off-by: stevehuang52 * refactor Signed-off-by: stevehuang52 * update for ptl2 Signed-off-by: stevehuang52 * refactor and update Signed-off-by: stevehuang52 * add WavLM related scripts Signed-off-by: stevehuang52 * update for noise manifest Signed-off-by: stevehuang52 * update to handle empty noise Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * update multi-layer feat Signed-off-by: stevehuang52 * add debug multi_validation_epoch_end Signed-off-by: stevehuang52 * add spk id finetune Signed-off-by: stevehuang52 * update config Signed-off-by: stevehuang52 * update sid Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * add more logging Signed-off-by: stevehuang52 * update cfg Signed-off-by: stevehuang52 * add config Signed-off-by: stevehuang52 * update config Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * fix cfg Signed-off-by: stevehuang52 * update label_model Signed-off-by: stevehuang52 * add config Signed-off-by: stevehuang52 * update cfg Signed-off-by: stevehuang52 * update multi-eval-end Signed-off-by: stevehuang52 * update debug code Signed-off-by: stevehuang52 * update debug code Signed-off-by: stevehuang52 * update error handling Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * add detailed logging Signed-off-by: stevehuang52 * fix typo Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * update cfg Signed-off-by: stevehuang52 * add pair dataset Signed-off-by: stevehuang52 * add missing log Signed-off-by: stevehuang52 * fix ecapa attn dim Signed-off-by: stevehuang52 * fix val loss and update cfg Signed-off-by: stevehuang52 * add cfg Signed-off-by: stevehuang52 * update perturb Signed-off-by: stevehuang52 * fix cfg Signed-off-by: stevehuang52 * update dev loss and eval Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * revert changes in modelPT Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * update debug Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * add exception handling Signed-off-by: stevehuang52 * update cfg Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * update cfg Signed-off-by: stevehuang52 * update augmentation Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * update cfg Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * update dataset Signed-off-by: stevehuang52 * update data Signed-off-by: stevehuang52 * update cfg Signed-off-by: stevehuang52 * update cfg Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * update cfg Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * fix dataset Signed-off-by: stevehuang52 * update masking to skip short audios Signed-off-by: stevehuang52 * clean up Signed-off-by: stevehuang52 * clean up Signed-off-by: stevehuang52 * update dataset for exception handling Signed-off-by: stevehuang52 * update for exception handling Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * update misc Signed-off-by: stevehuang52 * update dataset Signed-off-by: stevehuang52 * update data Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * add lhotse support Signed-off-by: monica-sekoyan * fix lhotse dataloader Signed-off-by: Monica-Sekoyan * modified get_item Signed-off-by: Monica-Sekoyan * fix lhotse dataloader Signed-off-by: Monica Sekoyan * small fix Signed-off-by: Monica-Sekoyan * update model utils Signed-off-by: stevehuang52 * change binomial for overlap masking Signed-off-by: Monica-Sekoyan * fix Signed-off-by: Monica-Sekoyan * reverse debug changes Signed-off-by: Monica-Sekoyan * clean up Signed-off-by: stevehuang52 * update masking Signed-off-by: stevehuang52 * fix lhotse Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * update logging Signed-off-by: stevehuang52 * add logging Signed-off-by: stevehuang52 * update multispeaker aug Signed-off-by: stevehuang52 * update multispeaker aug Signed-off-by: stevehuang52 * update cfg Signed-off-by: stevehuang52 * update Signed-off-by: stevehuang52 * update cfg Signed-off-by: stevehuang52 * upload Signed-off-by: stevehuang52 * clean up Signed-off-by: stevehuang52 * clean up Signed-off-by: stevehuang52 * fix for codeQL Signed-off-by: stevehuang52 * refactor Signed-off-by: stevehuang52 * refactor Signed-off-by: stevehuang52 * refactor Signed-off-by: stevehuang52 * refactor and add docs Signed-off-by: stevehuang52 * refactor and clean up Signed-off-by: stevehuang52 * Apply isort and black reformatting Signed-off-by: stevehuang52 * refactor Signed-off-by: stevehuang52 * refactor and update Signed-off-by: stevehuang52 * Apply isort and black reformatting Signed-off-by: stevehuang52 * refactor and clean up Signed-off-by: stevehuang52 * clean up Signed-off-by: stevehuang52 * add tests and metrics Signed-off-by: stevehuang52 * update docstring Signed-off-by: stevehuang52 * fix for tests Signed-off-by: stevehuang52 --------- Signed-off-by: stevehuang52 Signed-off-by: monica-sekoyan Signed-off-by: Monica-Sekoyan Signed-off-by: Monica Sekoyan Signed-off-by: stevehuang52 Signed-off-by: He Huang (Steve) <105218074+stevehuang52@users.noreply.github.com> Co-authored-by: monica-sekoyan Co-authored-by: stevehuang52 --- .../fast-conformer_transformer.yaml | 4 +- .../nest_ecapa_tdnn_small.yaml | 185 +++++ .../multi_layer_feat/nest_titanet_small.yaml | 238 ++++++ .../conf/ssl/nest/nest_fast-conformer.yaml | 254 +++++++ examples/asr/speech_pretraining/README.md | 30 +- .../speech_classification_mfa_train.py | 108 +++ .../masked_token_pred_pretrain.py | 63 ++ .../speech_translation/translate_speech.py | 7 +- examples/slu/speech_intent_slot/README.md | 4 +- .../fastconformer_transformer_large_bpe.yaml | 221 ++++++ ...lot_eval.py => speech_intent_slot_eval.py} | 5 +- ...t_train.py => speech_intent_slot_train.py} | 16 +- .../recognition/conf/titanet-large.yaml | 1 + nemo/collections/asr/data/audio_to_label.py | 100 +++ nemo/collections/asr/data/audio_to_text.py | 20 + nemo/collections/asr/data/ssl_dataset.py | 704 ++++++++++++++++++ nemo/collections/asr/losses/__init__.py | 2 +- nemo/collections/asr/losses/ssl_losses/mlm.py | 79 +- nemo/collections/asr/models/__init__.py | 6 +- .../asr/models/configs/asr_models_config.py | 19 +- .../asr/models/hybrid_rnnt_ctc_models.py | 3 + nemo/collections/asr/models/label_models.py | 160 +++- nemo/collections/asr/models/ssl_models.py | 442 ++++++++++- nemo/collections/asr/modules/__init__.py | 6 + nemo/collections/asr/modules/conv_asr.py | 50 +- .../asr/modules/ssl_modules/__init__.py | 21 + .../asr/modules/ssl_modules/augmentation.py | 290 ++++++++ .../asr/modules/ssl_modules/masking.py | 199 +++++ .../modules/ssl_modules/multi_layer_feat.py | 206 +++++ .../ssl_modules/multi_softmax_decoder.py | 84 +++ .../asr/modules/ssl_modules/quantizers.py | 166 +++++ .../asr/parts/mixins/transcription.py | 22 +- .../asr/parts/preprocessing/perturb.py | 182 +++-- .../asr/parts/preprocessing/segment.py | 29 +- .../asr/parts/submodules/ssl_quantizers.py | 10 +- nemo/collections/common/data/lhotse/cutset.py | 14 +- nemo/collections/common/data/utils.py | 37 + .../common/parts/multi_layer_perceptron.py | 27 +- .../common/parts/preprocessing/collections.py | 8 +- .../speech_llm/models/modular_models.py | 4 +- .../speech_llm/models/modular_t5_models.py | 6 +- nemo/core/classes/dataset.py | 4 - nemo/core/classes/modelPT.py | 1 - nemo/core/neural_types/elements.py | 27 +- nemo/utils/model_utils.py | 45 +- tests/collections/asr/test_ssl_models.py | 213 +++++- .../test_data_utils.py} | 6 +- 47 files changed, 4082 insertions(+), 246 deletions(-) create mode 100644 examples/asr/conf/ssl/nest/multi_layer_feat/nest_ecapa_tdnn_small.yaml create mode 100644 examples/asr/conf/ssl/nest/multi_layer_feat/nest_titanet_small.yaml create mode 100644 examples/asr/conf/ssl/nest/nest_fast-conformer.yaml create mode 100644 examples/asr/speech_pretraining/downstream/speech_classification_mfa_train.py create mode 100644 examples/asr/speech_pretraining/masked_token_pred_pretrain.py create mode 100644 examples/slu/speech_intent_slot/configs/fastconformer_transformer_large_bpe.yaml rename examples/slu/speech_intent_slot/{run_speech_intent_slot_eval.py => speech_intent_slot_eval.py} (96%) rename examples/slu/speech_intent_slot/{run_speech_intent_slot_train.py => speech_intent_slot_train.py} (86%) create mode 100644 nemo/collections/asr/data/ssl_dataset.py create mode 100644 nemo/collections/asr/modules/ssl_modules/__init__.py create mode 100644 nemo/collections/asr/modules/ssl_modules/augmentation.py create mode 100644 nemo/collections/asr/modules/ssl_modules/masking.py create mode 100644 nemo/collections/asr/modules/ssl_modules/multi_layer_feat.py create mode 100644 nemo/collections/asr/modules/ssl_modules/multi_softmax_decoder.py create mode 100644 nemo/collections/asr/modules/ssl_modules/quantizers.py create mode 100644 nemo/collections/common/data/utils.py rename tests/collections/{asr/utils/test_transcription_move_to_device.py => common/test_data_utils.py} (82%) diff --git a/examples/asr/conf/speech_translation/fast-conformer_transformer.yaml b/examples/asr/conf/speech_translation/fast-conformer_transformer.yaml index 6b6dbf129c54..6289d1c42298 100644 --- a/examples/asr/conf/speech_translation/fast-conformer_transformer.yaml +++ b/examples/asr/conf/speech_translation/fast-conformer_transformer.yaml @@ -21,8 +21,8 @@ model: log_prediction: true # enables logging sample predictions in the output during training train_ds: - is_tarred: true - tarred_audio_filepaths: ??? + is_tarred: false + tarred_audio_filepaths: null manifest_filepath: ??? sample_rate: 16000 shuffle: false diff --git a/examples/asr/conf/ssl/nest/multi_layer_feat/nest_ecapa_tdnn_small.yaml b/examples/asr/conf/ssl/nest/multi_layer_feat/nest_ecapa_tdnn_small.yaml new file mode 100644 index 000000000000..e78588ae5bdf --- /dev/null +++ b/examples/asr/conf/ssl/nest/multi_layer_feat/nest_ecapa_tdnn_small.yaml @@ -0,0 +1,185 @@ +# This is an example config that uses a NEST model as feature extractors, by using multi-layer feature aggregation. +# The major modification is to replace `model.preprocessor` with the one in this config. + +name: "NEST_MFA_Tune_ECAPA_TDNN" + +model: + sample_rate: 16000 + + train_ds: + manifest_filepath: ??? + sample_rate: ${model.sample_rate} + labels: null + batch_size: 64 + num_workers: 8 + shuffle: True + augmentor: + noise: + manifest_path: ??? + prob: 0.5 + min_snr_db: 0 + max_snr_db: 15 + + speed: + prob: 0.5 + sr: ${model.sample_rate} + resample_type: 'kaiser_fast' + min_speed_rate: 0.95 + max_speed_rate: 1.05 + + impulse: + prob: 0.5 + manifest_path: ??? + + validation_ds: + manifest_filepath: ??? + sample_rate: ${model.sample_rate} + labels: null + batch_size: 128 + num_workers: 8 + shuffle: False + + # Overwrite the original 'preprocessor' and replace it with multi-layer feature aggregation from ConformerEncoder + preprocessor: + _target_: nemo.collections.asr.modules.ssl_modules.multi_layer_feat.ConformerMultiLayerFeaturePreprocessor + layer_idx_list: [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16] + freeze_encoder: True + aggregator: + _target_: nemo.collections.asr.modules.ssl_modules.multi_layer_feat.Aggregator + mode: "weighted_sum" + weights: null + layer_idx_list: ${model.preprocessor.layer_idx_list} + + # the actual preprocessor to use + preprocessor: + _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor + normalize: "per_feature" + window_size: 0.025 + sample_rate: ${model.sample_rate} + window_stride: 0.01 + window: "hann" + features: 80 + n_fft: 512 + frame_splicing: 1 + dither: 0.00001 + stft_conv: false + + spec_augment: + _target_: nemo.collections.asr.modules.SpectrogramAugmentation + freq_masks: 3 + freq_width: 4 + time_masks: 5 + time_width: 0.05 + + # this has to match with the ConformerEncoder config in NEST pretrain config + encoder: + _target_: nemo.collections.asr.modules.ConformerEncoder + feat_in: ${model.preprocessor.preprocessor.features} + feat_out: -1 # you may set it if you need different output size other than the default d_model + n_layers: 17 + d_model: 512 + + # Sub-sampling params + subsampling: dw_striding # vggnet, striding, stacking or stacking_norm, dw_striding + subsampling_factor: 8 # must be power of 2 for striding and vggnet + subsampling_conv_channels: 256 # -1 sets it to d_model + causal_downsampling: false + + # Feed forward module's params + ff_expansion_factor: 4 + + # Multi-headed Attention Module's params + self_attention_model: rel_pos # rel_pos or abs_pos + n_heads: 8 # may need to be lower for smaller d_models + # [left, right] specifies the number of steps to be seen from left and right of each step in self-attention + att_context_size: [-1, -1] # -1 means unlimited context + att_context_style: regular # regular or chunked_limited + xscaling: true # scales up the input embeddings by sqrt(d_model) + untie_biases: true # unties the biases of the TransformerXL layers + pos_emb_max_len: 5000 + + # Convolution module's params + conv_kernel_size: 9 + conv_norm_type: 'batch_norm' # batch_norm or layer_norm or groupnormN (N specifies the number of groups) + # conv_context_size can be"causal" or a list of two integers while conv_context_size[0]+conv_context_size[1]+1==conv_kernel_size + # null means [(kernel_size-1)//2, (kernel_size-1)//2], and 'causal' means [(kernel_size-1), 0] + conv_context_size: null + + ### regularization + dropout: 0.1 # The dropout used in most of the Conformer Modules + dropout_pre_encoder: 0.1 # The dropout used before the encoder + dropout_emb: 0.0 # The dropout used for embeddings + dropout_att: 0.1 # The dropout for multi-headed attention modules + + # set to non-zero to enable stochastic depth + stochastic_depth_drop_prob: 0.0 + stochastic_depth_mode: linear # linear or uniform + stochastic_depth_start_layer: 1 + + encoder: + _target_: nemo.collections.asr.modules.ECAPAEncoder + feat_in: ${model.preprocessor.encoder.d_model} + filters: [512,512,512,512,1536] + kernel_sizes: [5,3,3,3,1] + dilations: [1,1,1,1,1] + scale: 8 + + decoder: + _target_: nemo.collections.asr.modules.SpeakerDecoder + feat_in: 1536 + num_classes: 7205 + pool_mode: 'attention' #xvector,tap or attention + attention_channels: 128 + emb_sizes: 192 + + loss: + _target_: nemo.collections.asr.losses.angularloss.AngularSoftmaxLoss # you could also use cross-entrophy loss + scale: 30 + margin: 0.2 + + optim: + name: adamw + lr: 0.001 + weight_decay: 0.0002 + + # scheduler setup + sched: + name: CosineAnnealing + warmup_ratio: 0.1 + min_lr: 0.00001 + +trainer: + devices: -1 # number of gpus (trained on four nodes - each node has 8 gpus) + max_epochs: 250 + max_steps: -1 # computed at runtime if not set + num_nodes: 1 + accelerator: gpu + strategy: ddp + accumulate_grad_batches: 1 + deterministic: False + enable_checkpointing: False + logger: False + log_every_n_steps: 1 # Interval of logging. + val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations + gradient_clip_val: 1.0 + +exp_manager: + exp_dir: null + name: ${name} + create_tensorboard_logger: True + create_checkpoint_callback: True + checkpoint_callback_params: + # in case of multiple validation sets, first one is used + monitor: "val_loss" + mode: "min" + save_top_k: 1 + + # you need to set these two to True to continue the training + resume_if_exists: true + resume_ignore_no_checkpoint: true + + # You may use this section to create a W&B logger + create_wandb_logger: false + wandb_logger_kwargs: + name: null + project: null diff --git a/examples/asr/conf/ssl/nest/multi_layer_feat/nest_titanet_small.yaml b/examples/asr/conf/ssl/nest/multi_layer_feat/nest_titanet_small.yaml new file mode 100644 index 000000000000..8dc0f48c5011 --- /dev/null +++ b/examples/asr/conf/ssl/nest/multi_layer_feat/nest_titanet_small.yaml @@ -0,0 +1,238 @@ +# This is an example config that uses a NEST model as feature extractors, by using multi-layer feature aggregation. +# The major modification is to replace `model.preprocessor` with the one in this config. + +name: "NEST_MFA_Tune_TitaNet" + +model: + sample_rate: 16000 + + train_ds: + manifest_filepath: ??? + sample_rate: ${model.sample_rate} + labels: null + batch_size: 64 + num_workers: 8 + shuffle: True + augmentor: + noise: + manifest_path: ??? + prob: 0.5 + min_snr_db: 0 + max_snr_db: 15 + + speed: + prob: 0.5 + sr: ${model.sample_rate} + resample_type: 'kaiser_fast' + min_speed_rate: 0.95 + max_speed_rate: 1.05 + + impulse: + prob: 0.5 + manifest_path: ??? + + validation_ds: + manifest_filepath: ??? + sample_rate: ${model.sample_rate} + labels: null + batch_size: 128 + num_workers: 8 + shuffle: False + + # Overwrite the original 'preprocessor' and replace it with multi-layer feature aggregation from ConformerEncoder + preprocessor: + _target_: nemo.collections.asr.modules.ssl_modules.multi_layer_feat.ConformerMultiLayerFeaturePreprocessor + layer_idx_list: [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16] + freeze_encoder: True + aggregator: + _target_: nemo.collections.asr.modules.ssl_modules.multi_layer_feat.Aggregator + mode: "weighted_sum" + weights: null + layer_idx_list: ${model.preprocessor.layer_idx_list} + + # the actual preprocessor to use + preprocessor: + _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor + normalize: "per_feature" + window_size: 0.025 + sample_rate: ${model.sample_rate} + window_stride: 0.01 + window: "hann" + features: 80 + n_fft: 512 + frame_splicing: 1 + dither: 0.00001 + stft_conv: false + + spec_augment: + _target_: nemo.collections.asr.modules.SpectrogramAugmentation + freq_masks: 3 + freq_width: 4 + time_masks: 5 + time_width: 0.05 + + # this has to match with the ConformerEncoder config in NEST pretrain config + encoder: + _target_: nemo.collections.asr.modules.ConformerEncoder + feat_in: ${model.preprocessor.preprocessor.features} + feat_out: -1 # you may set it if you need different output size other than the default d_model + n_layers: 17 + d_model: 512 + + # Sub-sampling params + subsampling: dw_striding # vggnet, striding, stacking or stacking_norm, dw_striding + subsampling_factor: 8 # must be power of 2 for striding and vggnet + subsampling_conv_channels: 256 # -1 sets it to d_model + causal_downsampling: false + + # Feed forward module's params + ff_expansion_factor: 4 + + # Multi-headed Attention Module's params + self_attention_model: rel_pos # rel_pos or abs_pos + n_heads: 8 # may need to be lower for smaller d_models + # [left, right] specifies the number of steps to be seen from left and right of each step in self-attention + att_context_size: [-1, -1] # -1 means unlimited context + att_context_style: regular # regular or chunked_limited + xscaling: true # scales up the input embeddings by sqrt(d_model) + untie_biases: true # unties the biases of the TransformerXL layers + pos_emb_max_len: 5000 + + # Convolution module's params + conv_kernel_size: 9 + conv_norm_type: 'batch_norm' # batch_norm or layer_norm or groupnormN (N specifies the number of groups) + # conv_context_size can be"causal" or a list of two integers while conv_context_size[0]+conv_context_size[1]+1==conv_kernel_size + # null means [(kernel_size-1)//2, (kernel_size-1)//2], and 'causal' means [(kernel_size-1), 0] + conv_context_size: null + + ### regularization + dropout: 0.1 # The dropout used in most of the Conformer Modules + dropout_pre_encoder: 0.1 # The dropout used before the encoder + dropout_emb: 0.0 # The dropout used for embeddings + dropout_att: 0.1 # The dropout for multi-headed attention modules + + # set to non-zero to enable stochastic depth + stochastic_depth_drop_prob: 0.0 + stochastic_depth_mode: linear # linear or uniform + stochastic_depth_start_layer: 1 + + encoder: + _target_: nemo.collections.asr.modules.ConvASREncoder + feat_in: ${model.preprocessor.encoder.d_model} + activation: relu + conv_mask: true + + jasper: + - filters: ${model.model_defaults.filters} + repeat: 1 + kernel: [3] + stride: [1] + dilation: [1] + dropout: 0.0 + residual: false + separable: ${model.model_defaults.separable} + se: ${model.model_defaults.se} + se_context_size: ${model.model_defaults.se_context_size} + + - filters: ${model.model_defaults.filters} + repeat: ${model.model_defaults.repeat} + kernel: [7] + stride: [1] + dilation: [1] + dropout: ${model.model_defaults.dropout} + residual: true + separable: ${model.model_defaults.separable} + se: ${model.model_defaults.se} + se_context_size: ${model.model_defaults.se_context_size} + + - filters: ${model.model_defaults.filters} + repeat: ${model.model_defaults.repeat} + kernel: [11] + stride: [1] + dilation: [1] + dropout: ${model.model_defaults.dropout} + residual: true + separable: ${model.model_defaults.separable} + se: ${model.model_defaults.se} + se_context_size: ${model.model_defaults.se_context_size} + + - filters: ${model.model_defaults.filters} + repeat: ${model.model_defaults.repeat} + kernel: [15] + stride: [1] + dilation: [1] + dropout: ${model.model_defaults.dropout} + residual: true + separable: ${model.model_defaults.separable} + se: ${model.model_defaults.se} + se_context_size: ${model.model_defaults.se_context_size} + + - filters: &enc_feat_out 3072 + repeat: 1 + kernel: [1] + stride: [1] + dilation: [1] + dropout: 0.0 + residual: false + separable: ${model.model_defaults.separable} + se: ${model.model_defaults.se} + se_context_size: ${model.model_defaults.se_context_size} + + decoder: + _target_: nemo.collections.asr.modules.SpeakerDecoder + feat_in: *enc_feat_out + num_classes: 7205 + pool_mode: 'attention' + emb_sizes: 192 + + loss: + _target_: nemo.collections.asr.losses.angularloss.AngularSoftmaxLoss # you could also use cross-entrophy loss + scale: 30 + margin: 0.2 + + optim: + name: adamw + lr: 0.005 + weight_decay: 0.0002 + + # scheduler setup + sched: + name: CosineAnnealing + warmup_ratio: 0.1 + min_lr: 1e-6 + +trainer: + devices: -1 # number of gpus (trained on four nodes - each node has 8 gpus) + max_epochs: 250 + max_steps: -1 # computed at runtime if not set + num_nodes: 1 + accelerator: gpu + strategy: ddp + accumulate_grad_batches: 1 + deterministic: False + enable_checkpointing: False + logger: False + log_every_n_steps: 1 # Interval of logging. + val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations + gradient_clip_val: 1.0 + +exp_manager: + exp_dir: null + name: ${name} + create_tensorboard_logger: True + create_checkpoint_callback: True + checkpoint_callback_params: + # in case of multiple validation sets, first one is used + monitor: "val_loss" + mode: "min" + save_top_k: 1 + + # you need to set these two to True to continue the training + resume_if_exists: true + resume_ignore_no_checkpoint: true + + # You may use this section to create a W&B logger + create_wandb_logger: false + wandb_logger_kwargs: + name: null + project: null diff --git a/examples/asr/conf/ssl/nest/nest_fast-conformer.yaml b/examples/asr/conf/ssl/nest/nest_fast-conformer.yaml new file mode 100644 index 000000000000..054c66830d65 --- /dev/null +++ b/examples/asr/conf/ssl/nest/nest_fast-conformer.yaml @@ -0,0 +1,254 @@ +# This config contains the default values for self-supervised pre-training of a FastConformer model +# +# Here are the recommended configs for different variants of FastConformer, other parameters are the same as in this config file. +# +# +--------------+---------+---------+----------+----------------+--------------+------------+ +# | Model | d_model | n_heads | n_layers |conv_kernel_size| weight_decay | xscaling | +# +==============+=========+========+===========+================+==============+============+ +# | Small (14M) | 176 | 4 | 16 | 9 | 0.0 | True | +# +--------------+---------+--------+-----------+----------------+--------------+------------+ +# | Medium (32M) | 256 | 4 | 16 | 9 | 1e-3 | True | +# +--------------+---------+--------+-----------+----------------+--------------+------------+ +# | Large (120M) | 512 | 8 | 17 | 9 | 1e-3 | True | +# +--------------+---------+--------+-----------+----------------+--------------+------------+ +# | XLarge (616M)| 1024 | 8 | 24 | 9 | 1e-3 | False | +# +--------------+---------+--------+-----------+----------------+--------------+------------+ +# | XXLarge(1.2B)| 1024 | 8 | 42 | 9 | 1e-3 | False | +# +--------------------------------------------------------------+--------------+------------+ + + +name: "NEST-FastConformer-SSL" + +model: + sample_rate: 16000 + num_classes: 8192 + num_books: 1 + code_dim: 16 + squeeze_single: false + mask_position: pre_conv # position to apply masking, before or after conv subsampling, choices in ['pre_conv', 'post_conv'] + + train_ds: + manifest_filepath: ??? + noise_manifest: null + sample_rate: ${model.sample_rate} + batch_size: 8 # you may increase batch_size if your memory allows + shuffle: true + num_workers: 8 + pin_memory: true + use_start_end_token: false + trim_silence: false + max_duration: 60.0 + min_duration: 1.0 + drop_last: true + is_concat: false + concat_sampling_technique: temperature + concat_sampling_temperature: 1.0 + # tarred datasets + is_tarred: false + tarred_audio_filepaths: null + shuffle_n: 2048 + # bucketing params + bucketing_strategy: "synced_randomized" + bucketing_batch_size: null + # batch augmentation + batch_augmentor: + _target_: nemo.collections.asr.modules.ssl_modules.MultiSpeakerNoiseAugmentation + prob: 0.5 # prob of activating the augmentation + noise_ratio: 0.5 # prob of applying noise aug, otherwise apply speech augmentation + min_r_speech: -5.0 # min energy ratio when applying speech augmentation + max_r_speech: 5.0 # max energy ratio when applying speech augmentation + min_r_noise: -5.0 # min energy ratio when applying noise augmentation + max_r_noise: 20.0 # max energy ratio when applying noise augmentation + min_mix_rate: 0.5 # min ratio of the input audio that would be augmented + max_mix_rate: 0.5 # max ratio of the input audio that would be augmented + min_num_segments: 1 # min num of segments that consititute the noise audio + max_num_segments: 1 # max num of segments that consititute the noise audio + min_num_speakers: 1 # min num of extra speakers to add + max_num_speakers: 1 # max num of extra speakers to add + + validation_ds: + manifest_filepath: ??? + noise_manifest: null + sample_rate: ${model.sample_rate} + batch_size: 8 # you may increase batch_size if your memory allows + shuffle: false + num_workers: 8 + pin_memory: true + use_start_end_token: false + max_duration: 60.0 + min_duration: 1.0 + # batch augmentation + batch_augmentor: + _target_: nemo.collections.asr.modules.ssl_modules.MultiSpeakerNoiseAugmentation + prob: 0.5 + noise_ratio: 0.5 + min_r_speech: -5.0 + max_r_speech: 5.0 + min_r_noise: -5.0 + max_r_noise: 20.0 + min_mix_rate: 0.5 + max_mix_rate: 0.5 + min_num_segments: 1 # min num of segments that consititute the noise audio + max_num_segments: 1 # max num of segments that consititute the noise audio + min_num_speakers: 1 # min num of extra speakers to add + max_num_speakers: 1 # max num of extra speakers to add + + preprocessor: + _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor + sample_rate: ${model.sample_rate} + normalize: "per_feature" + window_size: 0.025 + window_stride: 0.01 + window: "hann" + features: 80 + n_fft: 512 + log: true + frame_splicing: 1 + dither: 0.00001 + pad_to: 16 + pad_value: 0.0 + + # spec_augment is not actually used, just to avoid init error + spec_augment: + _target_: nemo.collections.asr.modules.SpectrogramAugmentation + freq_masks: 0 # set to zero to disable it + time_masks: 0 # set to zero to disable it + freq_width: 27 + time_width: 0.05 + + masking: + _target_: nemo.collections.asr.modules.RandomBlockMasking + block_size: 40 # for pre_conv masking, 10ms per frame, 400ms per block with block_size=40 + mask_prob: 0.01 # for allow_overlap=True, this means the mask prob for each frame; otherwise it means the overall masked proportion + feat_in: ${model.preprocessor.features} + freeze: true + allow_overlap: true + + quantizer: + _target_: nemo.collections.asr.modules.RandomProjectionVectorQuantizer + feat_in: ${model.preprocessor.features} + code_dim: ${model.code_dim} + num_books: ${model.num_books} + num_classes: ${model.num_classes} + dist_fn: "l2" # choices=["l2", "cosine"] + freeze: true + squeeze_single: ${model.squeeze_single} + combine_time_steps: ${model.encoder.subsampling_factor} # conformer sub-sampling ratio + + encoder: + _target_: nemo.collections.asr.modules.ConformerEncoder + feat_in: ${model.preprocessor.features} + feat_out: -1 # you may set it if you need different output size other than the default d_model + n_layers: 17 + d_model: 512 + use_bias: True # whether to apply bias in the feedforward, MHA and convolution modules + + # Sub-sampling params + subsampling: dw_striding # vggnet, striding, stacking or stacking_norm, dw_striding + subsampling_factor: 8 # must be power of 2 for striding and vggnet + subsampling_conv_channels: 256 # -1 sets it to d_model + causal_downsampling: false + + # Feed forward module's params + ff_expansion_factor: 4 + + # Multi-headed Attention Module's params + self_attention_model: rel_pos # rel_pos or abs_pos + n_heads: 8 # may need to be lower for smaller d_models + # [left, right] specifies the number of steps to be seen from left and right of each step in self-attention + att_context_size: [-1, -1] # -1 means unlimited context + att_context_style: regular # regular or chunked_limited + xscaling: true # scales up the input embeddings by sqrt(d_model) + untie_biases: true # unties the biases of the TransformerXL layers + pos_emb_max_len: 5000 + + # Convolution module's params + conv_kernel_size: 9 + conv_norm_type: 'batch_norm' # batch_norm or layer_norm or groupnormN (N specifies the number of groups) + # conv_context_size can be"causal" or a list of two integers while conv_context_size[0]+conv_context_size[1]+1==conv_kernel_size + # null means [(kernel_size-1)//2, (kernel_size-1)//2], and 'causal' means [(kernel_size-1), 0] + conv_context_size: null + + ### regularization + dropout: 0.1 # The dropout used in most of the Conformer Modules + dropout_pre_encoder: 0.1 # The dropout used before the encoder + dropout_emb: 0.0 # The dropout used for embeddings + dropout_att: 0.1 # The dropout for multi-headed attention modules + + # set to non-zero to enable stochastic depth + stochastic_depth_drop_prob: 0.0 + stochastic_depth_mode: linear # linear or uniform + stochastic_depth_start_layer: 1 + + decoder: + _target_: nemo.collections.asr.modules.MultiSoftmaxDecoder + feat_in: ${model.encoder.d_model} + num_classes: ${model.num_classes} + num_decoders: ${model.num_books} + squeeze_single: ${model.squeeze_single} + use_bias: true + + loss: + _target_: nemo.collections.asr.losses.MultiMLMLoss + combine_time_steps: ${model.encoder.subsampling_factor} # conformer sub-sampling ratio for 'pre_conv', 1 for 'post_conv' + mask_threshold: 0.8 + num_decoders: ${model.num_books} + squeeze_single: ${model.squeeze_single} + + optim: + name: adamw + lr: 5.0 + # optimizer arguments + betas: [0.9, 0.98] + weight_decay: 1e-3 + + # scheduler setup + sched: + name: NoamAnnealing + d_model: ${model.encoder.d_model} + # scheduler config override + warmup_steps: 25000 + warmup_ratio: null + min_lr: 1e-6 + +trainer: + devices: -1 # number of GPUs, -1 would use all available GPUs + num_nodes: 1 + max_epochs: 1000 + max_steps: -1 # computed at runtime if not set + val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations + accelerator: auto + strategy: ddp_find_unused_parameters_true + accumulate_grad_batches: 1 + gradient_clip_val: 0.0 + precision: 32 # 16, 32, or bf16 + log_every_n_steps: 10 # Interval of logging. + enable_progress_bar: True + num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it + check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs + sync_batchnorm: true + enable_checkpointing: False # Provided by exp_manager + logger: false # Provided by exp_manager + benchmark: false # needs to be false for models with variable-length speech input as it slows down training + +exp_manager: + exp_dir: null + name: ${name} + # filename: '${name}--{${exp_manager.checkpoint_callback_params.monitor}:.3f}-{step}-{epoch}' + create_tensorboard_logger: true + create_checkpoint_callback: true + checkpoint_callback_params: + # in case of multiple validation sets, first one is used + monitor: "val_loss" + mode: "min" + save_top_k: 1 + + # you need to set these two to True to continue the training + resume_if_exists: true + resume_ignore_no_checkpoint: true + + # You may use this section to create a W&B logger + create_wandb_logger: false + wandb_logger_kwargs: + name: null + project: null diff --git a/examples/asr/speech_pretraining/README.md b/examples/asr/speech_pretraining/README.md index 75ae9e58c79b..aeafcf69292b 100644 --- a/examples/asr/speech_pretraining/README.md +++ b/examples/asr/speech_pretraining/README.md @@ -1,27 +1,7 @@ -# Speech Pre-training via Self Supervised Learning +# Speech Self-Supervised Learning -This directory contains example scripts to train ASR models using various Self Supervised Losses. +This directory contains example scripts to self-supervised speech models. -The model's pretrained here can further be finetuned on specific labeled data in further steps. - -# Model execution overview - -The training scripts in this directory execute in the following order. When preparing your own training-from-scratch / fine-tuning scripts, please follow this order for correct training/inference. - -```mermaid - -graph TD - A[Hydra Overrides + Yaml Config] --> B{Config} - B --> |Init| C[Trainer] - C --> D[ExpManager] - B --> D[ExpManager] - C --> E[Model] - B --> |Init| E[Model] - E --> |Constructor| G(Setup Train + Validation Data loaders) - G --> H(Setup Optimization) - H --> I[Maybe init from pretrained] - I --> J["trainer.fit(model)"] - -``` - -During restoration of the model, you may pass the Trainer to the restore_from / from_pretrained call, or set it after the model has been initialized by using `model.set_trainer(Trainer)`. \ No newline at end of file +There are two main types of supported self-supervised learning methods: +- [Wav2vec-BERT](https://arxiv.org/abs/2108.06209): `speech_pre_training.py` +- [NEST](https://arxiv.org/abs/2408.13106): `masked_token_pred_pretrain.py` diff --git a/examples/asr/speech_pretraining/downstream/speech_classification_mfa_train.py b/examples/asr/speech_pretraining/downstream/speech_classification_mfa_train.py new file mode 100644 index 000000000000..3a256c7ab2d3 --- /dev/null +++ b/examples/asr/speech_pretraining/downstream/speech_classification_mfa_train.py @@ -0,0 +1,108 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict + +import pytorch_lightning as pl +import torch +from omegaconf import OmegaConf + +from nemo.collections.asr.models import EncDecSpeakerLabelModel, SpeechEncDecSelfSupervisedModel +from nemo.core.classes.common import typecheck +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.exp_manager import exp_manager + +typecheck.set_typecheck_enabled(enabled=False) + +""" +Example script for training a speech classification model with a self-supervised pre-trained encoder, and +use the SSL encoder for multi-layer feature extraction. + +# Example of training a speaker classification model with a self-supervised pre-trained encoder +```sh +python speech_classification_mfa_train.py \ + # (Optional: --config-path= --config-name=) \ + ++init_from_nemo_model= \ + # or use ++init_from_pretrained_model= \ + model.train_ds.manifest_filepath= \ + model.validation_ds.manifest_filepath= \ + trainer.devices=-1 \ + trainer.accelerator="gpu" \ + strategy="ddp" \ + trainer.max_epochs=100 \ + model.optim.name="adamw" \ + model.optim.lr=0.001 \ + model.optim.betas=[0.9,0.999] \ + model.optim.weight_decay=0.0001 \ + model.optim.sched.warmup_steps=2000 + exp_manager.create_wandb_logger=True \ + exp_manager.wandb_logger_kwargs.name="" \ + exp_manager.wandb_logger_kwargs.project="" +``` + +""" + + +def load_ssl_encoder(model, cfg): + if cfg.get("init_from_ptl_ckpt", None) is not None: + state_dict = torch.load(cfg.init_from_ptl_ckpt, map_location='cpu')['state_dict'] + logging.info(f"Loading encoder from PyTorch Lightning checkpoint: {cfg.init_from_ptl_ckpt}") + elif cfg.get("init_from_nemo_model", None) is not None: + ssl_model = SpeechEncDecSelfSupervisedModel.restore_from(cfg.init_from_nemo_model, map_location='cpu') + state_dict = ssl_model.state_dict() + logging.info(f"Loading encoder from NeMo model: {cfg.init_from_nemo_model}") + elif cfg.get("init_from_pretrained_model", None) is not None: + ssl_model = SpeechEncDecSelfSupervisedModel.from_pretrained(cfg.init_from_pretrained_model, map_location='cpu') + state_dict = ssl_model.state_dict() + logging.info(f"Loading encoder from pretrained model: {cfg.init_from_pretrained_model}") + else: + logging.info("No model checkpoint or pretrained model specified for encoder initialization.") + return model + + encoder_state_dict = OrderedDict() + for key, value in state_dict.items(): + if key.startswith('encoder.'): + encoder_state_dict[f'preprocessor.feature_extractor.{key}'] = value + + model.load_state_dict(encoder_state_dict, strict=False) + logging.info("Loaded ssl encoder state dict.") + + return model + + +@hydra_runner(config_path="../conf/ssl/nest/multi_layer_feat", config_name="nest_ecapa_tdnn_small") +def main(cfg): + + logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') + trainer = pl.Trainer(**cfg.trainer) + exp_manager(trainer, cfg.get("exp_manager", None)) + + speaker_model = EncDecSpeakerLabelModel(cfg=cfg.model, trainer=trainer) + + if cfg.model.preprocessor.get("encoder", None) is not None: + # multi-layer feature extractor + speaker_model = load_ssl_encoder(speaker_model, cfg) + else: + speaker_model.maybe_init_from_pretrained_checkpoint(cfg) + + trainer.fit(speaker_model) + + if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.manifest_filepath is not None: + if speaker_model.prepare_test(trainer): + trainer.test(speaker_model) + + +if __name__ == '__main__': + main() diff --git a/examples/asr/speech_pretraining/masked_token_pred_pretrain.py b/examples/asr/speech_pretraining/masked_token_pred_pretrain.py new file mode 100644 index 000000000000..83729dfd9d67 --- /dev/null +++ b/examples/asr/speech_pretraining/masked_token_pred_pretrain.py @@ -0,0 +1,63 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytorch_lightning as pl +from omegaconf import OmegaConf + +from nemo.collections.asr.models.ssl_models import EncDecDenoiseMaskedTokenPredModel +from nemo.core.config import hydra_runner +from nemo.utils import logging +from nemo.utils.exp_manager import exp_manager + + +""" +# Example of training a self-supervised denoising masksed token prediction model +```sh +python pretrain_masked_token_pred.py \ + # (Optional: --config-path= --config-name=) \ + model.train_ds.manifest_filepath= \ + model.validation_ds.manifest_filepath= \ + trainer.devices=-1 \ + trainer.accelerator="gpu" \ + strategy="ddp" \ + trainer.max_epochs=100 \ + model.optim.name="adamw" \ + model.optim.lr=0.001 \ + model.optim.betas=[0.9,0.999] \ + model.optim.weight_decay=0.0001 \ + model.optim.sched.warmup_steps=2000 + exp_manager.create_wandb_logger=True \ + exp_manager.wandb_logger_kwargs.name="" \ + exp_manager.wandb_logger_kwargs.project="" +``` +""" + + +@hydra_runner(config_path="../conf/ssl/nest", config_name="nest_fast-conformer") +def main(cfg): + logging.info(f"Hydra config: {OmegaConf.to_yaml(cfg)}") + + trainer = pl.Trainer(**cfg.trainer) + exp_manager(trainer, cfg.get("exp_manager", None)) + asr_model = EncDecDenoiseMaskedTokenPredModel(cfg=cfg.model, trainer=trainer) + + # Initialize the weights of the model from another model, if provided via config + asr_model.maybe_init_from_pretrained_checkpoint(cfg) + + trainer.fit(asr_model) + + +if __name__ == "__main__": + main() diff --git a/examples/asr/speech_translation/translate_speech.py b/examples/asr/speech_translation/translate_speech.py index 42394001255f..47717f562774 100644 --- a/examples/asr/speech_translation/translate_speech.py +++ b/examples/asr/speech_translation/translate_speech.py @@ -100,6 +100,9 @@ class TranslationConfig: # if True, will also skip writing anything to the output file return_translations: bool = False + presort_manifest: bool = False # sort manifest by duration before inference + pred_name_postfix: str = "translation" # postfix to add to the audio filename for the output + @hydra_runner(config_name="TranslationConfig", schema=TranslationConfig) def main(cfg: TranslationConfig) -> Union[TranslationConfig, List[str]]: @@ -176,8 +179,8 @@ def main(cfg: TranslationConfig) -> Union[TranslationConfig, List[str]]: # translate audio with torch.amp.autocast(asr_model.device.type, enabled=cfg.amp): with torch.no_grad(): - translations = asr_model.translate( - paths2audio_files=filepaths, + translations = asr_model.transcribe( + audio=filepaths, batch_size=cfg.batch_size, return_hypotheses=return_hypotheses, ) diff --git a/examples/slu/speech_intent_slot/README.md b/examples/slu/speech_intent_slot/README.md index ac11e43f2ace..a91175431513 100644 --- a/examples/slu/speech_intent_slot/README.md +++ b/examples/slu/speech_intent_slot/README.md @@ -59,7 +59,7 @@ Run with the default config that uses ASR-pretrained encoder on NeMo ASR-set 3.0 ```bash DATA_DIR="./slurp_data" EXP_NAME="slurp_conformer_transformer_large" -CUDA_VISIBLE_DEVICES=0 python run_speech_intent_slot_train.py \ +CUDA_VISIBLE_DEVICES=0 python speech_intent_slot_train.py \ --config-path="./configs" --config-name=conformer_transformer_large_bpe \ model.train_ds.manifest_filepath="[${DATA_DIR}/train_slu.json,${DATA_DIR}/train_synthetic_slu.json]" \ model.validation_ds.manifest_filepath="${DATA_DIR}/devel_slu.json" \ @@ -88,7 +88,7 @@ CKPT_AVG_DIR="../../../examples/slu/speech_intent_slot/${CKPT_DIR}" python ../../../scripts/checkpoint_averaging/checkpoint_averaging.py $CKPT_AVG_DIR NEMO_MODEL="${CKPT_DIR}/${EXP_NAME}-averaged.nemo" -CUDA_VISIBLE_DEVICES=0 python run_speech_intent_slot_eval.py \ +CUDA_VISIBLE_DEVICES=0 python speech_intent_slot_eval.py \ dataset_manifest="${DATA_DIR}/test_slu.json" \ model_path=$NEMO_MODEL \ batch_size=32 \ diff --git a/examples/slu/speech_intent_slot/configs/fastconformer_transformer_large_bpe.yaml b/examples/slu/speech_intent_slot/configs/fastconformer_transformer_large_bpe.yaml new file mode 100644 index 000000000000..6b700e9001f7 --- /dev/null +++ b/examples/slu/speech_intent_slot/configs/fastconformer_transformer_large_bpe.yaml @@ -0,0 +1,221 @@ +# Example config for speech intent classification and slot filling with FastConformer-Transformer architecture. + +name: "FastConformer-Transformer-BPE" + +pretrained_encoder: + name: stt_en_fastconformer_ctc_large + freeze: false + +model: + sample_rate: 16000 + log_prediction: true # enables logging sample predictions in the output during training + skip_nan_grad: false + + train_ds: + manifest_filepath: ??? + sample_rate: ${model.sample_rate} + batch_size: 16 # 16 for 32GB GPUs + shuffle: true + num_workers: 8 + pin_memory: true + use_start_end_token: true + trim_silence: false + max_duration: 11.0 + min_duration: 0.0 + # tarred datasets + is_tarred: false + tarred_audio_filepaths: null + shuffle_n: 2048 + # bucketing params + bucketing_strategy: "synced_randomized" + bucketing_batch_size: null + + validation_ds: + manifest_filepath: ??? + sample_rate: ${model.sample_rate} + batch_size: 32 # you may increase batch_size if your memory allows + shuffle: false + num_workers: 8 + pin_memory: true + use_start_end_token: true + + test_ds: + manifest_filepath: null + sample_rate: ${model.sample_rate} + batch_size: 32 # you may increase batch_size if your memory allows + shuffle: false + num_workers: 8 + pin_memory: true + use_start_end_token: true + + # you may find more detail on how to train a tokenizer at: /scripts/tokenizers/process_asr_text_tokenizer.py + tokenizer: + dir: ??? # path to directory which contains either tokenizer.model (bpe) or vocab.txt (wpe) + type: bpe # Can be either bpe (SentencePiece tokenizer) or wpe (WordPiece tokenizer) + + preprocessor: + _target_: nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor + sample_rate: ${model.sample_rate} + normalize: "per_feature" + window_size: 0.025 + window_stride: 0.01 + window: "hann" + features: 80 + n_fft: 512 + log: true + frame_splicing: 1 + dither: 0.00001 + pad_to: 0 + pad_value: 0.0 + + spec_augment: + _target_: nemo.collections.asr.modules.SpectrogramAugmentation + freq_masks: 2 # set to zero to disable it + # you may use lower time_masks for smaller models to have a faster convergence + time_masks: 10 # set to zero to disable it + freq_width: 27 + time_width: 0.05 + + encoder: + _target_: nemo.collections.asr.modules.ConformerEncoder + feat_in: ${model.preprocessor.features} + feat_out: -1 # you may set it if you need different output size other than the default d_model + n_layers: 18 + d_model: 512 + + # Sub-sampling params + subsampling: dw_striding # vggnet, striding, stacking or stacking_norm, dw_striding + subsampling_factor: 8 # must be power of 2 for striding and vggnet + subsampling_conv_channels: 256 # -1 sets it to d_model + causal_downsampling: false + + # Feed forward module's params + ff_expansion_factor: 4 + + # Multi-headed Attention Module's params + self_attention_model: rel_pos # rel_pos or abs_pos + n_heads: 8 # may need to be lower for smaller d_models + # [left, right] specifies the number of steps to be seen from left and right of each step in self-attention + att_context_size: [-1, -1] # -1 means unlimited context + att_context_style: regular # regular or chunked_limited + xscaling: true # scales up the input embeddings by sqrt(d_model) + untie_biases: true # unties the biases of the TransformerXL layers + pos_emb_max_len: 5000 + + # Convolution module's params + conv_kernel_size: 9 + conv_norm_type: 'batch_norm' # batch_norm or layer_norm or groupnormN (N specifies the number of groups) + # conv_context_size can be"causal" or a list of two integers while conv_context_size[0]+conv_context_size[1]+1==conv_kernel_size + # null means [(kernel_size-1)//2, (kernel_size-1)//2], and 'causal' means [(kernel_size-1), 0] + conv_context_size: null + + ### regularization + dropout: 0.1 # The dropout used in most of the Conformer Modules + dropout_pre_encoder: 0.1 # The dropout used before the encoder + dropout_emb: 0.0 # The dropout used for embeddings + dropout_att: 0.1 # The dropout for multi-headed attention modules + + # set to non-zero to enable stochastic depth + stochastic_depth_drop_prob: 0.0 + stochastic_depth_mode: linear # linear or uniform + stochastic_depth_start_layer: 1 + + embedding: + _target_: nemo.collections.asr.modules.transformer.TransformerEmbedding + vocab_size: -1 + hidden_size: ${model.encoder.d_model} + max_sequence_length: 512 + num_token_types: 1 + embedding_dropout: 0.0 + learn_positional_encodings: false + + decoder: + _target_: nemo.collections.asr.modules.transformer.TransformerDecoder + num_layers: 3 + hidden_size: ${model.encoder.d_model} + inner_size: 2048 + num_attention_heads: 8 + attn_score_dropout: 0.0 + attn_layer_dropout: 0.0 + ffn_dropout: 0.0 + + classifier: + _target_: nemo.collections.common.parts.MultiLayerPerceptron + hidden_size: ${model.encoder.d_model} + num_classes: -1 + num_layers: 1 + activation: 'relu' + log_softmax: true + + loss: + label_smoothing: 0.0 + + sequence_generator: + type: greedy # choices=[greedy, topk, beam] + max_sequence_length: ${model.embedding.max_sequence_length} + temperature: 1.0 # for top-k sampling + beam_size: 1 # K for top-k sampling, N for beam search + len_pen: 0 # for beam-search + + optim_param_groups: + encoder: + lr: 0.0002 + + optim: + name: adamw + lr: 0.0003 + # optimizer arguments + betas: [0.9, 0.98] + # less necessity for weight_decay as we already have large augmentations with SpecAug + # you may need weight_decay for large models, stable AMP training, small datasets, or when lower augmentations are used + weight_decay: 0.0 + + # scheduler setup + sched: + name: CosineAnnealing # WarmupAnnealing + warmup_steps: 2000 + warmup_ratio: null + min_lr: 1e-5 + +trainer: + devices: -1 # number of GPUs, -1 would use all available GPUs + num_nodes: 1 + max_epochs: 100 + max_steps: -1 # computed at runtime if not set + val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations + accelerator: auto + strategy: ddp_find_unused_parameters_true + accumulate_grad_batches: 1 + gradient_clip_val: 0.0 + precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP. + log_every_n_steps: 20 # Interval of logging. + num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it + check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs + sync_batchnorm: true + enable_checkpointing: False # Provided by exp_manager + logger: false # Provided by exp_manager + benchmark: false # needs to be false for models with variable-length speech input as it slows down training + +exp_manager: + exp_dir: null + name: ${name} + create_tensorboard_logger: true + create_checkpoint_callback: true + checkpoint_callback_params: + # in case of multiple validation sets, first one is used + monitor: "val_wer" + mode: "min" + save_top_k: 5 + always_save_nemo: true # saves the checkpoints as nemo files instead of PTL checkpoints + save_best_model: false + + resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc. + # you need to set these two to True to continue the training + resume_if_exists: false + resume_ignore_no_checkpoint: false + + # You may use this section to create a W&B logger + create_wandb_logger: false + wandb_logger_kwargs: + name: null + project: null diff --git a/examples/slu/speech_intent_slot/run_speech_intent_slot_eval.py b/examples/slu/speech_intent_slot/speech_intent_slot_eval.py similarity index 96% rename from examples/slu/speech_intent_slot/run_speech_intent_slot_eval.py rename to examples/slu/speech_intent_slot/speech_intent_slot_eval.py index 8ed0abead342..7ce1315cf961 100644 --- a/examples/slu/speech_intent_slot/run_speech_intent_slot_eval.py +++ b/examples/slu/speech_intent_slot/speech_intent_slot_eval.py @@ -31,7 +31,7 @@ @dataclass class EvaluationConfig(InferenceConfig): dataset_manifest: str = MISSING - output_filename: Optional[str] = "evaluation_transcripts.json" + output_filename: Optional[str] = None average: str = "micro" full: bool = False errors: bool = False @@ -43,7 +43,8 @@ class EvaluationConfig(InferenceConfig): def main(cfg: EvaluationConfig): torch.set_grad_enabled(False) - cfg.output_filename = str(Path(Path(cfg.model_path).parent) / Path("predictions.json")) + if not cfg.only_score_manifest and not cfg.output_filename: + cfg.output_filename = str(Path(Path(cfg.model_path).parent) / Path("predictions.json")) if is_dataclass(cfg): cfg = OmegaConf.structured(cfg) diff --git a/examples/slu/speech_intent_slot/run_speech_intent_slot_train.py b/examples/slu/speech_intent_slot/speech_intent_slot_train.py similarity index 86% rename from examples/slu/speech_intent_slot/run_speech_intent_slot_train.py rename to examples/slu/speech_intent_slot/speech_intent_slot_train.py index d8989bf01cdd..a9999d4d4682 100644 --- a/examples/slu/speech_intent_slot/run_speech_intent_slot_train.py +++ b/examples/slu/speech_intent_slot/speech_intent_slot_train.py @@ -88,11 +88,17 @@ def main(cfg): pretrained_encoder_name = cfg.pretrained_encoder.name if pretrained_encoder_name is not None: if Path(pretrained_encoder_name).is_file(): - logging.info(f"Loading pretrained encoder from local: {pretrained_encoder_name}") - pretraind_model = ASRModel.restore_from( - restore_path=pretrained_encoder_name, map_location=torch.device("cpu") - ) - model.encoder.load_state_dict(pretraind_model.encoder.state_dict(), strict=False) + if not pretrained_encoder_name.endswith(".nemo"): + logging.info(f"Loading encoder from PyTorch Lightning checkpoint: {pretrained_encoder_name}") + state_dict = torch.load(pretrained_encoder_name, map_location='cpu')['state_dict'] + pretraind_model = None + else: + logging.info(f"Loading pretrained encoder from NeMo file: {pretrained_encoder_name}") + pretraind_model = ASRModel.restore_from( + restore_path=pretrained_encoder_name, map_location=torch.device("cpu") + ) + state_dict = pretraind_model.state_dict() + model.load_state_dict(state_dict, strict=False) del pretraind_model else: logging.info(f"Loading pretrained encoder from NGC: {pretrained_encoder_name}") diff --git a/examples/speaker_tasks/recognition/conf/titanet-large.yaml b/examples/speaker_tasks/recognition/conf/titanet-large.yaml index e4859678a44e..6d150b1ed5e4 100644 --- a/examples/speaker_tasks/recognition/conf/titanet-large.yaml +++ b/examples/speaker_tasks/recognition/conf/titanet-large.yaml @@ -31,6 +31,7 @@ model: labels: null batch_size: 128 shuffle: False + is_audio_pair: false # set `true` to calculate EER during validation, see nemo/collections/asr/models/label_models.py for details model_defaults: filters: 1024 diff --git a/nemo/collections/asr/data/audio_to_label.py b/nemo/collections/asr/data/audio_to_label.py index decd6beaa961..8dd65e3fa17a 100644 --- a/nemo/collections/asr/data/audio_to_label.py +++ b/nemo/collections/asr/data/audio_to_label.py @@ -1318,3 +1318,103 @@ def __len__(self): def _collate_fn(self, batch): return _speech_collate_fn(batch, pad_id=0) + + +class AudioPairToLabelDataset(AudioToSpeechLabelDataset): + """ + Dataset class for audio pairs classification tasks, such as calculating EER for speaker verification. + The input manifest file should contain pairs of audio files and a label. It's format is almost the same as + `AudioToSpeechLabelDataset` except that the `audio_filepath` field should be a list of two audio file paths + instead of one, and that `offset` and `duration` are not used as the dataset class will load the whole audio. + + Example of a line in the manifest file: + { + "audio_filepath": ["/path/to/audio_wav_0.wav", "/path/to/audio_wav_1.wav"], + "duration": null, # not used, will load the whole audio + "offset": 0.0, # not used, will load the whole audio + "label": "0" # label for the pair, can be a string or an integer + } + + """ + + @property + def output_types(self) -> Optional[Dict[str, NeuralType]]: + """Returns definitions of module output ports.""" + + output_types = { + 'audio_signal': NeuralType( + ('B', 'T'), + ( + AudioSignal(freq=self._sample_rate) + if self is not None and hasattr(self, '_sample_rate') + else AudioSignal() + ), + ), + 'a_sig_length': NeuralType(tuple('B'), LengthsType()), + 'audio_signal_2': NeuralType( + ('B', 'T'), + ( + AudioSignal(freq=self._sample_rate) + if self is not None and hasattr(self, '_sample_rate') + else AudioSignal() + ), + ), + 'a_sig_length_2': NeuralType(tuple('B'), LengthsType()), + 'label': NeuralType(tuple('B'), LabelsType()), + 'label_length': NeuralType(tuple('B'), LengthsType()), + } + + return output_types + + def __init__( + self, + *, + manifest_filepath: str | List[str], + labels: List[str], + featurizer, + min_duration: float | None = 0.1, + max_duration: float | None = None, + trim: bool = False, + window_length_in_sec: float | None = 8, + shift_length_in_sec: float | None = 1, + normalize_audio: bool = False, + **kwargs, + ): + super().__init__( + manifest_filepath=manifest_filepath, + labels=labels, + featurizer=featurizer, + min_duration=min_duration, + max_duration=max_duration, + trim=trim, + window_length_in_sec=window_length_in_sec, + shift_length_in_sec=shift_length_in_sec, + normalize_audio=normalize_audio, + is_regression_task=False, + cal_labels_occurrence=False, + ) + + def __getitem__(self, index): + sample = self.collection[index] + + audio_pair = sample.audio_file + + features = self.featurizer.process(audio_pair[0], offset=0, duration=None, trim=self.trim) + f, fl = features, torch.tensor(features.shape[0]).long() + + features2 = self.featurizer.process(audio_pair[1], offset=0, duration=None, trim=self.trim) + f2, fl2 = features2, torch.tensor(features2.shape[0]).long() + + t = torch.tensor(self.label2id[sample.label]).long() + tl = torch.tensor(1).long() # For compatibility with collate_fn used later + + return f, fl, f2, fl2, t, tl + + def fixed_seq_collate_fn(self, batch): + audio1, audio_len1, audio2, audio_len2, label, label_len = zip(*batch) + + batch1 = list(zip(audio1, audio_len1, label, label_len)) + a_sig1, a_sig_len1, pair_label, pair_label_len = _fixed_seq_collate_fn(self, batch1) + batch2 = list(zip(audio2, audio_len2, label, label_len)) + a_sig2, a_sig_len2, _, _ = _fixed_seq_collate_fn(self, batch2) + return a_sig1, a_sig_len1, a_sig2, a_sig_len2, pair_label, pair_label_len diff --git a/nemo/collections/asr/data/audio_to_text.py b/nemo/collections/asr/data/audio_to_text.py index 28dc168481ed..d5ece6202da7 100644 --- a/nemo/collections/asr/data/audio_to_text.py +++ b/nemo/collections/asr/data/audio_to_text.py @@ -146,6 +146,7 @@ def __init__( eos_id: Optional[int] = None, pad_id: int = 0, index_by_file_id: bool = False, + manifest_parse_func: Optional[Callable] = None, ): self.parser = parser @@ -156,6 +157,7 @@ def __init__( max_duration=max_duration, max_number=max_utts, index_by_file_id=index_by_file_id, + parse_func=manifest_parse_func, ) self.eos_id = eos_id @@ -423,6 +425,7 @@ class _AudioTextDataset(Dataset): pad_id: Id of pad symbol. Defaults to 0 return_sample_id (bool): whether to return the sample_id as a part of each sample channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`. Uses zero-based indexing. + manifest_parse_func: Optional function to parse manifest entries. Defaults to None. """ @property @@ -452,6 +455,7 @@ def __init__( pad_id: int = 0, return_sample_id: bool = False, channel_selector: Optional[ChannelSelectorType] = None, + manifest_parse_func: Optional[Callable] = None, ): if type(manifest_filepath) == str: manifest_filepath = manifest_filepath.split(",") @@ -468,6 +472,7 @@ def __init__( bos_id=bos_id, eos_id=eos_id, pad_id=pad_id, + manifest_parse_func=manifest_parse_func, ) self.featurizer = WaveformFeaturizer(sample_rate=sample_rate, int_values=int_values, augmentor=augmentor) self.trim = trim @@ -547,6 +552,7 @@ class AudioToCharDataset(_AudioTextDataset): eos_id: Id of end of sequence symbol to append if not None return_sample_id (bool): whether to return the sample_id as a part of each sample channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`. Uses zero-based indexing. + manifest_parse_func: Optional function to parse manifest entries. Defaults to None. """ @property @@ -580,6 +586,7 @@ def __init__( parser: Union[str, Callable] = 'en', return_sample_id: bool = False, channel_selector: Optional[ChannelSelectorType] = None, + manifest_parse_func: Optional[Callable] = None, ): self.labels = labels @@ -602,6 +609,7 @@ def __init__( pad_id=pad_id, return_sample_id=return_sample_id, channel_selector=channel_selector, + manifest_parse_func=manifest_parse_func, ) @@ -640,6 +648,7 @@ class AudioToBPEDataset(_AudioTextDataset): tokens to beginning and ending of speech respectively. return_sample_id (bool): whether to return the sample_id as a part of each sample channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`. Uses zero-based indexing. + manifest_parse_func: Optional function to parse manifest entries. Defaults to None. """ @property @@ -667,6 +676,7 @@ def __init__( use_start_end_token: bool = True, return_sample_id: bool = False, channel_selector: Optional[ChannelSelectorType] = None, + manifest_parse_func: Optional[Callable] = None, ): if use_start_end_token and hasattr(tokenizer, "bos_id") and tokenizer.bos_id > 0: bos_id = tokenizer.bos_id @@ -716,6 +726,7 @@ def __call__(self, *args): trim=trim, return_sample_id=return_sample_id, channel_selector=channel_selector, + manifest_parse_func=manifest_parse_func, ) @@ -809,6 +820,7 @@ class _TarredAudioToTextDataset(IterableDataset): global_rank (int): Worker rank, used for partitioning shards. Defaults to 0. world_size (int): Total number of processes, used for partitioning shards. Defaults to 0. return_sample_id (bool): whether to return the sample_id as a part of each sample + manifest_parse_func: Optional function to parse manifest entries. Defaults to None. """ def __init__( @@ -831,6 +843,7 @@ def __init__( global_rank: int = 0, world_size: int = 0, return_sample_id: bool = False, + manifest_parse_func: Optional[Callable] = None, ): self.shard_manifests = shard_manifests @@ -856,6 +869,7 @@ def __init__( eos_id=eos_id, pad_id=pad_id, index_by_file_id=True, # Must set this so the manifest lines can be indexed by file ID + manifest_parse_func=manifest_parse_func, ) self.len = self._compute_len() @@ -1099,6 +1113,7 @@ class TarredAudioToCharDataset(_TarredAudioToTextDataset): global_rank (int): Worker rank, used for partitioning shards. Defaults to 0. world_size (int): Total number of processes, used for partitioning shards. Defaults to 0. return_sample_id (bool): whether to return the sample_id as a part of each sample + manifest_parse_func: Optional function to parse manifest entries. Defaults to None. """ def __init__( @@ -1125,6 +1140,7 @@ def __init__( global_rank: int = 0, world_size: int = 0, return_sample_id: bool = False, + manifest_parse_func: Optional[Callable] = None, ): self.labels = labels @@ -1151,6 +1167,7 @@ def __init__( global_rank=global_rank, world_size=world_size, return_sample_id=return_sample_id, + manifest_parse_func=manifest_parse_func, ) @@ -1232,6 +1249,7 @@ class TarredAudioToBPEDataset(_TarredAudioToTextDataset): global_rank (int): Worker rank, used for partitioning shards. Defaults to 0. world_size (int): Total number of processes, used for partitioning shards. Defaults to 0. return_sample_id (bool): whether to return the sample_id as a part of each sample + manifest_parse_func: Optional function to parse manifest entries. Defaults to None. """ def __init__( @@ -1252,6 +1270,7 @@ def __init__( global_rank: int = 0, world_size: int = 0, return_sample_id: bool = False, + manifest_parse_func: Optional[Callable] = None, ): if use_start_end_token and hasattr(tokenizer, "bos_id") and tokenizer.bos_id > 0: bos_id = tokenizer.bos_id @@ -1305,6 +1324,7 @@ def __call__(self, *args): global_rank=global_rank, world_size=world_size, return_sample_id=return_sample_id, + manifest_parse_func=manifest_parse_func, ) diff --git a/nemo/collections/asr/data/ssl_dataset.py b/nemo/collections/asr/data/ssl_dataset.py new file mode 100644 index 000000000000..a526adb8242f --- /dev/null +++ b/nemo/collections/asr/data/ssl_dataset.py @@ -0,0 +1,704 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import io +import json +import os +from dataclasses import dataclass +from math import isclose +from typing import Any, Dict, List, Optional + +import numpy as np +import torch +from lhotse.dataset import AudioSamples +from omegaconf import DictConfig, ListConfig, open_dict + +from nemo.collections.asr.data import audio_to_text, audio_to_text_dataset +from nemo.collections.asr.parts.preprocessing.perturb import WhiteNoisePerturbation, process_augmentations +from nemo.collections.asr.parts.preprocessing.segment import AudioSegment +from nemo.collections.asr.parts.utils.manifest_utils import read_manifest +from nemo.collections.common.data.dataset import ConcatDataset +from nemo.collections.common.parts.preprocessing.manifest import get_full_path +from nemo.core.classes import Serialization +from nemo.utils import logging + + +@dataclass +class AudioNoiseItem: + sample_id: str | None = None + audio: torch.Tensor | None = None + audio_len: torch.Tensor | None = None + noise: torch.Tensor | None = None + noise_len: torch.Tensor | None = None + noisy_audio: torch.Tensor | None = None + noisy_audio_len: torch.Tensor | None = None + + +@dataclass +class AudioNoiseBatch: + sample_id: list | None = None + audio: torch.Tensor | None = None + audio_len: torch.Tensor | None = None + noise: torch.Tensor | None = None + noise_len: torch.Tensor | None = None + noisy_audio: torch.Tensor | None = None + noisy_audio_len: torch.Tensor | None = None + + +def _parse_manifest_item(line: str, manifest_file: str) -> Dict[str, Any]: + """ + Specialized function to parse the manifest file by ignoring text, + such that nemo dataset can save time on tokenizing text. + """ + item = json.loads(line) + + # Audio file + if 'audio_filename' in item: + item['audio_file'] = item.pop('audio_filename') + elif 'audio_filepath' in item: + item['audio_file'] = item.pop('audio_filepath') + else: + raise KeyError(f"No 'audio_filename' or 'audio_filepath' in manifest item: {item}") + + item['audio_file'] = get_full_path(audio_file=item['audio_file'], manifest_file=manifest_file) + + # Duration. + if 'duration' not in item: + item['duration'] = None + + # dummy text + item['text'] = "" + + item = dict( + audio_file=item['audio_file'], + duration=item['duration'], + text=item['text'], + offset=item.get('offset', None), + speaker=item.get('speaker', None), + orig_sr=item.get('orig_sample_rate', None), + token_labels=item.get('token_labels', None), + lang=item.get('lang', None), + ) + return item + + +def _audio_noise_collate_fn(batch: List[AudioNoiseItem], batch_augmentor: Any = None) -> AudioNoiseBatch: + audios = [x.audio for x in batch] + audio_lengths = [x.audio_len for x in batch] + max_audio_len = max(audio_lengths).item() + + noises = [x.noise for x in batch] + noise_lengths = [x.noise_len for x in batch] + + noisy_audios = [x.noisy_audio for x in batch] + noisy_audio_lengths = [x.noisy_audio_len for x in batch] + + audio_signal_list = [] + noise_signal_list = [] + noisy_audio_signal_list = [] + for i, audio in enumerate(audios): + audio_len = audio.size(0) + if audio_len < max_audio_len: + pad = (0, max_audio_len - audio_len) + audio = torch.nn.functional.pad(audio, pad) + audio_signal_list.append(audio) + + noise = noises[i] + noise_len = noise.size(0) + if noise_len < max_audio_len: + pad = (0, max_audio_len - noise_len) + noise = torch.nn.functional.pad(noise, pad) + noise_signal_list.append(noise[:max_audio_len]) + + noisy_audio = noisy_audios[i] + noisy_audio_len = noisy_audio.size(0) + if noisy_audio_len < max_audio_len: + pad = (0, max_audio_len - noisy_audio_len) + noisy_audio = torch.nn.functional.pad(noisy_audio, pad) + noisy_audio_signal_list.append(noisy_audio[:max_audio_len]) + + audio_signal = torch.stack(audio_signal_list).float() + audio_lengths = torch.stack(audio_lengths).long() + noise_signal = torch.stack(noise_signal_list).float() + noise_lengths = torch.stack(noise_lengths).long() + noisy_audio_signal = torch.stack(noisy_audio_signal_list).float() + noisy_audio_lengths = torch.stack(noisy_audio_lengths).long() + + output = AudioNoiseBatch( + audio=audio_signal, + audio_len=audio_lengths, + noise=noise_signal, + noise_len=noise_lengths, + noisy_audio=noisy_audio_signal, + noisy_audio_len=noisy_audio_lengths, + ) + + if batch_augmentor is not None: + output = batch_augmentor(output) + + return output + + +def load_noise_manifest(noise_manifest: str | ListConfig | None): + """ + load noise manifest from a single or a list of manifest files + """ + if noise_manifest is None: + return [] + + if isinstance(noise_manifest, str): + noise_manifest = noise_manifest.split(',') + + noise_data = [] + for manifest in noise_manifest: + curr_data = read_manifest(manifest) + for i in range(len(curr_data)): + curr_data[i]['audio_filepath'] = get_full_path(curr_data[i]['audio_filepath'], manifest) + noise_data.extend(curr_data) + return noise_data + + +def load_noise_audio( + sample: Dict[str, Any], + sample_rate: int, + max_audio_len: Optional[int] = None, + pad_to_max: bool = True, + min_white_noise_db: int = -90, + max_white_noise_db: int = -46, + max_trial: int = 100, +): + """ + Load noise audio from the manifest item, and apply white noise if the loaded noise audio is empty. + Args: + sample: a sample from the noise manifest + sample_rate: target sample rate to resample the noise audio + max_audio_len: the maximum audio length to load + pad_to_max: whether to pad the audio to max_audio_len + min_white_noise_db: the minimum white noise level in dB + max_white_noise_db: the maximum white noise level in dB + max_trial: the maximum number of trials to load noise audio before giving up + Returns: + noise: the loaded noise audio + noise_len: the length of the loaded noise audio + """ + max_dur = None if max_audio_len is None else max_audio_len / sample_rate + duration = sample.get('duration', None) + offset = sample.get('offset', 0.0) + + if max_dur is not None and duration is not None and duration > max_dur: + cnt = 0 + while cnt < max_trial: + # randomly sample a segment of the noise + offset = np.random.uniform(0, duration - max_dur) + + audio_segment = AudioSegment.from_file( + audio_file=sample['audio_filepath'], + offset=offset, + duration=max_dur, + target_sr=sample_rate, + ) + + if sum(audio_segment.samples) > 0: + # break if the segment is not empty + break + cnt += 1 + else: + audio_segment = AudioSegment.from_file( + audio_file=sample['audio_filepath'], + offset=offset, + duration=duration, + target_sr=sample_rate, + ) + + if sum(audio_segment.samples) == 0: + logging.warning( + f"Loaded noise audio is empty: {sample}, with sampled offset={offset}, duration={max_dur}. Adding white noise." + ) + WhiteNoisePerturbation(min_level=min_white_noise_db, max_level=max_white_noise_db).perturb(audio_segment) + + noise = torch.tensor(audio_segment.samples, dtype=torch.float) + noise_len = torch.tensor(noise.size(0)).long() + # pad to max_audio_len if necessary + if max_audio_len is not None and pad_to_max: + if noise.size(0) < max_audio_len: + pad = (0, max_audio_len - noise.size(0)) + noise = torch.nn.functional.pad(noise, pad) + else: + noise = noise[:max_audio_len] + noise_len = torch.tensor(max_audio_len).long() + return noise, noise_len + + +def sample_noise(noise_data: List[Dict], sample_rate: int, max_audio_len: int | None = None, max_trial: int = 20): + """ + Randomly sample noise audio from the noise manifest. + Args: + noise_data: the noise manifest data + sample_rate: target sample rate to resample the noise audio + max_audio_len: the maximum audio length to load + max_trial: the maximum number of trials to load noise audio before giving up + Returns: + noise_audio: the sampled noise audio + noise_len: the length of the sampled noise audio + """ + cnt = 0 + noise_audio = torch.zeros(max_audio_len).float() + noise_len = torch.tensor(max_audio_len).long() + while cnt < max_trial and len(noise_data) > 0: + try: + noise_sample = noise_data[np.random.randint(len(noise_data))] + noise_audio, noise_len = load_noise_audio(noise_sample, sample_rate, max_audio_len) + break + except Exception as e: + logging.warning(f"Error loading noise audio with config {noise_sample} and exception: {e}, retrying.") + cnt += 1 + if cnt == max_trial: + logging.warning(f"Failed to load noise audio after {max_trial} attempts, returning zero noise.") + return torch.zeros(max_audio_len).float(), torch.tensor(max_audio_len).long() + return noise_audio, noise_len + + +def pad_audio(audio: torch.Tensor, min_len: int, pad_audio_mode) -> torch.Tensor: + """ + Pad audio to min_len with the specified mode + Args: + audio: the input audio tensor + min_len: the minimum length to pad to + pad_audio_mode: the padding mode, either 'repeat' or 'zero' + Returns: + audio: the padded audio tensor + """ + allowed_mode = ['repeat', 'zero'] + if audio.size(0) < min_len: + if pad_audio_mode == 'repeat' and audio.size(0) > 0: + num_repeats = int(np.ceil(min_len / audio.size(0))) + audio = audio.repeat(num_repeats)[:min_len] + elif pad_audio_mode == 'zero' or audio.size(0) == 0: + audio = torch.nn.functional.pad(audio, (0, min_len - audio.size(0))) + else: + raise ValueError(f"Unsupported pad_audio_mode: {pad_audio_mode}, must be one of {allowed_mode}") + return audio + + +class AudioNoiseDataset(audio_to_text.AudioToCharDataset): + @property + def output_types(self): + # disable type checking for since it doesn't support dataclass + return None + + def __init__( + self, + noise_manifest: str | None = None, + batch_augmentor: Any | None = None, + min_audio_len_secs: float = 1.0, + pad_audio_mode: str = 'repeat', + **kwargs, + ): + # add bos_id=0 to avoid empty text token + super().__init__(bos_id=0, manifest_parse_func=_parse_manifest_item, **kwargs) + self.noise_manifest = noise_manifest + self.batch_augmentor = batch_augmentor + self.noise_data = load_noise_manifest(noise_manifest) + self.min_audio_len_secs = min_audio_len_secs + self.pad_audio_mode = pad_audio_mode + + def __getitem__(self, index) -> AudioNoiseItem: + sample = self.manifest_processor.collection[index] + offset = sample.offset + + if offset is None: + offset = 0 + + audio = self.featurizer.process( + sample.audio_file, + offset=offset, + duration=sample.duration, + trim=self.trim, + orig_sr=sample.orig_sr, + channel_selector=self.channel_selector, + ) + if audio.size(0) == 0: + logging.warning(f"Loaded audio has zero length: {sample}.") + + min_len = int(self.min_audio_len_secs * self.featurizer.sample_rate) + audio = pad_audio(audio, min_len, self.pad_audio_mode) + audio_len = torch.tensor(audio.shape[0]).long() + noise, noise_len = sample_noise(self.noise_data, self.featurizer.sample_rate, audio_len.item()) + + item = AudioNoiseItem( + sample_id=str(index), + audio=audio, + audio_len=audio_len, + noise=noise, + noise_len=noise_len, + noisy_audio=audio + noise, + noisy_audio_len=audio_len, + ) + return item + + def _collate_fn(self, batch: List[AudioNoiseItem]) -> AudioNoiseBatch: + return _audio_noise_collate_fn(batch, self.batch_augmentor) + + +class TarredAudioNoiseDataset(audio_to_text.TarredAudioToCharDataset): + @property + def output_types(self): + # disable type checking for since it doesn't support dataclass + return None + + def __init__( + self, + noise_manifest: str | None = None, + batch_augmentor: Any | None = None, + min_audio_len_secs: float = 1.0, + pad_audio_mode: str = 'repeat', + **kwargs, + ): + """ + Args: + noise_manifest: the noise manifest file + batch_augmentor: the batch augmentor + min_audio_len_secs: the minimum audio length in seconds, audios shorter than this will be padded + pad_audio_mode: the padding mode for audios shorter than min_audio_len_secs, either 'repeat' or 'zero' + **kwargs: other arguments for TarredAudioToCharDataset + + """ + super().__init__(bos_id=0, manifest_parse_func=_parse_manifest_item, **kwargs) + self.noise_manifest = noise_manifest + self.batch_augmentor = batch_augmentor + self.noise_data = load_noise_manifest(noise_manifest) + self.min_audio_len_secs = min_audio_len_secs + self.pad_audio_mode = pad_audio_mode + + def _build_sample(self, tup): + """Builds the training sample by combining the data from the WebDataset with the manifest info.""" + audio_bytes, audio_filename, offset_id = tup + + # Grab manifest entry from self.manifest_preprocessor.collection + file_id, _ = os.path.splitext(os.path.basename(audio_filename)) + manifest_idx = self.manifest_processor.collection.mapping[file_id][offset_id] + manifest_entry = self.manifest_processor.collection[manifest_idx] + + offset = manifest_entry.offset + if offset is None: + offset = 0 + + try: + # Convert audio bytes to IO stream for processing (for SoundFile to read) + audio_filestream = io.BytesIO(audio_bytes) + audio = self.featurizer.process( + audio_filestream, + offset=offset, + duration=manifest_entry.duration, + trim=self.trim, + orig_sr=manifest_entry.orig_sr, + ) + audio_filestream.close() + except Exception as e: + raise RuntimeError(f"Error reading audio sample: {manifest_entry}, with exception: {e}.") + + min_len = int(self.min_audio_len_secs * self.featurizer.sample_rate) + audio = pad_audio(audio, min_len, self.pad_audio_mode) + audio_len = torch.tensor(audio.shape[0]).long() + noise, noise_len = sample_noise(self.noise_data, self.featurizer.sample_rate, audio_len.item()) + + item = AudioNoiseItem( + sample_id=str(manifest_idx), + audio=audio, + audio_len=audio_len, + noise=noise, + noise_len=noise_len, + noisy_audio=audio + noise, + noisy_audio_len=audio_len, + ) + return item + + def _pad_audio(self, audio: torch.Tensor) -> torch.Tensor: + min_len = int(self.min_audio_len_secs * self.featurizer.sample_rate) + if audio.size(0) < min_len: + if self.pad_audio_mode == 'repeat': + num_repeats = int(np.ceil(min_len / audio.size(0))) + audio = audio.repeat(num_repeats)[:min_len] + elif self.pad_audio_mode == 'zero': + audio = torch.nn.functional.pad(audio, (0, min_len - audio.size(0))) + else: + raise ValueError( + f"Unsupported pad_audio_mode: {self.pad_audio_mode}, must be one of ['repeat', 'zero']" + ) + return audio + + def _collate_fn(self, batch: List[AudioNoiseItem]) -> AudioNoiseBatch: + return _audio_noise_collate_fn(batch, self.batch_augmentor) + + +class LhotseAudioNoiseDataset(torch.utils.data.Dataset): + def __init__(self, noise_manifest: str | None = None, batch_augmentor_cfg: DictConfig = None): + super().__init__() + + if batch_augmentor_cfg: + batch_augmentor = Serialization.from_config_dict(batch_augmentor_cfg) + else: + batch_augmentor = None + + self.batch_augmentor = batch_augmentor + self.noise_data = load_noise_manifest(noise_manifest) + self.load_audio = AudioSamples(fault_tolerant=True) + + def __getitem__(self, cuts): + + audios, audio_lens, cuts = self.load_audio(cuts) + sampled_noises = [sample_noise(self.noise_data, cut.sampling_rate, cut.num_samples) for cut in cuts] + + items = [ + AudioNoiseItem( + sample_id=str(cuts[i].id), + audio=audios[i], + audio_len=audio_lens[i], + noise=sampled_noises[i][0], + noise_len=sampled_noises[i][1], + noisy_audio=audios[i] + sampled_noises[i][0], + noisy_audio_len=audio_lens[i], + ) + for i in range(len(cuts)) + ] + return _audio_noise_collate_fn(items, self.batch_augmentor) + + +def get_audio_noise_dataset( + config: Dict[str, Any], augmentor: Any = None, batch_augmentor: Any = None +) -> AudioNoiseDataset: + dataset = AudioNoiseDataset( + noise_manifest=config.get('noise_manifest', None), + batch_augmentor=batch_augmentor, + manifest_filepath=config['manifest_filepath'], + labels=config.get('labels', None), + sample_rate=config['sample_rate'], + int_values=config.get('int_values', False), + augmentor=augmentor, + max_duration=config.get('max_duration', None), + min_duration=config.get('min_duration', None), + trim=config.get('trim_silence', False), + channel_selector=config.get('channel_selector', None), + ) + return dataset + + +def get_concat_audio_noise_dataset( + config: Dict[str, Any], global_rank: int, world_size: int, augmentor: Any = None, batch_augmentor: Any = None +) -> ConcatDataset: + manifest_filepaths = config['manifest_filepath'] + datasets = [] + + # needed to support validation Concat Datasets that arrive here as + # [[dataset1,dataset2]] otherwise ModelPT would interfere + if len(manifest_filepaths) == 1 and not isinstance(manifest_filepaths[0], str): + logging.info(f"removing an extra nesting level from {manifest_filepaths}") + manifest_filepaths = config['manifest_filepath'][0] + + for manifest_filepath in manifest_filepaths: + conf = copy.deepcopy(config) + conf['manifest_filepath'] = manifest_filepath + + dataset = get_audio_noise_dataset(config=conf, augmentor=augmentor) + datasets.append(dataset) + + dataset = ConcatDataset( + datasets, + sampling_technique=config.get('concat_sampling_technique', 'temperature'), + sampling_temperature=config.get('concat_sampling_temperature', 5), + sampling_scale=config.get('concat_sampling_scale', 1), + sampling_probabilities=config.get('concat_sampling_probabilities', None), + shuffle=config.get('concat_shuffle', True), + seed=config.get('concat_sampling_seed', None), + global_rank=global_rank, + world_size=world_size, + ) + return dataset + + +def get_tarred_audio_noise_dataset(config, shuffle_n, global_rank, world_size, augmentor, batch_augmentor: Any = None): + tarred_audio_filepaths = config['tarred_audio_filepaths'] + manifest_filepaths = config['manifest_filepath'] + datasets = [] + tarred_audio_filepaths = audio_to_text_dataset.convert_to_config_list(tarred_audio_filepaths) + manifest_filepaths = audio_to_text_dataset.convert_to_config_list(manifest_filepaths) + + bucketing_weights = config.get('bucketing_weights', None) # For upsampling buckets + if bucketing_weights: + for idx, weight in enumerate(bucketing_weights): + if not isinstance(weight, int) or weight <= 0: + raise ValueError(f"bucket weights must be positive integers") + + if len(manifest_filepaths) != len(tarred_audio_filepaths): + raise ValueError( + f"manifest_filepaths (length={len(manifest_filepaths)}) and tarred_audio_filepaths (length={len(tarred_audio_filepaths)}) need to have the same number of buckets." + ) + + for dataset_idx, (tarred_audio_filepath, manifest_filepath) in enumerate( + zip(tarred_audio_filepaths, manifest_filepaths) + ): + if len(tarred_audio_filepath) == 1: + tarred_audio_filepath = tarred_audio_filepath[0] + if len(manifest_filepath) == 1: + manifest_filepath = manifest_filepath[0] + + is_sharded_manifest = True if "_OP_" in manifest_filepath and "_CL_" in manifest_filepath else False + logging.info( + f"Loading TarredAudioNoiseDataset from {tarred_audio_filepath} and {manifest_filepath}, shard={is_sharded_manifest}" + ) + dataset = TarredAudioNoiseDataset( + noise_manifest=config.get('noise_manifest', None), + batch_augmentor=batch_augmentor, + audio_tar_filepaths=tarred_audio_filepath, + manifest_filepath=manifest_filepath, + labels=config.get('labels', None), + sample_rate=config['sample_rate'], + int_values=config.get('int_values', False), + augmentor=augmentor, + shuffle_n=shuffle_n, + max_duration=config.get('max_duration', None), + min_duration=config.get('min_duration', None), + trim=config.get('trim_silence', False), + shard_strategy=config.get('tarred_shard_strategy', 'scatter'), + shard_manifests=is_sharded_manifest, + global_rank=global_rank, + world_size=world_size, + ) + if bucketing_weights: + [datasets.append(dataset) for _ in range(bucketing_weights[dataset_idx])] + else: + datasets.append(dataset) + + return audio_to_text_dataset.get_chain_dataset(datasets=datasets, ds_config=config, rank=global_rank) + + +def get_concat_tarred_audio_noise_dataset( + config, shuffle_n, global_rank, world_size, augmentor, batch_augmentor: Any = None +): + tarred_audio_filepaths = config['tarred_audio_filepaths'] + manifest_filepaths = config['manifest_filepath'] + datasets = [] + for dataset_idx, (tarred_audio_filepath, manifest_filepath) in enumerate( + zip(tarred_audio_filepaths, manifest_filepaths) + ): + conf = copy.deepcopy(config) + conf['manifest_filepath'] = manifest_filepath + conf['tarred_audio_filepaths'] = tarred_audio_filepath + dataset = get_tarred_audio_noise_dataset( + config=conf, + shuffle_n=shuffle_n, + global_rank=global_rank, + world_size=world_size, + augmentor=augmentor, + batch_augmentor=batch_augmentor, + ) + datasets.append(dataset) + + dataset = ConcatDataset( + datasets, + sampling_technique=config.get('concat_sampling_technique', 'temperature'), + sampling_temperature=config.get('concat_sampling_temperature', 5), + sampling_scale=config.get('concat_sampling_scale', 1), + sampling_probabilities=config.get('concat_sampling_probabilities', None), + shuffle=config.get('concat_shuffle', True), + seed=config.get('concat_sampling_seed', None), + global_rank=global_rank, + world_size=world_size, + ) + return dataset + + +def get_audio_noise_dataset_from_config( + config, + global_rank: int, + world_size: int, +): + if 'augmentor' in config: + augmentor = process_augmentations(config['augmentor'], global_rank=global_rank, world_size=world_size) + else: + augmentor = None + + if 'batch_augmentor' in config: + batch_augmentor = Serialization.from_config_dict(config['batch_augmentor']) + else: + batch_augmentor = None + + is_concat = config.get('is_concat', False) + if is_concat: + if config.get('concat_sampling_technique', None) is None: + logging.warning( + f"Concat dataset requires `concat_sampling_technique` but it was not provided, using round-robin. Config: {config}" + ) + config['concat_sampling_technique'] = 'round-robin' + + if config['concat_sampling_technique'] == 'random': + if not 'concat_sampling_probabilities' in config: + logging.warning( + f"Concat dataset requires `concat_sampling_probabilities` list, using uniform weights. Config: {config}" + ) + with open_dict(config): + config['concat_sampling_probabilities'] = [1 / len(config['manifest_filepath'])] * len( + config['manifest_filepath'] + ) + elif not isclose(sum(config['concat_sampling_probabilities']), 1, abs_tol=1e-6): + raise ValueError( + f"`concat_sampling_probabilities` need to sum to 1 with 1e-6 tolerance. Config: {config}" + ) + + shuffle = config['shuffle'] + if config.get('is_tarred', False): + if ('tarred_audio_filepaths' in config and config['tarred_audio_filepaths'] is None) or ( + 'manifest_filepath' in config and config['manifest_filepath'] is None + ): + logging.warning( + "Could not load dataset as `manifest_filepath` was None or " + f"`tarred_audio_filepaths` is None. Provided config : {config}" + ) + return None + + shuffle_n = config.get('shuffle_n', 4 * config['batch_size']) if shuffle else 0 + if is_concat: + dataset = get_concat_tarred_audio_noise_dataset( + config=config, + shuffle_n=shuffle_n, + global_rank=global_rank, + world_size=world_size, + augmentor=augmentor, + batch_augmentor=batch_augmentor, + ) + else: + dataset = get_tarred_audio_noise_dataset( + config=config, + shuffle_n=shuffle_n, + global_rank=global_rank, + world_size=world_size, + augmentor=augmentor, + batch_augmentor=batch_augmentor, + ) + else: + if 'manifest_filepath' in config and config['manifest_filepath'] is None: + logging.warning(f"Could not load dataset as `manifest_filepath` was None. Provided config : {config}") + return None + if is_concat: + dataset = get_concat_audio_noise_dataset( + config=config, + global_rank=global_rank, + world_size=world_size, + augmentor=augmentor, + batch_augmentor=batch_augmentor, + ) + else: + dataset = get_audio_noise_dataset(config=config, augmentor=augmentor, batch_augmentor=batch_augmentor) + return dataset diff --git a/nemo/collections/asr/losses/__init__.py b/nemo/collections/asr/losses/__init__.py index 0747e9a37bea..756a071178d7 100644 --- a/nemo/collections/asr/losses/__init__.py +++ b/nemo/collections/asr/losses/__init__.py @@ -17,5 +17,5 @@ from nemo.collections.asr.losses.lattice_losses import LatticeLoss from nemo.collections.asr.losses.ssl_losses.contrastive import ContrastiveLoss from nemo.collections.asr.losses.ssl_losses.ctc import CTCLossForSSL -from nemo.collections.asr.losses.ssl_losses.mlm import MLMLoss +from nemo.collections.asr.losses.ssl_losses.mlm import MLMLoss, MultiMLMLoss from nemo.collections.asr.losses.ssl_losses.rnnt import RNNTLossForSSL diff --git a/nemo/collections/asr/losses/ssl_losses/mlm.py b/nemo/collections/asr/losses/ssl_losses/mlm.py index 89de01dc1b34..424374869c3d 100644 --- a/nemo/collections/asr/losses/ssl_losses/mlm.py +++ b/nemo/collections/asr/losses/ssl_losses/mlm.py @@ -25,14 +25,14 @@ class MLMLoss(Loss): @property def input_types(self): - """Input types definitions for Contrastive. - """ + """Input types definitions for Contrastive.""" return { - "spec_masks": NeuralType(("B", "D", "T"), SpectrogramType()), + "spec_masks": NeuralType(("B", "D", "T"), SpectrogramType(), optional=True), "decoder_outputs": NeuralType(("B", "T", "D"), LogprobsType()), "targets": NeuralType(('B', 'T'), LabelsType()), "decoder_lengths": NeuralType(tuple('B'), LengthsType(), optional=True), "target_lengths": NeuralType(tuple('B'), LengthsType(), optional=True), + "masks": NeuralType(("B", "D", "T"), SpectrogramType(), optional=True), } @property @@ -48,7 +48,9 @@ def needs_labels(self): return True def __init__( - self, combine_time_steps: int = 1, mask_threshold: float = 0.8, + self, + combine_time_steps: int = 1, + mask_threshold: float = 0.8, ): super().__init__() self.nll_loss = nn.NLLLoss() @@ -56,11 +58,15 @@ def __init__( self.mask_threshold = mask_threshold @typecheck() - def forward(self, spec_masks, decoder_outputs, targets, decoder_lengths=None, target_lengths=None): + def forward( + self, decoder_outputs, targets, decoder_lengths=None, target_lengths=None, spec_masks=None, masks=None + ): + + if masks is None: + masks = spec_masks - # outputs are log_probs - masks = spec_masks.transpose(-2, -1) - # BxTxC + # B,D,T -> B,T,D + masks = masks.transpose(1, 2) masks = masks.reshape(masks.shape[0], masks.shape[1] // self.combine_time_steps, -1) masks = masks.mean(-1) > self.mask_threshold @@ -73,3 +79,60 @@ def forward(self, spec_masks, decoder_outputs, targets, decoder_lengths=None, ta loss = torch.mean(loss) return loss + + +class MultiMLMLoss(Loss): + """ + Masked language model loss for multiple decoders, where cross-entropy loss is applied separately on each decoder. + This loss can be used with `nemo.collections.asr.modules.ssl_modules.MultiSoftmaxDecoder` to train a model with multiple targets per frame. + Reference: https://arxiv.org/abs/2202.01855 + """ + + @property + def input_types(self): + if self.squeeze_single and self.num_decoders == 1: + decoder_outputs = NeuralType(("B", "T", "C"), LogprobsType()) + targets = NeuralType(('B', 'T'), LabelsType()) + else: + decoder_outputs = NeuralType(("B", "T", "C", "H"), LogprobsType()) + targets = NeuralType(("B", "T", "H"), LabelsType()) + return { + "masks": NeuralType(("B", "D", "T"), SpectrogramType()), + "decoder_outputs": decoder_outputs, + "targets": targets, + "decoder_lengths": NeuralType(tuple('B'), LengthsType(), optional=True), + "target_lengths": NeuralType(tuple('B'), LengthsType(), optional=True), + } + + def __init__( + self, + combine_time_steps: int = 1, + mask_threshold: float = 0.8, + num_decoders: int = 1, + squeeze_single: bool = False, + ): + super().__init__() + self.num_decoders = num_decoders + self.squeeze_single = squeeze_single + self.mlm_loss = MLMLoss(combine_time_steps, mask_threshold) + + @typecheck() + def forward(self, masks, decoder_outputs, targets, decoder_lengths=None, target_lengths=None): + if self.squeeze_single and self.num_decoders == 1: + return self.mlm_loss( + spec_masks=masks, + decoder_outputs=decoder_outputs, + targets=targets, + decoder_lengths=decoder_lengths, + target_lengths=target_lengths, + ) + loss = 0.0 + for i in range(self.num_decoders): + loss += self.mlm_loss( + spec_masks=masks, + decoder_outputs=decoder_outputs[:, :, :, i], + targets=targets[:, :, i], + decoder_lengths=decoder_lengths, + target_lengths=target_lengths, + ) + return loss / self.num_decoders diff --git a/nemo/collections/asr/models/__init__.py b/nemo/collections/asr/models/__init__.py index 9b339df44f18..e4a1342b9c36 100644 --- a/nemo/collections/asr/models/__init__.py +++ b/nemo/collections/asr/models/__init__.py @@ -35,5 +35,9 @@ from nemo.collections.asr.models.rnnt_bpe_models import EncDecRNNTBPEModel from nemo.collections.asr.models.rnnt_models import EncDecRNNTModel from nemo.collections.asr.models.slu_models import SLUIntentSlotBPEModel -from nemo.collections.asr.models.ssl_models import SpeechEncDecSelfSupervisedModel +from nemo.collections.asr.models.ssl_models import ( + EncDecDenoiseMaskedTokenPredModel, + EncDecMaskedTokenPredModel, + SpeechEncDecSelfSupervisedModel, +) from nemo.collections.asr.models.transformer_bpe_models import EncDecTransfModelBPE diff --git a/nemo/collections/asr/models/configs/asr_models_config.py b/nemo/collections/asr/models/configs/asr_models_config.py index 397c13f30f1a..29dbbe06d1f8 100644 --- a/nemo/collections/asr/models/configs/asr_models_config.py +++ b/nemo/collections/asr/models/configs/asr_models_config.py @@ -63,6 +63,9 @@ class ASRDatasetConfig(nemo.core.classes.dataset.DatasetConfig): bucketing_batch_size: Optional[Any] = None bucketing_weights: Optional[List[int]] = None + # Optional callable function to parse manifest file + manifest_parse_func: Optional[Any] = (None,) + @dataclass class EncDecCTCConfig(model_cfg.ModelConfig): @@ -104,15 +107,23 @@ class EncDecCTCModelConfig(model_cfg.NemoConfig): @dataclass class CacheAwareStreamingConfig: - chunk_size: int = 0 # the size of each chunk at each step, it can be a list of two integers to specify different chunk sizes for the first step and others - shift_size: int = 0 # the size of the shift in each step, it can be a list of two integers to specify different shift sizes for the first step and others + chunk_size: int = ( + 0 # the size of each chunk at each step, it can be a list of two integers to specify different chunk sizes for the first step and others + ) + shift_size: int = ( + 0 # the size of the shift in each step, it can be a list of two integers to specify different shift sizes for the first step and others + ) cache_drop_size: int = 0 # the number of steps to drop from the cache last_channel_cache_size: int = 0 # the size of the needed cache for last channel layers - valid_out_len: int = 0 # the number of the steps in the final output which are valid (have the same value as in the offline mode) + valid_out_len: int = ( + 0 # the number of the steps in the final output which are valid (have the same value as in the offline mode) + ) - pre_encode_cache_size: int = 0 # the size of the needed cache for the pre-encoding part of the model to avoid caching inside the pre-encoding layers + pre_encode_cache_size: int = ( + 0 # the size of the needed cache for the pre-encoding part of the model to avoid caching inside the pre-encoding layers + ) drop_extra_pre_encoded: int = 0 # the number of steps to get dropped after the pre-encoding layer last_channel_num: int = 0 # number of the last channel layers (like MHA layers) which need caching in the model diff --git a/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py b/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py index e927faee96d6..c14265325985 100644 --- a/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py +++ b/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py @@ -609,6 +609,9 @@ def test_step(self, batch, batch_idx, dataloader_idx=0): return test_logs def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0): + if not outputs or not all([isinstance(x, dict) for x in outputs]): + logging.warning("No outputs to process for validation_epoch_end") + return {} if self.compute_eval_loss: val_loss_mean = torch.stack([x['val_loss'] for x in outputs]).mean() val_loss_log = {'val_loss': val_loss_mean} diff --git a/nemo/collections/asr/models/label_models.py b/nemo/collections/asr/models/label_models.py index 62cf2e4608d0..08c304e4c52c 100644 --- a/nemo/collections/asr/models/label_models.py +++ b/nemo/collections/asr/models/label_models.py @@ -26,10 +26,15 @@ from hydra.utils import instantiate from omegaconf import DictConfig, OmegaConf, open_dict from pytorch_lightning import Trainer +from sklearn.metrics import roc_curve from torchmetrics import Accuracy from tqdm import tqdm -from nemo.collections.asr.data.audio_to_label import AudioToSpeechLabelDataset, cache_datastore_manifests +from nemo.collections.asr.data.audio_to_label import ( + AudioPairToLabelDataset, + AudioToSpeechLabelDataset, + cache_datastore_manifests, +) from nemo.collections.asr.data.audio_to_label_dataset import ( get_concat_tarred_speech_label_dataset, get_tarred_speech_label_dataset, @@ -139,7 +144,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): if 'loss' in cfg: cfg_eval_loss = copy.deepcopy(cfg.loss) - if '_target_' in cfg.loss and 'angular' in cfg.loss._target_: + if 'angular' in cfg.loss.get('_target_', {}): OmegaConf.set_struct(cfg, True) with open_dict(cfg): cfg.decoder.angular = True @@ -192,7 +197,7 @@ def extract_labels(data_layer_config): ) labels.update(collection.uniq_labels) labels = list(sorted(labels)) - logging.warning(f"Total number of {len(labels)} found in all the manifest files.") + logging.warning(f"Total number of {len(labels)} labels found in all the manifest files.") return labels def __setup_dataloader_from_config(self, config: Optional[Dict]): @@ -238,7 +243,13 @@ def __setup_dataloader_from_config(self, config: Optional[Dict]): logging.warning(f"Could not load dataset as `manifest_filepath` was None. Provided config : {config}") return None - dataset = AudioToSpeechLabelDataset( + if config.get("is_audio_pair", False): + data_cls = AudioPairToLabelDataset + logging.warning("Using AudioPairToLabelDataset, where Angular loss will not be computed.") + else: + data_cls = AudioToSpeechLabelDataset + + dataset = data_cls( manifest_filepath=config['manifest_filepath'], labels=config['labels'], featurizer=featurizer, @@ -304,12 +315,18 @@ def setup_training_data(self, train_data_layer_config: Optional[Union[DictConfig ) def setup_validation_data(self, val_data_layer_config: Optional[Union[DictConfig, Dict]]): - val_data_layer_config['labels'] = self.labels + if val_data_layer_config.get("is_audio_pair", False): + val_data_layer_config['labels'] = ["0", "1"] + else: + val_data_layer_config['labels'] = self.labels self._validation_dl = self.__setup_dataloader_from_config(config=val_data_layer_config) def setup_test_data(self, test_data_layer_params: Optional[Union[DictConfig, Dict]]): if hasattr(self, 'dataset'): - test_data_layer_params['labels'] = self.labels + if test_data_layer_params.get("is_audio_pair", False): + test_data_layer_params['labels'] = ["0", "1"] + else: + test_data_layer_params['labels'] = self.labels self.embedding_dir = test_data_layer_params.get('embedding_dir', './') self._test_dl = self.__setup_dataloader_from_config(config=test_data_layer_params) @@ -342,7 +359,6 @@ def forward_for_export(self, audio_signal, length): logits, embs = self.decoder(encoder_output=encoded, length=length) return logits, embs - @typecheck() def forward(self, input_signal, input_signal_length): processed_signal, processed_signal_len = self.preprocessor( input_signal=input_signal, @@ -352,15 +368,34 @@ def forward(self, input_signal, input_signal_length): if self.spec_augmentation is not None and self.training: processed_signal = self.spec_augmentation(input_spec=processed_signal, length=processed_signal_len) - encoded, length = self.encoder(audio_signal=processed_signal, length=processed_signal_len) - logits, embs = self.decoder(encoder_output=encoded, length=length) + encoder_outputs = self.encoder(audio_signal=processed_signal, length=processed_signal_len) + if isinstance(encoder_outputs, tuple): + encoded, length = encoder_outputs + else: + encoded, length = encoder_outputs, None + decoder_outputs = self.decoder(encoder_output=encoded, length=length) + if isinstance(decoder_outputs, tuple): + logits, embs = decoder_outputs + else: + logits, embs = decoder_outputs, None + return logits, embs # PTL-specific methods def training_step(self, batch, batch_idx): - audio_signal, audio_signal_len, labels, _ = batch - logits, _ = self.forward(input_signal=audio_signal, input_signal_length=audio_signal_len) - loss = self.loss(logits=logits, labels=labels) + if len(batch) > 4: + audio_signal_1, audio_signal_len_1, audio_signal_2, audio_signal_len_2, labels, _ = batch + _, audio_emb1 = self.forward(input_signal=audio_signal_1, input_signal_length=audio_signal_len_1) + _, audio_emb2 = self.forward(input_signal=audio_signal_2, input_signal_length=audio_signal_len_2) + + # convert binary labels to -1, 1 + loss_labels = (labels.float() - 0.5) * 2 + cosine_sim = torch.cosine_similarity(audio_emb1, audio_emb2) + loss = torch.nn.functional.mse_loss(cosine_sim, loss_labels) + else: + audio_signal, audio_signal_len, labels, _ = batch + logits, _ = self.forward(input_signal=audio_signal, input_signal_length=audio_signal_len) + loss = self.loss(logits=logits, labels=labels) self.log('loss', loss) self.log('learning_rate', self._optimizer.param_groups[0]['lr']) @@ -375,9 +410,13 @@ def training_step(self, batch, batch_idx): return {'loss': loss} def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = 'val'): + if len(batch) > 4: + return self.pair_evaluation_step(batch, batch_idx, dataloader_idx, tag) + audio_signal, audio_signal_len, labels, _ = batch logits, _ = self.forward(input_signal=audio_signal, input_signal_length=audio_signal_len) loss_value = self.eval_loss(logits=logits, labels=labels) + acc_top_k = self._accuracy(logits=logits, labels=labels) correct_counts, total_counts = self._accuracy.correct_counts_k, self._accuracy.total_counts_k self._macro_accuracy.update(preds=logits, target=labels) @@ -403,8 +442,93 @@ def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = return output + def pair_evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = 'val'): + audio_signal_1, audio_signal_len_1, audio_signal_2, audio_signal_len_2, labels, _ = batch + _, audio_emb1 = self.forward(input_signal=audio_signal_1, input_signal_length=audio_signal_len_1) + _, audio_emb2 = self.forward(input_signal=audio_signal_2, input_signal_length=audio_signal_len_2) + + # convert binary labels to -1, 1 + loss_labels = (labels.float() - 0.5) * 2 + cosine_sim = torch.cosine_similarity(audio_emb1, audio_emb2) + loss_value = torch.nn.functional.mse_loss(cosine_sim, loss_labels) + + logits = torch.stack([1 - cosine_sim, cosine_sim], dim=-1) + acc_top_k = self._accuracy(logits=logits, labels=labels) + correct_counts, total_counts = self._accuracy.correct_counts_k, self._accuracy.total_counts_k + self._macro_accuracy.update(preds=logits, target=labels) + stats = self._macro_accuracy._final_state() + + output = { + f'{tag}_loss': loss_value, + f'{tag}_correct_counts': correct_counts, + f'{tag}_total_counts': total_counts, + f'{tag}_acc_micro_top_k': acc_top_k, + f'{tag}_acc_macro_stats': stats, + f"{tag}_scores": cosine_sim, + f"{tag}_labels": labels, + } + + if tag == 'val': + if isinstance(self.trainer.val_dataloaders, (list, tuple)) and len(self.trainer.val_dataloaders) > 1: + self.validation_step_outputs[dataloader_idx].append(output) + else: + self.validation_step_outputs.append(output) + else: + if isinstance(self.trainer.test_dataloaders, (list, tuple)) and len(self.trainer.test_dataloaders) > 1: + self.test_step_outputs[dataloader_idx].append(output) + else: + self.test_step_outputs.append(output) + + return output + + def pair_multi_eval_epoch_end(self, outputs, dataloader_idx: int = 0, tag: str = 'val'): + loss_mean = torch.stack([x[f'{tag}_loss'] for x in outputs]).mean() + scores = torch.cat([x[f'{tag}_scores'] for x in outputs]).cpu().numpy() + labels = torch.cat([x[f'{tag}_labels'] for x in outputs]).long().cpu().numpy() + fpr, tpr, thresholds = roc_curve(y_true=labels, y_score=scores, pos_label=1) + fnr = 1 - tpr + try: + eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))] * 100 + except ValueError as e: + logging.warning(f"Got ValueError while calculating EER: {e}") + eer = 100.0 + + correct_counts = torch.stack([x[f'{tag}_correct_counts'] for x in outputs]).sum(axis=0) + total_counts = torch.stack([x[f'{tag}_total_counts'] for x in outputs]).sum(axis=0) + + self._accuracy.correct_counts_k = correct_counts + self._accuracy.total_counts_k = total_counts + topk_scores = self._accuracy.compute() + + self._macro_accuracy.tp = torch.stack([x[f'{tag}_acc_macro_stats'][0] for x in outputs]).sum(axis=0) + self._macro_accuracy.fp = torch.stack([x[f'{tag}_acc_macro_stats'][1] for x in outputs]).sum(axis=0) + self._macro_accuracy.tn = torch.stack([x[f'{tag}_acc_macro_stats'][2] for x in outputs]).sum(axis=0) + self._macro_accuracy.fn = torch.stack([x[f'{tag}_acc_macro_stats'][3] for x in outputs]).sum(axis=0) + macro_accuracy_score = self._macro_accuracy.compute() + + self._accuracy.reset() + self._macro_accuracy.reset() + + tensorboard_logs = {f'{tag}_loss': loss_mean, f"{tag}_eer": eer} + for top_k, score in zip(self._accuracy.top_k, topk_scores): + tensorboard_logs[f'{tag}_acc_micro_top_{top_k}'] = score + tensorboard_logs[f'{tag}_acc_macro'] = macro_accuracy_score + + return {f'{tag}_loss': loss_mean, 'log': tensorboard_logs} + def multi_evaluation_epoch_end(self, outputs, dataloader_idx: int = 0, tag: str = 'val'): + # Check if all outputs are non-empty + if not outputs or not all([bool(x) for x in outputs]): + logging.warning( + f"Not all outputs are dictionaries. Cannot aggregate results for {tag} dataset in dataloader {dataloader_idx}. Outputs: {outputs}" + ) + return {} + + if f"{tag}_scores" in outputs[0]: + return self.pair_multi_eval_epoch_end(outputs, dataloader_idx, tag) + loss_mean = torch.stack([x[f'{tag}_loss'] for x in outputs]).mean() + correct_counts = torch.stack([x[f'{tag}_correct_counts'] for x in outputs]).sum(axis=0) total_counts = torch.stack([x[f'{tag}_total_counts'] for x in outputs]).sum(axis=0) @@ -421,16 +545,12 @@ def multi_evaluation_epoch_end(self, outputs, dataloader_idx: int = 0, tag: str self._accuracy.reset() self._macro_accuracy.reset() - self.log(f'{tag}_loss', loss_mean, sync_dist=True) + tensorboard_logs = {f'{tag}_loss': loss_mean} for top_k, score in zip(self._accuracy.top_k, topk_scores): - self.log(f'{tag}_acc_micro_top_{top_k}', score, sync_dist=True) - self.log(f'{tag}_acc_macro', macro_accuracy_score, sync_dist=True) + tensorboard_logs[f'{tag}_acc_micro_top_{top_k}'] = score + tensorboard_logs[f'{tag}_acc_macro'] = macro_accuracy_score - return { - f'{tag}_loss': loss_mean, - f'{tag}_acc_micro_top_k': topk_scores, - f'{tag}_acc_macro': macro_accuracy_score, - } + return {f'{tag}_loss': loss_mean, 'log': tensorboard_logs} def validation_step(self, batch, batch_idx, dataloader_idx: int = 0): return self.evaluation_step(batch, batch_idx, dataloader_idx, 'val') diff --git a/nemo/collections/asr/models/ssl_models.py b/nemo/collections/asr/models/ssl_models.py index 787c91e7b84a..5424ed79e751 100644 --- a/nemo/collections/asr/models/ssl_models.py +++ b/nemo/collections/asr/models/ssl_models.py @@ -13,19 +13,21 @@ # limitations under the License. from math import ceil -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union import torch import torch.nn as nn from omegaconf import DictConfig from pytorch_lightning import Trainer -from nemo.collections.asr.data import audio_to_text_dataset +from nemo.collections.asr.data import audio_to_text_dataset, ssl_dataset from nemo.collections.asr.data.audio_to_text_dali import DALIOutputs from nemo.collections.asr.data.audio_to_text_lhotse import LhotseSpeechToTextBpeDataset +from nemo.collections.asr.modules.ssl_modules.masking import ConvFeatureMaksingWrapper from nemo.collections.asr.parts.mixins import ASRModuleMixin from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config +from nemo.collections.common.data.utils import move_data_to_device from nemo.collections.common.parts.preprocessing.parsers import make_parser from nemo.core.classes import ModelPT from nemo.core.classes.common import PretrainedModelInfo, typecheck @@ -35,12 +37,13 @@ AudioSignal, LabelsType, LengthsType, + LogprobsType, NeuralType, SpectrogramType, ) from nemo.utils import logging -__all__ = ['SpeechEncDecSelfSupervisedModel'] +__all__ = ['SpeechEncDecSelfSupervisedModel', 'EncDecMaskedTokenPredModel', 'EncDecDenoiseMaskedTokenPredModel'] class SpeechEncDecSelfSupervisedModel(ModelPT, ASRModuleMixin, AccessMixin): @@ -261,11 +264,16 @@ def setup_training_data(self, train_data_config: Optional[Union[DictConfig, Dict # We also need to check if limit_train_batches is already set. # If it's an int, we assume that the user has set it to something sane, i.e. <= # training batches, # and don't change it. Otherwise, adjust batches accordingly if it's a float (including 1.0). - if isinstance(self._trainer.limit_train_batches, float): + if self._trainer is not None and isinstance(self._trainer.limit_train_batches, float): self._trainer.limit_train_batches = int( self._trainer.limit_train_batches * ceil((len(self._train_dl.dataset) / self.world_size) / train_data_config['batch_size']) ) + elif self._trainer is None: + logging.warning( + "Model Trainer was not set before constructing the dataset, incorrect number of " + "training batches will be used. Please set the trainer and rebuild the dataset." + ) def setup_validation_data(self, val_data_config: Optional[Union[DictConfig, Dict]]): """ @@ -333,7 +341,11 @@ def output_types(self) -> Optional[Dict[str, NeuralType]]: @typecheck() def forward( - self, input_signal=None, input_signal_length=None, processed_signal=None, processed_signal_length=None, + self, + input_signal=None, + input_signal_length=None, + processed_signal=None, + processed_signal_length=None, ): """ Forward pass of the model. @@ -390,7 +402,8 @@ def forward( if not has_processed_signal: processed_signal, processed_signal_length = self.preprocessor( - input_signal=input_signal, length=input_signal_length, + input_signal=input_signal, + length=input_signal_length, ) if self.pen_factor: @@ -509,11 +522,13 @@ def training_step(self, batch, batch_nb): signal, signal_len, targets, target_lengths = batch if isinstance(batch, DALIOutputs) and batch.has_processed_signal: spectrograms, spec_masks, encoded, encoded_len = self.forward( - processed_signal=signal, processed_signal_length=signal_len, + processed_signal=signal, + processed_signal_length=signal_len, ) else: spectrograms, spec_masks, encoded, encoded_len = self.forward( - input_signal=signal, input_signal_length=signal_len, + input_signal=signal, + input_signal_length=signal_len, ) if self.decoder_losses is not None: @@ -553,11 +568,13 @@ def validation_pass(self, batch, batch_idx, dataloader_idx=0): signal, signal_len, targets, target_lengths = batch if isinstance(batch, DALIOutputs) and batch.has_processed_signal: spectrograms, spec_masks, encoded, encoded_len = self.forward( - processed_signal=signal, processed_signal_length=signal_len, + processed_signal=signal, + processed_signal_length=signal_len, ) else: spectrograms, spec_masks, encoded, encoded_len = self.forward( - input_signal=signal, input_signal_length=signal_len, + input_signal=signal, + input_signal_length=signal_len, ) if self.decoder_losses is not None: @@ -589,3 +606,408 @@ def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0): val_loss_mean = torch.stack([x['val_loss'] for x in outputs]).mean() tensorboard_logs = {'val_loss': val_loss_mean} return {'val_loss': val_loss_mean, 'log': tensorboard_logs} + + +class EncDecMaskedTokenPredModel(SpeechEncDecSelfSupervisedModel): + def transfer_batch_to_device(self, batch: Any, device: torch.device, dataloader_idx: int) -> Any: + """ + PTL hook: https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#transfer-batch-to-device + """ + batch = move_data_to_device(batch, device) + return batch + + @classmethod + def list_available_models(cls) -> List[PretrainedModelInfo]: + """ + This method returns a list of pre-trained model which can be instantiated directly from NVIDIA's NGC cloud. + + Returns: + List of available pre-trained models. + """ + results = [] + return results + + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + super().__init__(cfg, trainer) + + if self.cfg.get("mask_position", "pre_conv") == "post_conv": + # adjust config for post-convolution masking + self.cfg.quantizer.feat_in = self.cfg.encoder.d_model + self.cfg.masking.feat_in = self.cfg.encoder.d_model + self.cfg.masking.block_size = self.cfg.masking.block_size // self.cfg.encoder.subsampling_factor + self.cfg.loss.combine_time_steps = 1 + + self.quantizer = self.from_config_dict(self.cfg.quantizer) + self.mask_processor = self.from_config_dict(self.cfg.masking) + self.encoder = self.from_config_dict(self.cfg.encoder) + self.decoder = self.from_config_dict(self.cfg.decoder) + self.loss = self.from_config_dict(self.cfg.loss) + + self.pre_encoder = None + if self.cfg.get("mask_position", "pre_conv") == "post_conv": + # hacked to mask features after convolutional sub-sampling + self.pre_encoder = ConvFeatureMaksingWrapper(self.encoder.pre_encode, self.mask_processor) + self.encoder.pre_encode = self.pre_encoder + + @property + def input_types(self) -> Optional[Dict[str, NeuralType]]: + if hasattr(self.preprocessor, '_sample_rate'): + input_signal_eltype = AudioSignal(freq=self.preprocessor._sample_rate) + else: + input_signal_eltype = AudioSignal() + return { + "input_signal": NeuralType(('B', 'T'), input_signal_eltype, optional=True), + "input_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True), + "processed_signal": NeuralType(('B', 'D', 'T'), SpectrogramType(), optional=True), + "processed_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True), + "targets": NeuralType(('B', 'T'), LabelsType(), optional=True), + "target_lengths": NeuralType(tuple('B'), LengthsType(), optional=True), + "apply_mask": NeuralType(optional=True), + } + + @property + def output_types(self) -> Optional[Dict[str, NeuralType]]: + if self.cfg.num_books == 1 and self.cfg.squeeze_single: + logprobs = NeuralType(('B', 'T', 'C'), LogprobsType()) + tokens = NeuralType(('B', 'T'), LabelsType()) + else: + logprobs = NeuralType(('B', 'T', 'C', 'H'), LogprobsType()) + tokens = NeuralType(('B', 'T', 'H'), LabelsType()) + return { + "logprobs": logprobs, + "encoded_len": NeuralType(tuple('B'), LengthsType()), + "masks": NeuralType(('B', 'D', 'T'), SpectrogramType()), + "tokens": tokens, + } + + @typecheck() + def forward( + self, + input_signal=None, + input_signal_length=None, + processed_signal=None, + processed_signal_length=None, + apply_mask=False, + ): + has_input_signal = input_signal is not None and input_signal_length is not None + has_processed_signal = processed_signal is not None and processed_signal_length is not None + if (has_input_signal ^ has_processed_signal) == False: + raise ValueError( + f"{self} Arguments ``input_signal`` and ``input_signal_length`` are mutually exclusive " + " with ``processed_signal`` and ``processed_signal_len`` arguments." + ) + + if not has_processed_signal: + processed_signal, processed_signal_length = self.preprocessor( + input_signal=input_signal, + length=input_signal_length, + ) + + if self.pre_encoder is not None: + # mask after convolutional sub-sampling + self.pre_encoder.set_masking_enabled(apply_mask=apply_mask) + encoded, encoded_len = self.encoder(audio_signal=processed_signal, length=processed_signal_length) + masks = self.pre_encoder.get_current_mask() + feats = self.pre_encoder.get_current_feat() + _, tokens = self.quantizer(input_signal=feats.transpose(1, 2)) + else: + _, tokens = self.quantizer(input_signal=processed_signal) + if apply_mask: + masked_signal, masks = self.mask_processor( + input_feats=processed_signal, input_lengths=processed_signal_length + ) + else: + masked_signal = processed_signal + masks = torch.zeros_like(processed_signal) + encoded, encoded_len = self.encoder(audio_signal=masked_signal, length=processed_signal_length) + + log_probs = self.decoder(encoder_output=encoded) + + return log_probs, encoded_len, masks, tokens + + def training_step(self, batch, batch_idx): + input_signal, input_signal_length, _, _ = batch + if isinstance(batch, DALIOutputs) and batch.has_processed_signal: + log_probs, encoded_len, masks, tokens = self.forward( + processed_signal=input_signal, processed_signal_length=input_signal_length, apply_mask=True + ) + else: + log_probs, encoded_len, masks, tokens = self.forward( + input_signal=input_signal, input_signal_length=input_signal_length, apply_mask=True + ) + + loss_value = self.loss(masks=masks, decoder_outputs=log_probs, targets=tokens, decoder_lengths=encoded_len) + + tensorboard_logs = { + 'learning_rate': self._optimizer.param_groups[0]['lr'], + 'global_step': self.trainer.global_step, + 'train_loss': loss_value, + } + + return {'loss': loss_value, 'log': tensorboard_logs} + + def inference_pass(self, batch, batch_idx, dataloader_idx=0, mode='val', apply_mask=False): + input_signal, input_signal_length, _, _ = batch + if isinstance(batch, DALIOutputs) and batch.has_processed_signal: + log_probs, encoded_len, masks, tokens = self.forward( + processed_signal=input_signal, processed_signal_length=input_signal_length, apply_mask=apply_mask + ) + else: + log_probs, encoded_len, masks, tokens = self.forward( + input_signal=input_signal, input_signal_length=input_signal_length, apply_mask=apply_mask + ) + + loss_value = self.loss(masks=masks, decoder_outputs=log_probs, targets=tokens, decoder_lengths=encoded_len) + + return {f'{mode}_loss': loss_value} + + def validation_step(self, batch, batch_idx, dataloader_idx=0): + metrics = self.inference_pass(batch, batch_idx, dataloader_idx, apply_mask=True) + if type(self.trainer.val_dataloaders) == list and len(self.trainer.val_dataloaders) > 1: + self.validation_step_outputs[dataloader_idx].append(metrics) + else: + self.validation_step_outputs.append(metrics) + return metrics + + def test_step(self, batch, batch_idx, dataloader_idx=0): + metrics = self.inference_pass(batch, batch_idx, dataloader_idx, mode="test", apply_mask=True) + if type(self.trainer.val_dataloaders) == list and len(self.trainer.val_dataloaders) > 1: + self.validation_step_outputs[dataloader_idx].append(metrics) + else: + self.validation_step_outputs.append(metrics) + return metrics + + def multi_validation_epoch_end(self, outputs: list, dataloader_idx: int = 0): + loss_list = [] + for i, x in enumerate(outputs): + if not isinstance(x, dict): + logging.warning(f'Batch {i} output in validation dataloader {dataloader_idx} is not a dictionary: {x}') + if 'val_loss' in x: + loss_list.append(x['val_loss']) + else: + logging.warning( + f'Batch {i} output in validation dataloader {dataloader_idx} does not have key `val_loss`: {x}' + ) + + if len(loss_list) == 0: + logging.warning( + f'Epoch {self.current_epoch} received no batches for validation dataloader {dataloader_idx}.' + ) + return {} + + val_loss_mean = torch.stack(loss_list).mean() + tensorboard_logs = {'val_loss': val_loss_mean} + return {'val_loss': val_loss_mean, 'log': tensorboard_logs} + + def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0): + test_loss_mean = torch.stack([x['test_loss'] for x in outputs]).mean() + tensorboard_logs = {'test_loss': test_loss_mean} + return {'test_loss': test_loss_mean, 'log': tensorboard_logs} + + +class EncDecDenoiseMaskedTokenPredModel(EncDecMaskedTokenPredModel): + def __init__(self, cfg: DictConfig, trainer: Trainer = None): + super().__init__(cfg, trainer) + + def _setup_dataloader_from_config(self, config: Optional[Dict]): + audio_to_text_dataset.inject_dataloader_value_from_model_config(self.cfg, config, key='sample_rate') + + if config.get("use_lhotse"): + return get_lhotse_dataloader_from_config( + config, + global_rank=self.global_rank, + world_size=self.world_size, + dataset=ssl_dataset.LhotseAudioNoiseDataset( + noise_manifest=config.get('noise_manifest', None), + batch_augmentor_cfg=config.get('batch_augmentor', None), + ), + ) + + dataset = ssl_dataset.get_audio_noise_dataset_from_config( + config, + global_rank=self.global_rank, + world_size=self.world_size, + ) + + shuffle = config['shuffle'] + if isinstance(dataset, torch.utils.data.IterableDataset): + shuffle = False + + if hasattr(dataset, 'collate_fn'): + collate_fn = dataset.collate_fn + elif hasattr(dataset.datasets[0], 'collate_fn'): + # support datasets that are lists of entries + collate_fn = dataset.datasets[0].collate_fn + else: + # support datasets that are lists of lists + collate_fn = dataset.datasets[0].datasets[0].collate_fn + + return torch.utils.data.DataLoader( + dataset=dataset, + batch_size=config['batch_size'], + collate_fn=collate_fn, + drop_last=config.get('drop_last', False), + shuffle=shuffle, + num_workers=config.get('num_workers', 0), + pin_memory=config.get('pin_memory', False), + ) + + @property + def input_types(self) -> Optional[Dict[str, NeuralType]]: + if hasattr(self.preprocessor, '_sample_rate'): + input_signal_eltype = AudioSignal(freq=self.preprocessor._sample_rate) + else: + input_signal_eltype = AudioSignal() + return { + "input_signal": NeuralType(('B', 'T'), input_signal_eltype, optional=True), + "input_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True), + "processed_signal": NeuralType(('B', 'D', 'T'), SpectrogramType(), optional=True), + "processed_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True), + "noise_signal": NeuralType(('B', 'T'), input_signal_eltype, optional=True), + "noise_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True), + "processed_noise_signal": NeuralType(('B', 'D', 'T'), SpectrogramType(), optional=True), + "processed_noise_length": NeuralType(tuple('B'), LengthsType(), optional=True), + "noisy_input_signal": NeuralType(('B', 'T'), input_signal_eltype, optional=True), + "noisy_input_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True), + "processed_noisy_signal": NeuralType(('B', 'D', 'T'), SpectrogramType(), optional=True), + "processed_noisy_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True), + "apply_mask": NeuralType(optional=True), + } + + @property + def output_types(self) -> Optional[Dict[str, NeuralType]]: + if self.cfg.num_books == 1 and self.cfg.squeeze_single: + logprobs = NeuralType(('B', 'T', 'C'), LogprobsType()) + tokens = NeuralType(('B', 'T'), LabelsType()) + else: + logprobs = NeuralType(('B', 'T', 'C', 'H'), LogprobsType()) + tokens = NeuralType(('B', 'T', 'H'), LabelsType()) + return { + "logprobs": logprobs, + "encoded_len": NeuralType(tuple('B'), LengthsType()), + "masks": NeuralType(('B', 'D', 'T'), SpectrogramType()), + "tokens": tokens, + } + + @typecheck() + def forward( + self, + input_signal=None, + input_signal_length=None, + processed_signal=None, + processed_signal_length=None, + noise_signal=None, # noqa + noise_signal_length=None, # noqa + processed_noise_signal=None, # noqa + processed_noise_signal_length=None, # noqa + noisy_input_signal=None, + noisy_input_signal_length=None, + processed_noisy_input_signal=None, + processed_noisy_input_signal_length=None, + apply_mask=False, + ): + has_input_signal = input_signal is not None and input_signal_length is not None + has_processed_signal = processed_signal is not None and processed_signal_length is not None + if (has_input_signal ^ has_processed_signal) == False: + raise ValueError( + f"{self} Arguments ``input_signal`` and ``input_signal_length`` are mutually exclusive " + " with ``processed_signal`` and ``processed_signal_len`` arguments." + ) + if not has_processed_signal: + processed_signal, processed_signal_length = self.preprocessor( + input_signal=input_signal, + length=input_signal_length, + ) + + ### Following code snipet is not used but kept for future reference + # + # has_noise_signal = noise_signal is not None and noise_signal_length is not None + # has_processed_noise_signal = processed_noise_signal is not None and processed_noise_signal_length is not None + # if (has_noise_signal ^ has_processed_noise_signal) == False: + # raise ValueError( + # f"{self} Arguments ``noise_signal`` and ``noise_signal_length`` are mutually exclusive " + # " with ``processed_noise_signal`` and ``processed_noise_signal_len`` arguments." + # ) + # if not has_processed_noise_signal: + # processed_noise_signal, processed_noise_signal_length = self.preprocessor( + # input_signal=noise_signal, + # length=noise_signal_length, + # ) + + has_noisy_input_signal = noisy_input_signal is not None and noisy_input_signal_length is not None + has_processed_noisy_input_signal = ( + processed_noisy_input_signal is not None and processed_noisy_input_signal_length is not None + ) + if (has_noisy_input_signal ^ has_processed_noisy_input_signal) == False: + raise ValueError( + f"{self} Arguments ``noisy_input_signal`` and ``noisy_input_signal_length`` are mutually exclusive " + " with ``processed_noisy_input_signal`` and ``processed_noisy_input_signal_len`` arguments." + ) + if not has_processed_noisy_input_signal: + processed_noisy_input_signal, processed_noisy_input_signal_length = self.preprocessor( + input_signal=noisy_input_signal, + length=noisy_input_signal_length, + ) + + if self.pre_encoder is not None: + # mask after convolutional sub-sampling + feats, _ = self.pre_encoder.pre_encode(x=processed_signal, lengths=processed_signal_length) + _, tokens = self.quantizer(input_signal=feats.transpose(1, 2)) + + self.pre_encoder.set_masking_enabled(apply_mask=apply_mask) + encoded, encoded_len = self.encoder( + audio_signal=processed_noisy_input_signal, length=processed_noisy_input_signal_length + ) + masks = self.pre_encoder.get_current_mask() + else: + _, tokens = self.quantizer(input_signal=processed_signal) + if apply_mask: + masked_signal, masks = self.mask_processor( + input_feats=processed_noisy_input_signal, input_lengths=processed_noisy_input_signal_length + ) + else: + masked_signal = processed_noisy_input_signal + masks = torch.zeros_like(processed_noisy_input_signal) + encoded, encoded_len = self.encoder(audio_signal=masked_signal, length=processed_noisy_input_signal_length) + + log_probs = self.decoder(encoder_output=encoded) + + return log_probs, encoded_len, masks, tokens + + def training_step(self, batch: ssl_dataset.AudioNoiseBatch, batch_idx: int): + log_probs, encoded_len, masks, tokens = self.forward( + input_signal=batch.audio, + input_signal_length=batch.audio_len, + noise_signal=batch.noise, + noise_signal_length=batch.noise_len, + noisy_input_signal=batch.noisy_audio, + noisy_input_signal_length=batch.noisy_audio_len, + apply_mask=True, + ) + + loss_value = self.loss(masks=masks, decoder_outputs=log_probs, targets=tokens, decoder_lengths=encoded_len) + + tensorboard_logs = { + 'learning_rate': self._optimizer.param_groups[0]['lr'], + 'global_step': self.trainer.global_step, + 'train_loss': loss_value, + } + + return {'loss': loss_value, 'log': tensorboard_logs} + + def inference_pass( + self, batch: ssl_dataset.AudioNoiseBatch, batch_idx: int, dataloader_idx: int = 0, mode: str = 'val' + ): + log_probs, encoded_len, masks, tokens = self.forward( + input_signal=batch.audio, + input_signal_length=batch.audio_len, + noise_signal=batch.noise, + noise_signal_length=batch.noise_len, + noisy_input_signal=batch.noisy_audio, + noisy_input_signal_length=batch.noisy_audio_len, + apply_mask=True, + ) + + loss_value = self.loss(masks=masks, decoder_outputs=log_probs, targets=tokens, decoder_lengths=encoded_len) + + return {f'{mode}_loss': loss_value} diff --git a/nemo/collections/asr/modules/__init__.py b/nemo/collections/asr/modules/__init__.py index a412040a3b67..940eb079ae27 100644 --- a/nemo/collections/asr/modules/__init__.py +++ b/nemo/collections/asr/modules/__init__.py @@ -44,3 +44,9 @@ StatelessTransducerDecoder, ) from nemo.collections.asr.modules.squeezeformer_encoder import SqueezeformerEncoder, SqueezeformerEncoderAdapter +from nemo.collections.asr.modules.ssl_modules import ( + ConvFeatureMaksingWrapper, + MultiSoftmaxDecoder, + RandomBlockMasking, + RandomProjectionVectorQuantizer, +) diff --git a/nemo/collections/asr/modules/conv_asr.py b/nemo/collections/asr/modules/conv_asr.py index 03b94ae0b209..3cb9ec13109b 100644 --- a/nemo/collections/asr/modules/conv_asr.py +++ b/nemo/collections/asr/modules/conv_asr.py @@ -85,8 +85,7 @@ def input_example(self, max_batch=1, max_dim=8192): @property def input_types(self): - """Returns definitions of module input ports. - """ + """Returns definitions of module input ports.""" return OrderedDict( { "audio_signal": NeuralType(('B', 'D', 'T'), SpectrogramType()), @@ -96,8 +95,7 @@ def input_types(self): @property def output_types(self): - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return OrderedDict( { "outputs": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()), @@ -273,8 +271,7 @@ def restore_from(cls, restore_path: str): @property def input_types(self): - """Returns definitions of module input ports. - """ + """Returns definitions of module input ports.""" return OrderedDict( { "audio_signal": NeuralType(('B', 'D', 'T'), SpectrogramType()), @@ -284,8 +281,7 @@ def input_types(self): @property def output_types(self): - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return OrderedDict( { "outputs": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()), @@ -401,10 +397,10 @@ def forward(self, audio_signal, length=None): class ConvASRDecoder(NeuralModule, Exportable, adapter_mixins.AdapterModuleMixin): """Simple ASR Decoder for use with CTC-based models such as JasperNet and QuartzNet - Based on these papers: - https://arxiv.org/pdf/1904.03288.pdf - https://arxiv.org/pdf/1910.10261.pdf - https://arxiv.org/pdf/2005.04290.pdf + Based on these papers: + https://arxiv.org/pdf/1904.03288.pdf + https://arxiv.org/pdf/1910.10261.pdf + https://arxiv.org/pdf/2005.04290.pdf """ @property @@ -502,8 +498,7 @@ def num_classes_with_blank(self): class ConvASRDecoderReconstruction(NeuralModule, Exportable): - """ASR Decoder for reconstructing masked regions of spectrogram - """ + """ASR Decoder for reconstructing masked regions of spectrogram""" @property def input_types(self): @@ -623,8 +618,8 @@ def _prepare_for_export(self, **kwargs): class ConvASRDecoderClassification(NeuralModule, Exportable): """Simple ASR Decoder for use with classification models such as JasperNet and QuartzNet - Based on these papers: - https://arxiv.org/pdf/2005.04290.pdf + Based on these papers: + https://arxiv.org/pdf/2005.04290.pdf """ def input_example(self, max_batch=1, max_dim=256): @@ -668,8 +663,7 @@ def __init__( self.decoder_layers = torch.nn.Sequential(torch.nn.Linear(self._feat_in, self._num_classes, bias=True)) self.apply(lambda x: init_weights(x, mode=init_mode)) - @typecheck() - def forward(self, encoder_output): + def forward(self, encoder_output, **kwargs): batch, in_channels, timesteps = encoder_output.size() encoder_output = self.pooling(encoder_output).view(batch, in_channels) # [B, C] @@ -705,8 +699,7 @@ class ECAPAEncoder(NeuralModule, Exportable): @property def input_types(self): - """Returns definitions of module input ports. - """ + """Returns definitions of module input ports.""" return OrderedDict( { "audio_signal": NeuralType(('B', 'D', 'T'), SpectrogramType()), @@ -716,8 +709,7 @@ def input_types(self): @property def output_types(self): - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return OrderedDict( { "outputs": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()), @@ -853,7 +845,11 @@ def __init__( self.apply(lambda x: init_weights(x, mode=init_mode)) def affine_layer( - self, inp_shape, out_shape, learn_mean=True, affine_type='conv', + self, + inp_shape, + out_shape, + learn_mean=True, + affine_type='conv', ): if affine_type == 'conv': layer = nn.Sequential( @@ -919,12 +915,16 @@ def _update_adapter_cfg_input_dim(self, block: JasperBlock, cfg): cfg = adapter_utils.update_adapter_cfg_input_dim(self, cfg, module_dim=block.planes) return cfg - def get_accepted_adapter_types(self,) -> Set[type]: + def get_accepted_adapter_types( + self, + ) -> Set[type]: types = super().get_accepted_adapter_types() if len(types) == 0: self.set_accepted_adapter_types( - [adapter_utils.LINEAR_ADAPTER_CLASSPATH,] + [ + adapter_utils.LINEAR_ADAPTER_CLASSPATH, + ] ) types = self.get_accepted_adapter_types() return types diff --git a/nemo/collections/asr/modules/ssl_modules/__init__.py b/nemo/collections/asr/modules/ssl_modules/__init__.py new file mode 100644 index 000000000000..0f31d4055df1 --- /dev/null +++ b/nemo/collections/asr/modules/ssl_modules/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo.collections.asr.modules.ssl_modules.augmentation import ( + MultiSpeakerNoiseAugmentation, + SpeakerNoiseAugmentation, +) +from nemo.collections.asr.modules.ssl_modules.masking import ConvFeatureMaksingWrapper, RandomBlockMasking +from nemo.collections.asr.modules.ssl_modules.multi_softmax_decoder import MultiSoftmaxDecoder +from nemo.collections.asr.modules.ssl_modules.quantizers import RandomProjectionVectorQuantizer diff --git a/nemo/collections/asr/modules/ssl_modules/augmentation.py b/nemo/collections/asr/modules/ssl_modules/augmentation.py new file mode 100644 index 000000000000..bb63b7f38b9a --- /dev/null +++ b/nemo/collections/asr/modules/ssl_modules/augmentation.py @@ -0,0 +1,290 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import numpy as np +import torch + +from nemo.collections.asr.data.ssl_dataset import AudioNoiseBatch + + +class SpeakerNoiseAugmentation(object): + def __init__( + self, + prob: float = 0.0, + noise_ratio: float = 0.0, + min_r_speech: float = -5.0, + max_r_speech: float = 5.0, + min_r_noise: float = -5.0, + max_r_noise: float = 20.0, + min_mix_rate: float = 0.0, + max_mix_rate: float = 1.0, + ): + super().__init__() + self.prob = prob + self.noise_ratio = noise_ratio + self.min_r_speech = min_r_speech + self.max_r_speech = max_r_speech + self.min_r_noise = min_r_noise + self.max_r_noise = max_r_noise + self.min_mix_rate = min_mix_rate + self.max_mix_rate = max_mix_rate + + if not (0 <= self.prob <= 1): + raise ValueError(f"prob must be in [0, 1], got: {self.prob}") + if not (0 <= self.noise_ratio <= 1): + raise ValueError(f"noise_ratio must be in [0, 1], got: {self.noise_ratio}") + if not (self.min_r_speech <= self.max_r_speech): + raise ValueError( + f"min_r_speech must be no greater than max_r_speech, got: min={self.min_r_speech} and max={self.max_r_speech}" + ) + if not (self.min_r_noise <= self.max_r_noise): + raise ValueError( + f"min_r_noise must be no greater than max_r_noise, got: min={self.min_r_noise} and max={self.max_r_noise}" + ) + if not (0 <= self.min_mix_rate <= self.max_mix_rate <= 1): + raise ValueError( + f"min_mix_rate must be no greater than max_mix_rate, and both must be in [0, 1], got: {self.min_mix_rate} and {self.max_mix_rate}" + ) + + def repeat_noise(self, noise: torch.Tensor, noise_len: int, max_audio_len: int) -> torch.Tensor: + noise = noise[:noise_len] + if noise_len < max_audio_len: + noise = noise.repeat(max_audio_len // noise_len + 1) + noise = noise[:max_audio_len] + return noise + + def pad_or_trim_noise(self, noise: torch.Tensor, max_audio_len: int, value=0) -> torch.Tensor: + noise_len = noise.size(0) + if noise_len < max_audio_len: + pad = (0, max_audio_len - noise_len) + noise = torch.nn.functional.pad(noise, pad, value=value) + else: + noise = noise[:max_audio_len] + return noise + + def __call__(self, batch: AudioNoiseBatch) -> AudioNoiseBatch: + audio_signal = batch.audio + audio_lengths = batch.audio_len + batch_size = audio_signal.size(0) + max_audio_len = audio_signal.size(1) + + noise = batch.noise + noise_len = batch.noise_len + noisy_audio = batch.noisy_audio + noisy_audio_len = batch.noisy_audio_len + for i in range(batch_size): + if np.random.rand() > self.prob: + continue + + # randomly select the length of mixing segment + if 0 <= self.min_mix_rate < self.max_mix_rate <= 1: + mix_len = np.random.randint( + int(audio_lengths[i] * self.min_mix_rate), int(audio_lengths[i] * self.max_mix_rate) + ) + else: + mix_len = max(1, int(audio_lengths[i] * self.min_mix_rate)) + + # randomly select position to start the mixing + mix_start_idx = np.random.randint(audio_lengths[i] - mix_len) + + # randomly select the energy ratio between speech and noise + if np.random.rand() < self.noise_ratio or batch_size == 1: + energy_ratio = np.random.uniform(self.min_r_noise, self.max_r_noise) + else: + energy_ratio = np.random.uniform(self.min_r_speech, self.max_r_speech) + j = np.random.choice([x for x in range(batch_size) if x != i]) + noise[i] = audio_signal[j].clone() + noise_len[i] = audio_lengths[j] + + # repeat noise to match the length of audio mix length if necessary + if noise_len[i] <= mix_len: + # repeat noise to match the length of audio mix length + noise_start_idx = 0 + noise[i] = self.pad_or_trim_noise(self.repeat_noise(noise[i], noise_len[i], mix_len), max_audio_len) + noise_len[i] = mix_len + else: + # randomly select a segment of noise + noise_start_idx = np.random.randint(noise_len[i] - mix_len) + + # calculate the scale factor for noise + audio_energy = torch.sum(audio_signal[i, : audio_lengths[i]] ** 2) / audio_lengths[i] + noise_energy = torch.sum(noise[i, : noise_len[i]] ** 2) / noise_len[i] if noise_len[i] > 0 else 0 + mix_scale = math.sqrt(audio_energy / (10 ** (energy_ratio / 10) * noise_energy)) if noise_energy > 0 else 0 + + # get the residual signal to be added to original audio + noise_clip = noise[i, noise_start_idx : noise_start_idx + mix_len] + noise_signal = torch.zeros_like(audio_signal[i]) + noise_signal[mix_start_idx : mix_start_idx + mix_len] = mix_scale * noise_clip + + # add noise to audio + noisy_audio[i] = audio_signal[i] + noise_signal + noisy_audio_len[i] = audio_lengths[i] + noise[i] = noise_signal + noise_len[i] = audio_lengths[i] + + return AudioNoiseBatch( + sample_id=batch.sample_id, + audio=batch.audio, + audio_len=batch.audio_len, + noise=noise, + noise_len=noise_len, + noisy_audio=noisy_audio, + noisy_audio_len=noisy_audio_len, + ) + + +class MultiSpeakerNoiseAugmentation(SpeakerNoiseAugmentation): + def __init__( + self, + prob: float = 0.0, + noise_ratio: float = 0.0, + min_r_speech: float = -5.0, + max_r_speech: float = 5.0, + min_r_noise: float = -5.0, + max_r_noise: float = 20.0, + min_mix_rate: float = 0.0, + max_mix_rate: float = 1.0, + min_num_segments: int = 1, + max_num_segments: int = 5, + min_num_speakers: int = 1, + max_num_speakers: int = 4, + ): + super().__init__( + prob=prob, + noise_ratio=noise_ratio, + min_r_speech=min_r_speech, + max_r_speech=max_r_speech, + min_r_noise=min_r_noise, + max_r_noise=max_r_noise, + min_mix_rate=min_mix_rate, + max_mix_rate=max_mix_rate, + ) + self.min_num_segments = min_num_segments + self.max_num_segments = max_num_segments + self.min_num_speakers = min_num_speakers + self.max_num_speakers = max_num_speakers + + def __call__(self, batch: AudioNoiseBatch) -> AudioNoiseBatch: + audio_signal = batch.audio + audio_lengths = batch.audio_len + batch_size = audio_signal.size(0) + + noise = batch.noise + noise_len = batch.noise_len + noisy_audio = batch.noisy_audio + noisy_audio_len = batch.noisy_audio_len + for i in range(batch_size): + if np.random.rand() > self.prob: + continue + + # randomly select the length of mixing segment + if 0 <= self.min_mix_rate < self.max_mix_rate <= 1: + mix_rate = np.random.uniform(self.min_mix_rate, self.max_mix_rate) + else: + mix_rate = self.min_mix_rate + mix_len = max(1, int(audio_lengths[i] * mix_rate)) + + # randomly select the number of segments + num_segments = np.random.randint(self.min_num_segments, self.max_num_segments + 1) + num_speakers = np.random.randint(self.min_num_speakers, self.max_num_speakers + 1) + num_speakers = min(num_speakers, batch_size) + + # randomly chunk mix_len into num_segments + segment_lens = np.random.multinomial(mix_len, [1 / num_segments] * num_segments) + + # randomly select the energy ratio between speech and noise + if np.random.rand() < self.noise_ratio or batch_size == 1: + mode = "noise" + energy_ratio = np.random.uniform(self.min_r_noise, self.max_r_noise) + else: + mode = "speech" + energy_ratio = np.random.uniform(self.min_r_speech, self.max_r_speech) + + noise_segments = self.get_noise_segments(i, batch, segment_lens, num_speakers, mode) + noise_signal = torch.zeros_like(audio_signal[i]) + min_start_idx = 0 + max_start_idx = audio_lengths[i] - mix_len + for j in range(num_segments): + start_idx = min_start_idx + if min_start_idx < max_start_idx: + start_idx = np.random.randint(min_start_idx, max_start_idx) + noise_signal[start_idx : start_idx + segment_lens[j]] = noise_segments[j] + min_start_idx = start_idx + segment_lens[j] + max_start_idx += segment_lens[j] + + # calculate the scale factor for noise + audio_energy = torch.sum(audio_signal[i, : audio_lengths[i]] ** 2) / audio_lengths[i] + noise_energy = torch.sum(noise_signal[: audio_lengths[i]] ** 2) / audio_lengths[i] + mix_scale = math.sqrt(audio_energy / (10 ** (energy_ratio / 10) * noise_energy)) if noise_energy > 0 else 0 + + # get the residual signal to be added to original audio + noise_signal = mix_scale * noise_signal + + # add noise to audio + noisy_audio[i] = audio_signal[i] + noise_signal + noisy_audio_len[i] = audio_lengths[i] + noise[i] = noise_signal + noise_len[i] = audio_lengths[i] + + return AudioNoiseBatch( + sample_id=batch.sample_id, + audio=batch.audio, + audio_len=batch.audio_len, + noise=noise, + noise_len=noise_len, + noisy_audio=noisy_audio, + noisy_audio_len=noisy_audio_len, + ) + + def get_noise_segments(self, batch_idx, batch, segment_lens, num_speakers, mode): + audio_signal = batch.audio + audio_lengths = batch.audio_len + noise = batch.noise + noise_len = batch.noise_len + batch_size = noise.size(0) + max_audio_len = audio_signal.size(1) + noise_segments = [] + if mode == "noise": + noise_padded = self.pad_or_trim_noise( + self.repeat_noise(noise[batch_idx], noise_len[batch_idx], max_audio_len), max_audio_len + ) + start_idx = 0 + for segment_len in segment_lens: + noise_segments.append(noise_padded[start_idx : start_idx + segment_len]) + start_idx += segment_len + return noise_segments + + if mode != "speech": + raise ValueError(f"mode must be either 'noise' or 'speech', got: {mode}") + + speaker_candidates = [x for x in range(batch_size) if x != batch_idx] + speaker_candidates = np.random.choice(speaker_candidates, min(num_speakers, batch_size - 1), replace=False) + sid = 0 + for seg_len in segment_lens: + bid = speaker_candidates[sid] + if seg_len > audio_lengths[bid]: + audio_segment = self.pad_or_trim_noise( + self.repeat_noise(audio_signal[bid], audio_lengths[bid], seg_len), seg_len + ) + else: + start_idx = np.random.randint(audio_lengths[bid] - seg_len) if audio_lengths[bid] > seg_len else 0 + audio_segment = audio_signal[bid][start_idx : start_idx + seg_len].clone() + noise_segments.append(audio_segment) + sid += 1 + if sid >= len(speaker_candidates): + sid = np.random.randint(len(speaker_candidates)) + + return noise_segments diff --git a/nemo/collections/asr/modules/ssl_modules/masking.py b/nemo/collections/asr/modules/ssl_modules/masking.py new file mode 100644 index 000000000000..3c3550dddd4c --- /dev/null +++ b/nemo/collections/asr/modules/ssl_modules/masking.py @@ -0,0 +1,199 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Optional, Union + +import numpy as np +import torch +import torch.nn as nn + +from nemo.core.classes import NeuralModule +from nemo.core.neural_types import AcousticEncodedRepresentation, LengthsType, NeuralType + + +class RandomBlockMasking(NeuralModule): + """ + Performs random block masking on sequence of features. + Args: + mask_prob (float): percentage of sequence to mask + block_size (int): size of each block to mask + mask_value (Optional[float]): value to use for masking, if None, use random values + feat_in (Optional[int]): size of input features, required if mask_value is None + freeze (bool): if True, mask embedding is not trainable + allow_overlap (bool): if True, masked blocks can overlap + """ + + def __init__( + self, + feat_in: int, + mask_prob: float = 0.5, + block_size: int = 48, + mask_value: Optional[float] = None, + freeze: bool = True, + allow_overlap: bool = False, + max_mask_ratio: float = 0.8, + ): + super().__init__() + self.block_size = block_size + self.mask_prob = mask_prob + self.allow_overlap = allow_overlap + self.max_mask_ratio = max_mask_ratio + + if mask_value is None: + self.mask_embedding = nn.Parameter(torch.FloatTensor(feat_in)) + nn.init.normal_(self.mask_embedding, mean=0.0, std=0.1) + else: + self.mask_embedding = nn.Parameter(torch.ones(feat_in) * mask_value, requires_grad=False) + if freeze: + self.freeze() + + @property + def input_types(self): + """Returns definitions of module input types""" + return { + "input_feats": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()), + "input_lengths": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + """Returns definitions of module output types""" + return { + "maksed_feats": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()), + "masks": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()), + } + + def forward(self, input_feats: torch.Tensor, input_lengths: torch.Tensor): + """ + Args: + input_feats (Tensor): input sequence features, shape=(batch, features, time) + input_length (Tensor): length of each sequence in the batch, shape=(batch) + Returns: + masked_feats (Tensor): masked features, shape=(batch, features, time) + masks (Tensor): the generated masks, shape=(batch, features, time) + """ + if self.allow_overlap: + return self.forward_with_overlap(input_feats, input_lengths) + else: + return self.forward_without_overlap(input_feats, input_lengths) + + def forward_without_overlap(self, input_feats: torch.Tensor, input_lengths: torch.Tensor): + """ + Args: + input_feats (Tensor): input sequence features, shape=(batch, features, time) + input_length (Tensor): length of each sequence in the batch, shape=(batch) + Returns: + masked_feats (Tensor): masked features, shape=(batch, features, time) + masks (Tensor): the generated masks, shape=(batch, features, time) + """ + batch_size = input_feats.size(0) + mask_value = self.mask_embedding.unsqueeze(-1) + masks = torch.zeros_like(input_feats) + maksed_feats = input_feats.clone() + for i in range(batch_size): + if self.block_size >= input_lengths[i] * self.max_mask_ratio: + # handle case where audio is too short + block_size = 8 + num_patches = 1 + patch_idx = [0] + else: + num_patches = torch.ceil(input_lengths[i] * self.mask_prob / self.block_size).int() + offset = torch.randint(0, self.block_size, (1,), device=input_feats.device)[0] + block_size = self.block_size + if (num_patches + 1) * self.block_size > input_lengths[i]: + block_size = torch.div(input_lengths[i], (num_patches + 1), rounding_mode='trunc') + max_num_patches = torch.div(input_lengths[i], block_size, rounding_mode='trunc') + patch_idx = torch.randperm(max_num_patches - 1, device=input_feats.device)[:num_patches] + for j in range(num_patches): + start = patch_idx[j] * block_size + offset + end = start + block_size + masks[i, :, start:end] = 1.0 + maksed_feats[i, :, start:end] = mask_value + return maksed_feats, masks + + def forward_with_overlap(self, input_feats: torch.Tensor, input_lengths: torch.Tensor): + """ + Args: + input_feats (Tensor): input sequence features, shape=(batch, features, time) + input_length (Tensor): length of each sequence in the batch, shape=(batch) + Returns: + masked_feats (Tensor): masked features, shape=(batch, features, time) + masks (Tensor): the generated masks, shape=(batch, features, time) + """ + batch_size = input_feats.size(0) + mask_value = self.mask_embedding.unsqueeze(-1) + masks = torch.zeros_like(input_feats) + maksed_feats = input_feats.clone() + for i in range(batch_size): + if self.block_size >= input_lengths[i] * self.max_mask_ratio: + # handle case where audio is too short + curr_block_size = 8 + num_patches = 1 + patch_idices = [0] + else: + curr_block_size = self.block_size + curr_len = input_lengths[i].detach().cpu().numpy() + num_patches = np.random.binomial(max(0, curr_len - self.block_size), self.mask_prob) + patch_idices = torch.randperm(max(0, curr_len - self.block_size), device=input_feats.device) + patch_idices = patch_idices[:num_patches] + for j in range(num_patches): + start = patch_idices[j] + end = min(start + curr_block_size, input_lengths[i]) + masks[i, :, start:end] = 1.0 + maksed_feats[i, :, start:end] = mask_value + return maksed_feats, masks + + +class ConvFeatureMaksingWrapper(NeuralModule): + """ + A wrapper module that applies masking to the features after subsampling layer of ConformerEncoder. + """ + + def __init__(self, pre_encode_module: nn.Module, masking_module: Union[nn.Module, NeuralModule]) -> None: + """ + Args: + pre_encode_module: the pre_encode module of the ConformerEncoder instance + masking_module: the module that performs masking on the extracted features + """ + super().__init__() + self.pre_encode = pre_encode_module + self.masking = masking_module + self.curr_mask = None + self.curr_feat = None + self.apply_mask = False + + def forward(self, x, lengths): + """ + Same interface as ConformerEncoder.pre_encode + """ + feats, lengths = self.pre_encode(x=x, lengths=lengths) + self.curr_feat = feats.detach() + if self.apply_mask: + feats = feats.transpose(1, 2) + masked_feats, self.curr_mask = self.masking(input_feats=feats, input_lengths=lengths) + masked_feats = masked_feats.transpose(1, 2).detach() + else: + masked_feats = feats + self.curr_mask = torch.zeros_like(feats) + return masked_feats, lengths + + def set_masking_enabled(self, apply_mask: bool): + self.apply_mask = apply_mask + + def get_current_mask(self): + return self.curr_mask + + def get_current_feat(self): + return self.curr_feat diff --git a/nemo/collections/asr/modules/ssl_modules/multi_layer_feat.py b/nemo/collections/asr/modules/ssl_modules/multi_layer_feat.py new file mode 100644 index 000000000000..e38e3abb6774 --- /dev/null +++ b/nemo/collections/asr/modules/ssl_modules/multi_layer_feat.py @@ -0,0 +1,206 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple + +import torch +import torch.distributed +import torch.nn as nn + +from nemo.collections.asr.modules import AudioToMelSpectrogramPreprocessor, ConformerEncoder +from nemo.core.classes import Exportable, NeuralModule +from nemo.core.classes.mixins import AccessMixin +from nemo.utils import logging + + +class Aggregator(nn.Module): + AVAILABLE_POOLING = ["cat", "sum", "mean", "avg", "max", "min", "none", "weighted_sum"] + + def __init__(self, mode, weights, layer_idx_list, channel_idx: int = 1): + """ + Args: + mode: Aggregation mode. One of ["cat", "sum", "mean", "avg", "max", "min", "none", "weighted_sum"] + weights: Weights for weighted sum aggregation. If None, weights are initialized to 1/num_layers. + layer_idx_list: List of layer indices to aggregate. + channel_idx: Channel dimension index of the input tensors. + """ + super().__init__() + self.mode = mode + self.channel_idx = channel_idx + self.weights = weights + if self.mode not in self.AVAILABLE_POOLING: + raise ValueError(f"Unknown mode `{self.mode}`, available modes are {self.AVAILABLE_POOLING}") + if self.mode == "weighted_sum" and self.weights is None: + self.weights = nn.Parameter(torch.ones(len(layer_idx_list)) / len(layer_idx_list)) + + def _forward_for_weighted_sum( + self, encoded: List[torch.Tensor], encoded_len: List[torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor]: + encoded_weighted = [encoded[i] * self.weights[i] for i in range(len(encoded))] + encoded_weighted = torch.sum(torch.stack(encoded_weighted, dim=-1), dim=-1) + return encoded_weighted, encoded_len[0] + + def forward( + self, encoded: List[torch.Tensor], encoded_len: List[torch.Tensor] + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + encoded: List of tensors of shape [B, D, T] representing the encoded features from different layers. + encoded_len: List of tensors of shape [B] representing the lengths of the encoded features. + Returns: + aggregated: Aggregated tensor of shape [B, D, T] representing the aggregated features. + aggregated_len: Tensor of shape [B] representing the lengths of the aggregated features. + """ + + if self.mode == "cat": + return torch.cat(encoded, dim=self.channel_idx), encoded_len[0] + elif self.mode == "sum": + return torch.cat([x.unsqueeze(-1) for x in encoded], dim=-1).sum(dim=-1), encoded_len[0] + elif self.mode == "mean" or self.mode == "avg": + return torch.cat([x.unsqueeze(-1) for x in encoded], dim=-1).mean(dim=-1), encoded_len[0] + elif self.mode == "max": + return torch.cat([x.unsqueeze(-1) for x in encoded], dim=-1).max(dim=-1), encoded_len[0] + elif self.mode == "min": + return torch.cat([x.unsqueeze(-1) for x in encoded], dim=-1).min(dim=-1), encoded_len[0] + elif self.mode == "none": + return encoded, encoded_len + elif self.mode == "weighted_sum": + return self._forward_for_weighted_sum(encoded, encoded_len) + else: + raise ValueError(f"Unknown mode {self.mode}") + + +class ConformerMultiLayerFeatureExtractor(NeuralModule, Exportable): + def __init__(self, encoder, aggregator, layer_idx_list): + """ + Args: + encoder: ConformerEncoder instance. + aggregator: Aggregator instance. + layer_idx_list: List of layer indices to extract features from. + """ + super().__init__() + self.encoder = encoder + self.aggregator = aggregator + self.layer_idx_list = ( + [int(l) for l in layer_idx_list] + if layer_idx_list is not None + else [i for i in range(len(self.encoder.layers))] + ) + for x in self.layer_idx_list: + if x < 0 or x >= len(self.encoder.layers): + raise ValueError(f"layer index {x} out of range [0, {len(self.encoder.layers)})") + logging.info(f"Extracting features from layers {self.layer_idx_list}") + self.access_cfg = { + "interctc": { + "capture_layers": self.layer_idx_list, + }, + "detach": False, + "convert_to_cpu": False, + } + self._is_access_enabled = False + + def forward( + self, audio_signal, length, cache_last_channel=None, cache_last_time=None, cache_last_channel_len=None + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + same interface as ConformerEncoder.forward() + Returns: + tuple of aggregated features of shape [B, D, T] and lengths of shape [B] + """ + self.encoder.update_access_cfg(self.access_cfg, guid=getattr(self, "model_guid", None)) + self.encoder.set_access_enabled(access_enabled=True, guid=getattr(self, "model_guid", None)) + + _ = self.encoder( + audio_signal=audio_signal, + length=length, + cache_last_channel=cache_last_channel, + cache_last_time=cache_last_time, + cache_last_channel_len=cache_last_channel_len, + ) + + total_registry = {} + for module_registry in self.encoder.get_module_registry(self.encoder).values(): + for key in module_registry: + if key.startswith("interctc/") and key in total_registry: + raise RuntimeError(f"layer {key} has been logged multiple times!") + total_registry.update(module_registry) + + encoded_list = [] + encoded_len_list = [] + for layer_idx in self.layer_idx_list: + try: + layer_outputs = total_registry[f"interctc/layer_output_{layer_idx}"] + layer_lengths = total_registry[f"interctc/layer_length_{layer_idx}"] + except KeyError: + raise RuntimeError( + f"Intermediate layer {layer_idx} was not captured! Check the layer index and the number of ConformerEncoder layers." + ) + if len(layer_outputs) > 1 or len(layer_lengths) > 1: + raise RuntimeError("Make sure encoder.forward is called exactly one time") + encoded_list.append(layer_outputs[0]) # [B, D, T] + encoded_len_list.append(layer_lengths[0]) # [B] + + self.encoder.reset_registry() + + return self.aggregator(encoded_list, encoded_len_list) + + +class ConformerMultiLayerFeaturePreprocessor(NeuralModule, Exportable, AccessMixin): + def __init__( + self, + aggregator: nn.Module, + preprocessor: AudioToMelSpectrogramPreprocessor, + encoder: ConformerEncoder, + spec_augment=None, + layer_idx_list: Optional[List[int]] = None, + freeze_encoder: bool = True, + ): + super().__init__() + self.preprocessor = preprocessor + self.spec_augmentation = spec_augment + self.feature_extractor = ConformerMultiLayerFeatureExtractor( + encoder=encoder, aggregator=aggregator, layer_idx_list=layer_idx_list + ) + self.freeze_encoder = freeze_encoder + if freeze_encoder: + self.freeze() + + def forward(self, input_signal, length): + """ + Forward pass of the model. + + Args: + input_signal: Tensor that represents a batch of raw audio signals, + of shape [B, T]. T here represents timesteps, with 1 second of audio represented as + `self.sample_rate` number of floating point values. + length: Vector of length B, that contains the individual lengths of the audio + sequences. + Returns: + encoded: A tensor of shape [B, D, T], where D represents the number of + feature dimensions extracted from the audio signal, and T represents the + number of timesteps in the processed audio signal. + encoded_len: A tensor of shape [B], that contains the lengths of the audio sequences. + """ + + processed_signal, processed_signal_length = self.preprocessor( + input_signal=input_signal, + length=length, + ) + + if self.spec_augmentation is not None and self.training: + processed_signal = self.spec_augmentation(input_spec=processed_signal, length=processed_signal_length) + + encoded, encoded_len = self.feature_extractor(audio_signal=processed_signal, length=processed_signal_length) + return encoded, encoded_len diff --git a/nemo/collections/asr/modules/ssl_modules/multi_softmax_decoder.py b/nemo/collections/asr/modules/ssl_modules/multi_softmax_decoder.py new file mode 100644 index 000000000000..d9311cd1ac21 --- /dev/null +++ b/nemo/collections/asr/modules/ssl_modules/multi_softmax_decoder.py @@ -0,0 +1,84 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from collections import OrderedDict + +import torch + +from nemo.collections.asr.parts.submodules.jasper import init_weights +from nemo.core.classes import NeuralModule, typecheck +from nemo.core.neural_types import AcousticEncodedRepresentation, LogprobsType, NeuralType + + +class MultiSoftmaxDecoder(NeuralModule): + """ + A linear decoder that takes encoder output and produces log probabilities, which also handles the + case where each frame has multiple output targets. This can be used together with + `nemo.collections.asr.losses.ssl_losses.MultiMLMLoss` to train a model with multiple targets per frame. + """ + + @property + def input_types(self): + return OrderedDict({"encoder_output": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation())}) + + @property + def output_types(self): + if self.squeeze_single and self.num_decoders == 1: + return OrderedDict({"logprobs": NeuralType(('B', 'T', 'C'), LogprobsType())}) + return OrderedDict({"logprobs": NeuralType(('B', 'T', 'C', 'H'), LogprobsType())}) + + def __init__( + self, + feat_in: int, + num_classes: int, + num_decoders: int = 1, + init_mode: str = "xavier_uniform", + use_bias: bool = False, + squeeze_single: bool = False, + ) -> None: + """ + Args: + feat_in: input feature dimension + num_classes: number of classes + num_decoders: number of decoders + init_mode: initialization mode + use_bias: whether to use bias + squeeze_single: if True, squeeze codebook dimension if num_books is 1 + """ + super().__init__() + self.feat_in = feat_in + self.num_classes = num_classes + self.num_decoders = num_decoders + self.squeeze_single = squeeze_single + + self.decoder_layers = torch.nn.Sequential( + torch.nn.Conv1d(self.feat_in, self.num_classes * self.num_decoders, kernel_size=1, bias=use_bias) + ) + self.apply(lambda x: init_weights(x, mode=init_mode)) + + @typecheck() + def forward(self, encoder_output): + """ + Args: + encoder_output: output from the encoder of shape (B, D, T) + Returns: + logprobs: log probabilities of shape (B, T, C, H), or (B, T, C) if squeeze_single is True + """ + logits = self.decoder_layers(encoder_output).transpose(1, 2) + logits = logits.reshape(logits.shape[0], logits.shape[1], self.num_classes, self.num_decoders) + if self.squeeze_single and self.num_decoders == 1: + logits = logits.squeeze(-1) + + return torch.nn.functional.log_softmax(logits, dim=2) diff --git a/nemo/collections/asr/modules/ssl_modules/quantizers.py b/nemo/collections/asr/modules/ssl_modules/quantizers.py new file mode 100644 index 000000000000..ebc9c65e2e7c --- /dev/null +++ b/nemo/collections/asr/modules/ssl_modules/quantizers.py @@ -0,0 +1,166 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn.functional as F +from torch import nn + +from nemo.core import NeuralModule +from nemo.core.classes import Exportable, NeuralModule, typecheck +from nemo.core.neural_types import LabelsType, NeuralType, SpectrogramType + + +class RandomProjectionVectorQuantizer(NeuralModule, Exportable): + DIST_FN_LIST = ["l2", "cosine"] + + def __init__( + self, + feat_in: int, + code_dim: int, + num_classes: int, + num_books: int, + dist_fn: str = "cosine", + time_ahead: bool = False, + freeze: bool = True, + squeeze_single: bool = False, + combine_time_steps: int = 1, + ): + """Vector quantization using random projection proposed in BEST-RQ paper: + 'Self-Supervised Learning with Random-Projection Quantizer for Speech Recognition' + + Args: + feat_in: input feature dimension + code_dim: dimension of the codebook features + num_classes: number of classes + num_books: number of codebooks + dist_fn: distance function to use, one of "l2" or "cosine" + time_ahead: if Ture, the input is of shape (B, T, D), otherwise (B, D, T) + freeze: whether to freeze the projection matrix + squeeze_single: if True, squeeze codebook dimension if num_books is 1 + """ + super().__init__() + + if dist_fn not in self.DIST_FN_LIST: + raise ValueError(f"Unknown distance function {dist_fn}, must be one of {self.DIST_FN_LIST}") + + self.feat_in = feat_in + self.code_dim = code_dim + self.num_classes = num_classes + self.num_books = num_books + self.dist_fn = dist_fn + self.time_ahead = time_ahead + self.squeeze_single = squeeze_single + self.combine_time_steps = combine_time_steps + + # (B, T, D) -> (B, T, num_books, code_dim) + self.proj = nn.Linear(self.feat_in * combine_time_steps, self.num_books * self.code_dim, bias=False) + torch.nn.init.xavier_normal_(self.proj.weight) + + # (num_books, num_classes, hid_dim) + codebooks = torch.randn(self.num_books, self.num_classes, self.code_dim).double() + torch.nn.init.normal_(codebooks, mean=0, std=1) + codebooks = F.normalize(codebooks, dim=-1) + self.codebooks = nn.Parameter(codebooks) + if freeze: + self.freeze() + + @property + def input_types(self): + """Returns definitions of module input ports.""" + if self.time_ahead: + return {"input_signal": NeuralType(('B', 'T', 'D'), SpectrogramType())} + return {"input_signal": NeuralType(('B', 'D', 'T'), SpectrogramType())} + + @property + def output_types(self): + """Returns definitions of module output ports.""" + if self.time_ahead: + if self.num_books == 1 and self.squeeze_single: + return { + "xq": NeuralType(('B', 'T', 'D'), SpectrogramType()), + "xid": NeuralType(('B', 'T'), LabelsType()), + } + return { + "xq": NeuralType(('B', 'T', 'D', 'H'), SpectrogramType()), + "xid": NeuralType(('B', 'T', 'H'), LabelsType()), + } + if self.num_books == 1 and self.squeeze_single: + return { + "xq": NeuralType(('B', 'D', 'T'), SpectrogramType()), + "xid": NeuralType(('B', 'T'), LabelsType()), + } + return { + "xq": NeuralType(('B', 'D', 'T', 'H'), SpectrogramType()), + "xid": NeuralType(('B', 'T', 'H'), LabelsType()), + } + + @typecheck() + def forward(self, input_signal): + """ + Args: + input_signal: input features of shape (B, T, D) or (B, D, T) + Returns: + xq: quantized features of shape (B, T, D, N) or (B, D, T, N) + xid: quantized tokens of shape (B, T, N) + """ + if not self.time_ahead: + # (B, D, T) -> (B, T, D) + input_signal = input_signal.transpose(1, 2) + + B, T, _ = input_signal.size() + + if self.combine_time_steps > 1: + input_signal = input_signal.contiguous().reshape(B, T // self.combine_time_steps, -1) + T = T // self.combine_time_steps + + # (B, T, D) -> (B, T, num_books*code_dim) + x = self.proj(input_signal) + + # normalize each feature vector + # (B, T, num_books*code_dim) -> (B, T, num_books, code_dim) + x = F.normalize(x.view(B, T, self.num_books, self.code_dim), dim=-1) + + # get tokens (xid) of shape (B, T, num_books) + if self.dist_fn == "cosine": + # (B, T, num_books, code_dim) -> (B, T, num_books, num_classes) + xid = torch.einsum('btdh,dch->btdc', x, self.codebooks) + # (B, T, num_books, num_classes) -> (B, T, num_books) + xid = xid.max(dim=-1)[1] + elif self.dist_fn == "l2": + # (B, T, num_books, code_dim) -> (B, T, num_books, code_dim, num_classes) + xid = x.unsqueeze(-1) - self.codebooks.transpose(1, 2).unsqueeze(0).unsqueeze(0) + xid = xid.norm(dim=-2).argmin(dim=-1) + else: + raise ValueError(f"Unknown distance function {self.dist_fn}, must be one of {self.DIST_FN_LIST}") + + # xid2: (B, T, num_books) -> (B, T, num_books) + xid2 = xid + self.num_classes * torch.arange(self.num_books, device=xid.device).unsqueeze(0).unsqueeze(0) + # xid2: (B, T, num_books) -> (B*num_books, T) + xid2 = xid2.transpose(1, 2).contiguous().view(-1, T) + + # get quantized vector (xq) of shape (B, T, code_dim, num_books) + # codebook: (num_books, num_classes, code_dim) -> (num_books*num_classes, code_dim) + xq = F.embedding(xid2.view(-1), self.codebooks.view(-1, self.code_dim)).view( + B, T, self.code_dim, self.num_books + ) + + if not self.time_ahead: + # (B, T, D) -> (B, D, T) + xq = xq.transpose(1, 2) + + if self.num_books == 1 and self.squeeze_single: + xq = xq.squeeze(-1) + xid = xid.squeeze(-1) + + return xq, xid diff --git a/nemo/collections/asr/parts/mixins/transcription.py b/nemo/collections/asr/parts/mixins/transcription.py index c7195f9ba12c..104e6bff81af 100644 --- a/nemo/collections/asr/parts/mixins/transcription.py +++ b/nemo/collections/asr/parts/mixins/transcription.py @@ -30,6 +30,7 @@ from nemo.collections.asr.parts.preprocessing.perturb import process_augmentations from nemo.collections.asr.parts.preprocessing.segment import AudioSegment, ChannelSelectorType from nemo.collections.asr.parts.utils import manifest_utils +from nemo.collections.common.data.utils import move_data_to_device from nemo.utils import logging, logging_mode TranscriptionReturnType = Union[List[str], List['Hypothesis'], Tuple[List[str]], Tuple[List['Hypothesis']]] @@ -68,25 +69,6 @@ class TranscribeConfig: _internal: Optional[InternalTranscribeConfig] = None -def move_to_device(batch, device, non_blocking=False): - """ - Recursively move all tensors in `batch` to `device`. - Supports tensors, lists, tuples, dictionaries, and dataclasses. - """ - if isinstance(batch, torch.Tensor): - return batch.to(device, non_blocking=non_blocking) - elif isinstance(batch, (list, tuple)): - return type(batch)(move_to_device(x, device, non_blocking) for x in batch) - elif isinstance(batch, dict): - return {k: move_to_device(v, device, non_blocking) for k, v in batch.items()} - elif is_dataclass(batch): - return type(batch)( - **{field.name: move_to_device(getattr(batch, field.name), device, non_blocking) for field in fields(batch)} - ) - else: - return batch # do nothing if not supported type - - def get_value_from_transcription_config(trcfg, key, default): """ Utility function to get a value from the transcription config. @@ -386,7 +368,7 @@ def transcribe_generator(self, audio, override_config: Optional[TranscribeConfig for test_batch in tqdm(dataloader, desc="Transcribing", disable=not verbose): # Move batch to device - test_batch = move_to_device(test_batch, transcribe_cfg._internal.device) + test_batch = move_data_to_device(test_batch, transcribe_cfg._internal.device) # Run forward pass model_outputs = self._transcribe_forward(test_batch, transcribe_cfg) processed_outputs = self._transcribe_output_processing(model_outputs, transcribe_cfg) diff --git a/nemo/collections/asr/parts/preprocessing/perturb.py b/nemo/collections/asr/parts/preprocessing/perturb.py index 2108da010c52..a63a051d3e0b 100644 --- a/nemo/collections/asr/parts/preprocessing/perturb.py +++ b/nemo/collections/asr/parts/preprocessing/perturb.py @@ -160,9 +160,13 @@ def perturb(self, data): return new_sr = int(self._sr * speed_rate) - data._samples = librosa.core.resample( - data._samples, orig_sr=self._sr, target_sr=new_sr, res_type=self._res_type - ) + try: + data._samples = librosa.core.resample( + data._samples, orig_sr=self._sr, target_sr=new_sr, res_type=self._res_type + ) + except Exception as e: + logging.warning(f"Failed to resample audio from {self._sr} to {new_sr}. Skipping augmentation. Error: {e}") + return class TimeStretchPerturbation(Perturbation): @@ -371,7 +375,10 @@ def __init__( def perturb(self, data): impulse = read_one_audiosegment( - self._manifest, data.sample_rate, tarred_audio=self._tarred_audio, audio_dataset=self._data_iterator, + self._manifest, + data.sample_rate, + tarred_audio=self._tarred_audio, + audio_dataset=self._data_iterator, ) # normalize if necessary @@ -385,6 +392,10 @@ def perturb(self, data): # len of input data samples len_data = len(data._samples) + if max(abs(data._samples)) == 0: + logging.warning("Zero audio input found, skipping impulse perturbation.") + return + # convolve with the full impulse response data._samples = signal.fftconvolve(data._samples, impulse_norm, "full") @@ -397,6 +408,10 @@ def perturb(self, data): # trim to match the input data length data._samples = data._samples[:len_data] + if max(abs(data._samples)) == 0: + logging.warning("Zero audio input found after impulse perturbation.") + return + # normalize data samples to [-1,1] after rir convolution to avoid nans with fp16 training data._samples = data._samples / max(abs(data._samples)) @@ -493,7 +508,10 @@ def perturb(self, data, ref_mic=0): ref_mic (int): reference mic index for scaling multi-channel audios """ noise = read_one_audiosegment( - self._manifest, data.sample_rate, tarred_audio=self._tarred_audio, audio_dataset=self._data_iterator, + self._manifest, + data.sample_rate, + tarred_audio=self._tarred_audio, + audio_dataset=self._data_iterator, ) self.perturb_with_input_noise(data, noise, ref_mic=ref_mic) @@ -516,13 +534,31 @@ def perturb_with_input_noise(self, data, noise, data_rms=None, ref_mic=0): ) snr_db = random.uniform(self._min_snr_db, self._max_snr_db) + + if data.is_empty(): + logging.warning( + f"Empty audio segment found for {data.audio_file} with offset {data.offset} and duration {data.duration}." + ) + if data_rms is None: data_rms = data.rms_db + if noise.is_empty(): + logging.warning( + f"Empty noise segment found for {noise.audio_file} with offset {noise.offset} and duration {noise.duration}." + ) + noise_rms = -float("inf") + else: + noise_rms = noise.rms_db + + if data.is_empty() and noise.is_empty(): + logging.warning("Both data and noise segments are empty. Skipping perturbation.") + return + if data.num_channels > 1: - noise_gain_db = data_rms[ref_mic] - noise.rms_db[ref_mic] - snr_db + noise_gain_db = data_rms[ref_mic] - noise_rms[ref_mic] - snr_db else: - noise_gain_db = data_rms - noise.rms_db - snr_db + noise_gain_db = data_rms - noise_rms - snr_db noise_gain_db = min(noise_gain_db, self._max_gain_db) # calculate noise segment to use @@ -736,7 +772,7 @@ def norm_audio_to_db(self, x, norm_to_db): x (numpy array): input audio signal norm_to_db (float): the db to normalise to """ - rms = (x ** 2).mean(axis=0) ** 0.5 + rms = (x**2).mean(axis=0) ** 0.5 rms = np.where(np.isclose(rms, 0), self._epsilon, rms) scalar = 10 ** (norm_to_db / 20.0) / rms return x * scalar @@ -828,33 +864,33 @@ def perturb(self, data): class RirAndNoisePerturbation(Perturbation): """ - RIR augmentation with additive foreground and background noise. - In this implementation audio data is augmented by first convolving the audio with a Room Impulse Response - and then adding foreground noise and background noise at various SNRs. RIR, foreground and background noises - should either be supplied with a manifest file or as tarred audio files (faster). + RIR augmentation with additive foreground and background noise. + In this implementation audio data is augmented by first convolving the audio with a Room Impulse Response + and then adding foreground noise and background noise at various SNRs. RIR, foreground and background noises + should either be supplied with a manifest file or as tarred audio files (faster). - Different sets of noise audio files based on the original sampling rate of the noise. This is useful while - training a mixed sample rate model. For example, when training a mixed model with 8 kHz and 16 kHz audio with a - target sampling rate of 16 kHz, one would want to augment 8 kHz data with 8 kHz noise rather than 16 kHz noise. + Different sets of noise audio files based on the original sampling rate of the noise. This is useful while + training a mixed sample rate model. For example, when training a mixed model with 8 kHz and 16 kHz audio with a + target sampling rate of 16 kHz, one would want to augment 8 kHz data with 8 kHz noise rather than 16 kHz noise. - Args: - rir_manifest_path: Manifest file for RIRs - rir_tar_filepaths: Tar files, if RIR audio files are tarred - rir_prob: Probability of applying a RIR - noise_manifest_paths: Foreground noise manifest path - min_snr_db: Min SNR for foreground noise - max_snr_db: Max SNR for background noise, - noise_tar_filepaths: Tar files, if noise files are tarred - apply_noise_rir: Whether to convolve foreground noise with a a random RIR - orig_sample_rate: Original sampling rate of foreground noise audio - max_additions: Max number of times foreground noise is added to an utterance, - max_duration: Max duration of foreground noise - bg_noise_manifest_paths: Background noise manifest path - bg_min_snr_db: Min SNR for background noise - bg_max_snr_db: Max SNR for background noise - bg_noise_tar_filepaths: Tar files, if noise files are tarred - bg_orig_sample_rate: Original sampling rate of background noise audio - rng: Random seed. Default is None + Args: + rir_manifest_path: Manifest file for RIRs + rir_tar_filepaths: Tar files, if RIR audio files are tarred + rir_prob: Probability of applying a RIR + noise_manifest_paths: Foreground noise manifest path + min_snr_db: Min SNR for foreground noise + max_snr_db: Max SNR for background noise, + noise_tar_filepaths: Tar files, if noise files are tarred + apply_noise_rir: Whether to convolve foreground noise with a a random RIR + orig_sample_rate: Original sampling rate of foreground noise audio + max_additions: Max number of times foreground noise is added to an utterance, + max_duration: Max duration of foreground noise + bg_noise_manifest_paths: Background noise manifest path + bg_min_snr_db: Min SNR for background noise + bg_max_snr_db: Max SNR for background noise + bg_noise_tar_filepaths: Tar files, if noise files are tarred + bg_orig_sample_rate: Original sampling rate of background noise audio + rng: Random seed. Default is None """ @@ -959,12 +995,12 @@ def perturb(self, data): class TranscodePerturbation(Perturbation): """ - Audio codec augmentation. This implementation uses sox to transcode audio with low rate audio codecs, - so users need to make sure that the installed sox version supports the codecs used here (G711 and amr-nb). + Audio codec augmentation. This implementation uses sox to transcode audio with low rate audio codecs, + so users need to make sure that the installed sox version supports the codecs used here (G711 and amr-nb). - Args: - codecs (List[str]):A list of codecs to be trancoded to. Default is None. - rng (int): Random seed. Default is None. + Args: + codecs (List[str]):A list of codecs to be trancoded to. Default is None. + rng (int): Random seed. Default is None. """ def __init__(self, codecs=None, rng=None): @@ -1019,9 +1055,9 @@ def perturb(self, data): class RandomSegmentPerturbation(Perturbation): """ - Returns a random segment from input of duration "duration_sec". + Returns a random segment from input of duration "duration_sec". If duration_sec > input audio length, pad_to_duration determines the outcome. - + RandomSegmentPerturbation is intended for self-supervised learning. Not for supervised, as extracting corresponding text is not facilitated. @@ -1030,20 +1066,29 @@ class RandomSegmentPerturbation(Perturbation): duration_sec (float): duration of the segment to be extracted pad_to_duration (bool): zero pad if length of input audio < duration_sec rng: Random seed. Default is None + min_rms_db: Minimum RMS db value for the perturbed audio. Default is None + max_trials: Maximum number of trials to find a segment with RMS db > min_rms_db. Default is 10 + verbose: If True, logs a warning if RMS db < min_rms_db after max_trials. Default is False """ - def __init__(self, duration_sec=32.0, pad_to_duration=False, rng=None): + def __init__( + self, duration_sec=32.0, pad_to_duration=False, rng=None, min_rms_db=None, max_trials=10, verbose=False + ): if duration_sec <= 0: raise ValueError("duration_sec should be > 0") self._duration_sec = duration_sec self._pad_to_duration = pad_to_duration + self._min_rms_db = min_rms_db + self._max_trials = max_trials + self._verbose = verbose random.seed(rng) if rng else None - def perturb(self, data): + def perturb(self, data: AudioSegment): if self._duration_sec > data.duration: if not self._pad_to_duration: - raise ValueError(f"audio length < {self._duration_sec} sec and pad_to_duration is set to False") + # don't do anything if pad_to_duration is False + return start_time = 0.0 pad_size = self._duration_sec * data.sample_rate - data.num_samples data.pad(pad_size=pad_size) @@ -1051,6 +1096,22 @@ def perturb(self, data): start_time = random.uniform(0.0, data.duration - self._duration_sec) end_time = start_time + self._duration_sec + new_data = copy.deepcopy(data) + new_data.subsegment(start_time=start_time, end_time=end_time) + if self._min_rms_db is not None: + rms_db = new_data.rms_db if new_data.num_channels == 1 else min(new_data.rms_db) + trial = 0 + while rms_db < self._min_rms_db and trial < self._max_trials: + start_time = random.uniform(0.0, data.duration - self._duration_sec) + end_time = start_time + self._duration_sec + new_data = copy.deepcopy(data) + new_data.subsegment(start_time=start_time, end_time=end_time) + rms_db = new_data.rms_db if new_data.num_channels == 1 else min(new_data.rms_db) + trial += 1 + if self._verbose and trial == self._max_trials and rms_db < self._min_rms_db: + logging.warning( + f"Could not find a segment with RMS db > {self._min_rms_db} after {self._max_trials} trials." + ) data.subsegment(start_time=start_time, end_time=end_time) @@ -1085,14 +1146,14 @@ def __init__(self, perturbations=None, rng=None): self._pipeline = perturbations if perturbations is not None else [] def perturb(self, segment): - for (prob, p) in self._pipeline: + for prob, p in self._pipeline: if random.random() < prob: p.perturb(segment) return def max_augmentation_length(self, length): newlen = length - for (prob, p) in self._pipeline: + for prob, p in self._pipeline: newlen = p.max_augmentation_length(newlen) return newlen @@ -1238,20 +1299,20 @@ class CustomPerturbation(perturb.Perturbation): class AugmentationDataset(IterableDataset): """ - A class that loads tarred audio files and cycles over the files in the dataset. - Accepts a single comma-separated JSON manifest file (in the same style as for the AudioToCharDataset/AudioToBPEDataset), - as well as the path(s) to the tarball(s) containing the wav files. Each line of the manifest should - contain the information for one audio file, including at least the transcript and name of the audio - file within the tarball. - Valid formats for the audio_tar_filepaths argument include: - (1) a single string that can be brace-expanded, e.g. 'path/to/audio.tar' or 'path/to/audio_{1..100}.tar.gz', or - (2) a list of file paths that will not be brace-expanded, e.g. ['audio_1.tar', 'audio_2.tar', ...]. - Note: For brace expansion in (1), there may be cases where `{x..y}` syntax cannot be used due to shell interference. - This occurs most commonly inside SLURM scripts. Therefore we provide a few equivalent replacements. - Supported opening braces - { <=> (, [, < and the special tag _OP_. - Supported closing braces - } <=> ), ], > and the special tag _CL_. - For SLURM based tasks, we suggest the use of the special tags for ease of use. - See the WebDataset documentation for more information about accepted data and input formats. + A class that loads tarred audio files and cycles over the files in the dataset. + Accepts a single comma-separated JSON manifest file (in the same style as for the AudioToCharDataset/AudioToBPEDataset), + as well as the path(s) to the tarball(s) containing the wav files. Each line of the manifest should + contain the information for one audio file, including at least the transcript and name of the audio + file within the tarball. + Valid formats for the audio_tar_filepaths argument include: + (1) a single string that can be brace-expanded, e.g. 'path/to/audio.tar' or 'path/to/audio_{1..100}.tar.gz', or + (2) a list of file paths that will not be brace-expanded, e.g. ['audio_1.tar', 'audio_2.tar', ...]. + Note: For brace expansion in (1), there may be cases where `{x..y}` syntax cannot be used due to shell interference. + This occurs most commonly inside SLURM scripts. Therefore we provide a few equivalent replacements. + Supported opening braces - { <=> (, [, < and the special tag _OP_. + Supported closing braces - } <=> ), ], > and the special tag _CL_. + For SLURM based tasks, we suggest the use of the special tags for ease of use. + See the WebDataset documentation for more information about accepted data and input formats. """ def __init__( @@ -1287,8 +1348,7 @@ def __len__(self): return len(self._manifest) def _loop_offsets(self, iterator): - """This function is used to iterate through utterances with different offsets for each file. - """ + """This function is used to iterate through utterances with different offsets for each file.""" class TarredAudioLoopOffsets: def __init__(self, collection): diff --git a/nemo/collections/asr/parts/preprocessing/segment.py b/nemo/collections/asr/parts/preprocessing/segment.py index 310e76cfd0b0..aceab6637006 100644 --- a/nemo/collections/asr/parts/preprocessing/segment.py +++ b/nemo/collections/asr/parts/preprocessing/segment.py @@ -36,7 +36,7 @@ import math import os import random -from typing import Iterable, Optional, Union +from typing import Iterable, List, Optional, Union import librosa import numpy as np @@ -171,6 +171,9 @@ def __init__( channel_selector=None, normalize_db: Optional[float] = None, ref_channel: Optional[int] = None, + audio_file: Optional[Union[str, List[str]]] = None, + offset: Optional[float] = None, + duration: Optional[float] = None, ): """Create audio segment from samples. Samples are convert float32 internally, with int scaled to [-1, 1]. @@ -207,7 +210,9 @@ def __init__( self._orig_sr = orig_sr if orig_sr is not None else sample_rate self._ref_channel = ref_channel self._normalize_db = normalize_db - + self._audio_file = audio_file + self._offset = offset + self._duration = duration if normalize_db is not None: self.normalize_db(normalize_db, ref_channel) @@ -349,7 +354,7 @@ def from_file( samples = Audio.from_file(audio_file, codec=ffmpeg_codecs.get(os.path.splitext(audio_file)[-1])) sample_rate = samples.frame_rate num_channels = samples.channels - if offset > 0: + if offset is not None and offset > 0: # pydub does things in milliseconds seconds = offset * 1000 samples = samples[int(seconds) :] @@ -380,6 +385,9 @@ def from_file( channel_selector=channel_selector, normalize_db=normalize_db, ref_channel=ref_channel, + audio_file=audio_file, + offset=offset, + duration=duration, ) @classmethod @@ -419,7 +427,7 @@ def from_file_list( for a_file in audio_file_list: # Load audio from the current file a_segment = cls.from_file( - a_file, + audio_file=a_file, target_sr=target_sr, int_values=int_values, offset=offset, @@ -465,6 +473,7 @@ def from_file_list( target_sr=target_sr, trim=trim, channel_selector=channel_selector, + audio_file=audio_file_list, *args, **kwargs, ) @@ -572,6 +581,18 @@ def rms_db(self): def orig_sr(self): return self._orig_sr + @property + def offset(self): + return float(self._offset) if self._offset is not None else None + + @property + def audio_file(self): + return str(self._audio_file) if self._audio_file is not None else None + + def is_empty(self): + mean_square = np.sum(np.mean(self._samples**2, axis=0)) + return self.num_samples == 0 or mean_square == 0 + def gain_db(self, gain): self._samples *= 10.0 ** (gain / 20.0) diff --git a/nemo/collections/asr/parts/submodules/ssl_quantizers.py b/nemo/collections/asr/parts/submodules/ssl_quantizers.py index 944589e3f2cc..26e69fa6d087 100644 --- a/nemo/collections/asr/parts/submodules/ssl_quantizers.py +++ b/nemo/collections/asr/parts/submodules/ssl_quantizers.py @@ -94,7 +94,7 @@ def block(input_dim, output_dim): self.codebook_indices = None def set_num_updates(self, num_updates): - self.curr_temp = max(self.max_temp * self.temp_decay ** num_updates, self.min_temp) + self.curr_temp = max(self.max_temp * self.temp_decay**num_updates, self.min_temp) def get_codebook_indices(self): if self.codebook_indices is None: @@ -105,7 +105,7 @@ def get_codebook_indices(self): self.codebook_indices = torch.tensor(inds, dtype=torch.long, device=self.vars.device).flatten() if not self.combine_groups: - self.codebook_indices = self.codebook_indices.view(self.num_vars ** self.groups, -1) + self.codebook_indices = self.codebook_indices.view(self.num_vars**self.groups, -1) for b in range(1, self.groups): self.codebook_indices[:, b] += self.num_vars * b self.codebook_indices = self.codebook_indices.flatten() @@ -124,16 +124,14 @@ def sample_from_codebook(self, b, n): @property def input_types(self): - """Returns definitions of module input ports. - """ + """Returns definitions of module input ports.""" if self.time_first: return {"x": NeuralType(('B', 'T', 'D'), EncodedRepresentation())} return {"x": NeuralType(('B', 'D', 'T'), EncodedRepresentation())} @property def output_types(self): - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" if self.time_first: return { "x": NeuralType(('B', 'T', 'D'), EncodedRepresentation()), diff --git a/nemo/collections/common/data/lhotse/cutset.py b/nemo/collections/common/data/lhotse/cutset.py index 7e7fdbc95a61..8e16688a1b32 100644 --- a/nemo/collections/common/data/lhotse/cutset.py +++ b/nemo/collections/common/data/lhotse/cutset.py @@ -421,6 +421,9 @@ def read_nemo_manifest(config, is_tarred: bool) -> CutSet: # Format option 2: # Assume it's [[path1, weight1], [path2, weight2], ...] (while tarred_audio_filepaths remain unchanged). # Note: this option allows to manually set the weights for multiple datasets. + # Format option 3: + # i.e., NeMo concatenated dataset + # Assume it's [path1, path2, ...] (while tarred_audio_filepaths in the same format). logging.info( f"Initializing Lhotse CutSet from multiple tarred NeMo manifest sources with a weighted multiplexer. " f"We found the following sources and weights: " @@ -429,9 +432,14 @@ def read_nemo_manifest(config, is_tarred: bool) -> CutSet: weights = [] tar_paths = config.tarred_audio_filepaths if is_tarred else repeat((None,)) # Create a stream for each dataset. - for manifest_info, (tar_path,) in zip(config.manifest_filepath, tar_paths): + for manifest_info, tar_path in zip(config.manifest_filepath, tar_paths): + if isinstance(tar_path, (list, tuple, ListConfig)): + # if it's in option 1 or 2 + (tar_path,) = tar_path + manifest_path = manifest_info[0] + else: + manifest_path = manifest_info # First, convert manifest_path[+tar_path] to an iterator. - manifest_path = manifest_info[0] if is_tarred and not metadata_only: nemo_iter = LazyNeMoTarredIterator( manifest_path=manifest_path, @@ -441,7 +449,7 @@ def read_nemo_manifest(config, is_tarred: bool) -> CutSet: else: nemo_iter = LazyNeMoIterator(manifest_path, **notar_kwargs, **common_kwargs) # Then, determine the weight or use one provided - if len(manifest_info) == 1: + if isinstance(manifest_info, str) or len(manifest_info) == 1: weight = len(nemo_iter) else: assert ( diff --git a/nemo/collections/common/data/utils.py b/nemo/collections/common/data/utils.py new file mode 100644 index 000000000000..657180974215 --- /dev/null +++ b/nemo/collections/common/data/utils.py @@ -0,0 +1,37 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import fields, is_dataclass +from typing import Any, Union + +import torch + + +def move_data_to_device(inputs: Any, device: Union[str, torch.device], non_blocking: bool = True) -> Any: + """Recursively moves inputs to the specified device""" + if isinstance(inputs, torch.Tensor): + return inputs.to(device, non_blocking=non_blocking) + elif isinstance(inputs, (list, tuple, set)): + return inputs.__class__([move_data_to_device(i, device, non_blocking) for i in inputs]) + elif isinstance(inputs, dict): + return {k: move_data_to_device(v, device, non_blocking) for k, v in inputs.items()} + elif is_dataclass(inputs): + return type(inputs)( + **{ + field.name: move_data_to_device(getattr(inputs, field.name), device, non_blocking) + for field in fields(inputs) + } + ) + else: + return inputs diff --git a/nemo/collections/common/parts/multi_layer_perceptron.py b/nemo/collections/common/parts/multi_layer_perceptron.py index 76c06bf23ea6..1ae4800ab8b9 100644 --- a/nemo/collections/common/parts/multi_layer_perceptron.py +++ b/nemo/collections/common/parts/multi_layer_perceptron.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional import torch @@ -34,6 +35,7 @@ def __init__( num_layers: int = 2, activation: str = 'relu', log_softmax: bool = True, + channel_idx: Optional[int] = None, ): super().__init__() self.layers = 0 @@ -46,16 +48,37 @@ def __init__( setattr(self, f'layer{self.layers}', layer) self.layers += 1 self.log_softmax = log_softmax + self.channel_idx = channel_idx @property def last_linear_layer(self): return getattr(self, f'layer{self.layers - 1}') - def forward(self, hidden_states): - output_states = hidden_states[:] + def forward(self, hidden_states: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor: + """ + Multi-layer perceptron forward function compatible with multiple types of input keyword arguments + """ + if hidden_states is None: + if "audio_signal" in kwargs: + hidden_states = kwargs["audio_signal"] + elif "encoder_output" in kwargs: + hidden_states = kwargs["encoder_output"] + else: + raise ValueError("No input tensor found") + + if self.channel_idx is not None: + # compatible with transformers/conformer output + output_states = hidden_states.transpose(-1, self.channel_idx) + else: + output_states = hidden_states + for i in range(self.layers): output_states = getattr(self, f'layer{i}')(output_states) if self.log_softmax: output_states = torch.log_softmax(output_states, dim=-1) + + if self.channel_idx is not None: + # compatible with transformers/conformer output + output_states = output_states.transpose(-1, self.channel_idx) return output_states diff --git a/nemo/collections/common/parts/preprocessing/collections.py b/nemo/collections/common/parts/preprocessing/collections.py index 0cb81c115d05..b16ac50e4d56 100644 --- a/nemo/collections/common/parts/preprocessing/collections.py +++ b/nemo/collections/common/parts/preprocessing/collections.py @@ -16,7 +16,8 @@ import json import os from itertools import combinations -from typing import Any, Dict, Iterable, List, Optional, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Union + import numpy as np import pandas as pd @@ -310,7 +311,7 @@ def __init__( class ASRAudioText(AudioText): """`AudioText` collector from asr structured json files.""" - def __init__(self, manifests_files: Union[str, List[str]], *args, **kwargs): + def __init__(self, manifests_files: Union[str, List[str]], parse_func: Optional[Callable] = None, *args, **kwargs): """Parse lists of audio files, durations and transcripts texts. Args: @@ -333,8 +334,9 @@ def __init__(self, manifests_files: Union[str, List[str]], *args, **kwargs): [], [], ) + speakers, orig_srs, token_labels, langs = [], [], [], [] - for item in manifest.item_iter(manifests_files): + for item in manifest.item_iter(manifests_files, parse_func=parse_func): ids.append(item['id']) audio_files.append(item['audio_file']) durations.append(item['duration']) diff --git a/nemo/collections/multimodal/speech_llm/models/modular_models.py b/nemo/collections/multimodal/speech_llm/models/modular_models.py index edabbfd82f87..46b2ca3e26fd 100644 --- a/nemo/collections/multimodal/speech_llm/models/modular_models.py +++ b/nemo/collections/multimodal/speech_llm/models/modular_models.py @@ -28,8 +28,8 @@ from pytorch_lightning.utilities import rank_zero_only from nemo.collections.asr.models import ASRModel, EncDecSpeakerLabelModel -from nemo.collections.asr.parts.mixins.transcription import move_to_device from nemo.collections.asr.parts.utils.eval_utils import remove_punctuations +from nemo.collections.common.data.utils import move_data_to_device from nemo.collections.common.metrics import MetricStringToTorchMetric, TextMetricsSet from nemo.collections.multimodal.speech_llm.data.build_dataset import ( build_speechllm_dataloader, @@ -480,7 +480,7 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_ if self.get_attention_mask_from_fusion and 'attention_mask' in required_keys: required_keys.remove('attention_mask') - batch = move_to_device(batch, self.device) + batch = move_data_to_device(batch, self.device) batch = self.get_batch_on_this_context_parallel_rank(batch) if not self.mcore_gpt: diff --git a/nemo/collections/multimodal/speech_llm/models/modular_t5_models.py b/nemo/collections/multimodal/speech_llm/models/modular_t5_models.py index fce31d031abd..79fc0468e819 100644 --- a/nemo/collections/multimodal/speech_llm/models/modular_t5_models.py +++ b/nemo/collections/multimodal/speech_llm/models/modular_t5_models.py @@ -27,7 +27,7 @@ from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.asr.models import ASRModel, SpeechEncDecSelfSupervisedModel -from nemo.collections.asr.parts.mixins.transcription import move_to_device +from nemo.collections.common.data.utils import move_data_to_device from nemo.collections.common.metrics import MetricStringToTorchMetric, TextMetricsSet from nemo.collections.multimodal.speech_llm.data.build_dataset import ( build_speechllm_dataloader, @@ -921,7 +921,7 @@ def inference_step(self, dataloader_iter, mode, dataloader_idx=0): def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: - batch = move_to_device(batch, device=self.device) + batch = move_data_to_device(batch, device=self.device) encoder_input, attention_mask, enc_mask = self.prepare_llm_input(batch) # enc_input = speech and text prompt # dec_input and label = text output label @@ -1329,7 +1329,7 @@ def forward( def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> Any: - batch = move_to_device(batch, device=self.device) + batch = move_data_to_device(batch, device=self.device) encoder_input, _, enc_mask = self.prepare_llm_input(batch) # enc_input = speech prompt # dec_input and label = text prompt and text output label diff --git a/nemo/core/classes/dataset.py b/nemo/core/classes/dataset.py index 789fc0b863d7..3e8652367734 100644 --- a/nemo/core/classes/dataset.py +++ b/nemo/core/classes/dataset.py @@ -100,10 +100,6 @@ def collate_fn(self, batch): @dataclass class DatasetConfig: - """ - - """ - # ... batch_size: int = 32 drop_last: bool = False diff --git a/nemo/core/classes/modelPT.py b/nemo/core/classes/modelPT.py index 5a0bbc2bea37..5b8d414ac85b 100644 --- a/nemo/core/classes/modelPT.py +++ b/nemo/core/classes/modelPT.py @@ -1027,7 +1027,6 @@ def on_validation_epoch_end(self) -> Optional[Dict[str, Dict[str, torch.Tensor]] if 'log' in output_dict: self.log_dict(output_dict.pop('log'), on_epoch=True) - # return everything else return output_dict diff --git a/nemo/core/neural_types/elements.py b/nemo/core/neural_types/elements.py index 7e95acebd91f..0aa5fc6b4cad 100644 --- a/nemo/core/neural_types/elements.py +++ b/nemo/core/neural_types/elements.py @@ -55,6 +55,7 @@ 'Length', 'IntType', 'FloatType', + 'BoolType', 'NormalDistributionSamplesType', 'NormalDistributionMeanType', 'NormalDistributionLogVarianceType', @@ -153,8 +154,7 @@ def compare(self, second) -> NeuralTypeComparisonResult: # TODO: Consider moving these files elsewhere class ChannelType(ElementType): - """Element to represent convolutional input/output channel. - """ + """Element to represent convolutional input/output channel.""" def __init__(self): """Dummy init for TorchScript compatibility""" @@ -164,8 +164,7 @@ def __init__(self): class EmbeddedTextType(ChannelType): - """Element to represent output on word/text embedding layers - """ + """Element to represent output on word/text embedding layers""" def __init__(self): """Dummy init for TorchScript compatibility""" @@ -379,7 +378,7 @@ def __init__(self): class Target(ElementType): """ - Type representing an element being a target value. + Type representing an element being a target value. """ def __init__(self): @@ -391,7 +390,7 @@ def __init__(self): class ClassificationTarget(Target): """ - Type representing an element being target value in the classification task, i.e. identifier of a desired class. + Type representing an element being target value in the classification task, i.e. identifier of a desired class. """ def __init__(self): @@ -403,8 +402,8 @@ def __init__(self): class ImageValue(ElementType): """ - Type representing an element/value of a single image channel, - e.g. a single element (R) of RGB image. + Type representing an element/value of a single image channel, + e.g. a single element (R) of RGB image. """ def __init__(self): @@ -416,8 +415,8 @@ def __init__(self): class NormalizedImageValue(ImageValue): """ - Type representing an element/value of a single image channel normalized to <0-1> range, - e.g. a single element (R) of normalized RGB image. + Type representing an element/value of a single image channel normalized to <0-1> range, + e.g. a single element (R) of normalized RGB image. """ def __init__(self): @@ -449,7 +448,7 @@ def __init__(self): class StringLabel(StringType): """ - Type representing an label being a string with class name (e.g. the "hamster" class in CIFAR100). + Type representing an label being a string with class name (e.g. the "hamster" class in CIFAR100). """ def __init__(self): @@ -510,8 +509,7 @@ def __init__(self): class ProbabilityDistributionSamplesType(ElementType): - """Element to represent tensors that meant to be sampled from a valid probability distribution - """ + """Element to represent tensors that meant to be sampled from a valid probability distribution""" def __init__(self): """Dummy init for TorchScript compatibility""" @@ -521,8 +519,7 @@ def __init__(self): class NormalDistributionSamplesType(ProbabilityDistributionSamplesType): - """Element to represent tensors that meant to be sampled from a valid normal distribution - """ + """Element to represent tensors that meant to be sampled from a valid normal distribution""" def __init__(self): """Dummy init for TorchScript compatibility""" diff --git a/nemo/utils/model_utils.py b/nemo/utils/model_utils.py index f4eefd39a9ea..24b110de229a 100644 --- a/nemo/utils/model_utils.py +++ b/nemo/utils/model_utils.py @@ -338,13 +338,28 @@ def resolve_validation_dataloaders(model: 'ModelPT'): # using the name of each of the nested dataset model._validation_names = [ds.name for ds in ds_values] else: - model._validation_names = [parse_dataset_as_name(ds) for ds in ds_values] + ds_names = cfg.validation_ds.get('name', []) + if len(ds_names) > 0: + if len(ds_names) != len(ds_values): + raise ValueError( + f"Number of names ({len(ds_names)}) does not match number of datasets ({len(ds_values)}). Got {ds_names} and {ds_values}" + ) + model._validation_names = [parse_dataset_as_name(n) for n in ds_names] + else: + model._validation_names = [parse_dataset_as_name(ds) for ds in ds_values] unique_names_check(name_list=model._validation_names) + return else: model.setup_validation_data(cfg.validation_ds) - model._validation_names = [parse_dataset_as_name(ds_values)] + ds_names = cfg.validation_ds.get('name', None) + if ds_names is not None: + if not isinstance(ds_names, str): + raise ValueError(f"`name` must be a string for single manifest, got {ds_names}") + model._validation_names = [parse_dataset_as_name(ds_names)] + else: + model._validation_names = [parse_dataset_as_name(ds_values)] unique_names_check(name_list=model._validation_names) @@ -417,14 +432,28 @@ def resolve_test_dataloaders(model: 'ModelPT'): # using the name of each of the nested dataset model._test_names = [ds.name for ds in ds_values] else: - model._test_names = [parse_dataset_as_name(ds) for ds in ds_values] + ds_names = cfg.test_ds.get('name', []) + if len(ds_names) > 0: + if len(ds_names) != len(ds_values): + raise ValueError( + f"Number of names ({len(ds_names)}) does not match number of datasets ({len(ds_values)}). Got {ds_names} and {ds_values}" + ) + model._test_names = [parse_dataset_as_name(n) for n in ds_names] + else: + model._test_names = [parse_dataset_as_name(ds) for ds in ds_values] unique_names_check(name_list=model._test_names) return else: model.setup_test_data(cfg.test_ds) - model._test_names = [parse_dataset_as_name(ds_values)] + ds_names = cfg.test_ds.get('name', None) + if ds_names is not None: + if not isinstance(ds_names, str): + raise ValueError(f"`name` must be a string for single manifest, got {ds_names}") + model._test_names = [parse_dataset_as_name(ds_names)] + else: + model._test_names = [parse_dataset_as_name(ds_values)] unique_names_check(name_list=model._test_names) @@ -468,7 +497,7 @@ def convert_model_config_to_dict_config(cfg: Union['DictConfig', 'NemoConfig']) def _convert_config(cfg: 'OmegaConf'): - """ Recursive function convertint the configuration from old hydra format to the new one. """ + """Recursive function convertint the configuration from old hydra format to the new one.""" if not _HAS_HYDRA: logging.error("This function requires Hydra/Omegaconf and it was not installed.") exit(1) @@ -671,9 +700,9 @@ def inject_model_parallel_rank(filepath, fsdp_sharded_ckpt=False): def ckpt_to_dir(filepath: Union[str, Path]) -> Path: - """ PTL considers checkpoints as .ckpt files. - This method removes the extension and returns a path - to be used as a directory for distributed checkpoints + """PTL considers checkpoints as .ckpt files. + This method removes the extension and returns a path + to be used as a directory for distributed checkpoints """ filepath = Path(filepath) diff --git a/tests/collections/asr/test_ssl_models.py b/tests/collections/asr/test_ssl_models.py index 932dd5760ac8..50aba8375061 100644 --- a/tests/collections/asr/test_ssl_models.py +++ b/tests/collections/asr/test_ssl_models.py @@ -12,13 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy - import pytest import torch -from omegaconf import DictConfig, ListConfig +from omegaconf import DictConfig -from nemo.collections.asr.models import SpeechEncDecSelfSupervisedModel +from nemo.collections.asr.models import EncDecDenoiseMaskedTokenPredModel, SpeechEncDecSelfSupervisedModel +from nemo.core.classes.common import typecheck @pytest.fixture() @@ -132,6 +131,157 @@ def ssl_model(): return ssl_model +@pytest.fixture() +def denoise_mlm_ssl_model(): + + model_defaults = { + "subsampling_factor": 1, + 'enc_hidden': 32, + 'dec_out': 128, + "sample_rate": 16000, + "num_classes": 32, + "num_books": 1, + "code_dim": 16, + "squeeze_single": False, + "mask_position": "pre_conv", # position to apply masking, before or after conv subsampling, choices in ['pre_conv', 'post_conv'] + } + + preprocessor = { + "_target_": "nemo.collections.asr.modules.AudioToMelSpectrogramPreprocessor", + "sample_rate": model_defaults["sample_rate"], + "normalize": "per_feature", + "window_size": 0.025, + "window_stride": 0.01, + "window": "hann", + "features": 80, + "n_fft": 512, + "log": True, + "frame_splicing": 1, + "dither": 0.00001, + "pad_to": 16, + "pad_value": 0.0, + } + + encoder = { + 'cls': 'nemo.collections.asr.modules.ConvASREncoder', + 'params': { + 'feat_in': preprocessor["features"], + 'activation': 'relu', + 'conv_mask': True, + 'jasper': [ + { + 'filters': model_defaults['enc_hidden'], + 'repeat': 1, + 'kernel': [1], + 'stride': [1], + 'dilation': [1], + 'dropout': 0.0, + 'residual': False, + 'separable': True, + 'se': True, + 'se_context_size': -1, + }, + { + 'filters': model_defaults['enc_hidden'], + 'repeat': 1, + 'kernel': [1], + 'stride': [1], + 'dilation': [1], + 'dropout': 0.0, + 'residual': False, + 'separable': True, + 'se': True, + 'se_context_size': -1, + }, + { + 'filters': model_defaults['enc_hidden'], + 'repeat': 1, + 'kernel': [1], + 'stride': [1], + 'dilation': [1], + 'dropout': 0.0, + 'residual': False, + 'separable': True, + 'se': True, + 'se_context_size': -1, + }, + ], + }, + } + + spec_augment = { + '_target_': 'nemo.collections.asr.modules.SpectrogramAugmentation', + 'freq_masks': 0, + 'time_masks': 0, + 'freq_width': 16, + 'time_width': 0.05, + } + + masking = { + "_target_": "nemo.collections.asr.modules.RandomBlockMasking", + "block_size": 40, # for pre_conv masking, 10ms per frame, 400ms per block with block_size=40 + "mask_prob": 0.01, # for allow_overlap=True, this means the mask prob for each frame; otherwise it means the overall masked proportion + "feat_in": preprocessor["features"], + "freeze": True, + "allow_overlap": True, + } + + quantizer = { + "_target_": "nemo.collections.asr.modules.RandomProjectionVectorQuantizer", + "feat_in": preprocessor["features"], + "code_dim": model_defaults["code_dim"], + "num_books": model_defaults["num_books"], + "num_classes": model_defaults["num_classes"], + "dist_fn": "l2", # choices=["l2", "cosine"] + "freeze": True, + "squeeze_single": model_defaults["squeeze_single"], + "combine_time_steps": model_defaults["subsampling_factor"], # conformer sub-sampling ratio + } + + decoder = { + "_target_": "nemo.collections.asr.modules.MultiSoftmaxDecoder", + "feat_in": model_defaults["enc_hidden"], + "num_classes": model_defaults["num_classes"], + "num_decoders": model_defaults["num_books"], + "squeeze_single": model_defaults["squeeze_single"], + "use_bias": True, + } + + loss = { + "_target_": "nemo.collections.asr.losses.MultiMLMLoss", + "combine_time_steps": model_defaults[ + "subsampling_factor" + ], # conformer sub-sampling ratio for 'pre_conv', 1 for 'post_conv' + "mask_threshold": 0.8, + "num_decoders": model_defaults["num_books"], + "squeeze_single": model_defaults["squeeze_single"], + } + + optim = { + "name": "adamw", + "lr": 5.0, + # optimizer arguments + "betas": [0.9, 0.98], + "weight_decay": 1e-3, + } + + model_config = DictConfig( + { + "preprocessor": DictConfig(preprocessor), + "spec_augment": DictConfig(spec_augment), + 'model_defaults': DictConfig(model_defaults), + "masking": DictConfig(masking), + "quantizer": DictConfig(quantizer), + "encoder": DictConfig(encoder), + "decoder": DictConfig(decoder), + "loss": DictConfig(loss), + "optim": DictConfig(optim), + } + ) + ssl_model = EncDecDenoiseMaskedTokenPredModel(cfg=model_config) + return ssl_model + + class TestSSLModel: @pytest.mark.unit def test_constructor(self, ssl_model): @@ -221,3 +371,58 @@ def test_contr_mlm_multi(self, ssl_model): loss_value, loss_val_dict = ssl_model.decoder_loss_step(spectrograms, spec_masks, encoded, encoded_len) assert len(loss_val_dict) == 4 + + +class TestDenoiseMLMSSLModel: + @pytest.mark.unit + def test_forward(self, denoise_mlm_ssl_model): + input_signal = torch.randn(size=(4, 64000)) + input_length = torch.randint(low=48000, high=64000, size=[4]) + noise = 0.1 * torch.ones_like(input_signal) + noisy_input_signal = input_signal + noise + noisy_input_length = input_length + with torch.no_grad(): + with typecheck.disable_checks(): + log_probs, encoded_len, masks, tokens = denoise_mlm_ssl_model.forward( + input_signal=input_signal, + input_signal_length=input_length, + noisy_input_signal=noisy_input_signal, + noisy_input_signal_length=noisy_input_length, + ) + + assert log_probs.size(0) == 4 + assert log_probs.size(2) == denoise_mlm_ssl_model.cfg.model_defaults.num_classes + assert encoded_len.size(0) == 4 + assert masks.size(0) == 4 + assert tokens.size(0) == 4 + assert masks.sum() == 0.0 # no mask should be applied to the input by default + + @pytest.mark.unit + def test_forward_masked(self, denoise_mlm_ssl_model: EncDecDenoiseMaskedTokenPredModel): + input_signal = torch.randn(size=(4, 64000)) + input_length = torch.randint(low=48000, high=64000, size=[4]) + noise = 0.1 * torch.ones_like(input_signal) + noisy_input_signal = input_signal + noise + noisy_input_length = input_length + + with torch.no_grad(): + with typecheck.disable_checks(): + log_probs, encoded_len, masks, tokens = denoise_mlm_ssl_model.forward( + input_signal=input_signal, + input_signal_length=input_length, + noisy_input_signal=noisy_input_signal, + noisy_input_signal_length=noisy_input_length, + apply_mask=True, + ) + + loss_value = denoise_mlm_ssl_model.loss( + masks=masks, decoder_outputs=log_probs, targets=tokens, decoder_lengths=encoded_len + ) + + assert log_probs.size(0) == 4 + assert log_probs.size(2) == denoise_mlm_ssl_model.cfg.model_defaults.num_classes + assert encoded_len.size(0) == 4 + assert masks.size(0) == 4 + assert tokens.size(0) == 4 + assert masks.sum() > 0.0 # mask should be applied to the input + assert not torch.isnan(loss_value) diff --git a/tests/collections/asr/utils/test_transcription_move_to_device.py b/tests/collections/common/test_data_utils.py similarity index 82% rename from tests/collections/asr/utils/test_transcription_move_to_device.py rename to tests/collections/common/test_data_utils.py index 6e95e66c5b26..4e4b8d519c1f 100644 --- a/tests/collections/asr/utils/test_transcription_move_to_device.py +++ b/tests/collections/common/test_data_utils.py @@ -3,7 +3,7 @@ import pytest import torch -from nemo.collections.asr.parts.mixins.transcription import move_to_device +from nemo.collections.common.data.utils import move_data_to_device @dataclass @@ -23,8 +23,8 @@ class _Batch: "not a tensor", ], ) -def test_transcription_move_to_device(batch): - cuda_batch = move_to_device(batch, device="cuda") +def test_move_data_to_device(batch): + cuda_batch = move_data_to_device(batch, device="cuda") assert type(batch) == type(cuda_batch) if isinstance(batch, _Batch): assert cuda_batch.data.is_cuda