From 70eb54810fdf479ecabe62fc8e9955e257bc7df5 Mon Sep 17 00:00:00 2001 From: "He Huang (Steve)" <105218074+stevehuang52@users.noreply.github.com> Date: Thu, 24 Oct 2024 14:49:58 -0400 Subject: [PATCH] [WIP] Add docs for NEST SSL (#10804) * add docs Signed-off-by: stevehuang52 * update doc and fix missing param Signed-off-by: stevehuang52 --------- Signed-off-by: stevehuang52 --- docs/source/asr/ssl/api.rst | 4 ++++ docs/source/asr/ssl/intro.rst | 4 ++++ examples/asr/conf/ssl/nest/nest_fast-conformer.yaml | 4 ++-- examples/asr/speech_pretraining/README.md | 8 ++++++++ .../asr/speech_pretraining/masked_token_pred_pretrain.py | 2 ++ nemo/collections/asr/models/ssl_models.py | 9 +++++++-- 6 files changed, 27 insertions(+), 4 deletions(-) diff --git a/docs/source/asr/ssl/api.rst b/docs/source/asr/ssl/api.rst index 77614e9ad5e3..16b21bdfb12e 100644 --- a/docs/source/asr/ssl/api.rst +++ b/docs/source/asr/ssl/api.rst @@ -4,6 +4,10 @@ NeMo SSL collection API Model Classes ------------- +.. autoclass:: nemo.collections.asr.models.EncDecDenoiseMaskedTokenPredModel + :show-inheritance: + :members: + .. autoclass:: nemo.collections.asr.models.SpeechEncDecSelfSupervisedModel :show-inheritance: :members: diff --git a/docs/source/asr/ssl/intro.rst b/docs/source/asr/ssl/intro.rst index 76a3a75dcf37..89002711be97 100644 --- a/docs/source/asr/ssl/intro.rst +++ b/docs/source/asr/ssl/intro.rst @@ -19,6 +19,10 @@ encoder module of neural ASR models. Here too, majority of SSL effort is focused While it is common that AM is the focus of SSL in ASR, it can also be utilized in improving other parts of ASR models (e.g., predictor module in transducer based ASR models). +In NeMo, we provide two types of SSL models, `Wav2Vec-BERT `_ and `NEST `_. +The training script for them can be found in `https://github.com/NVIDIA/NeMo/tree/main/examples/asr/speech_pretraining`. + + The full documentation tree is as follows: .. toctree:: diff --git a/examples/asr/conf/ssl/nest/nest_fast-conformer.yaml b/examples/asr/conf/ssl/nest/nest_fast-conformer.yaml index 054c66830d65..2124e6e6f7f1 100644 --- a/examples/asr/conf/ssl/nest/nest_fast-conformer.yaml +++ b/examples/asr/conf/ssl/nest/nest_fast-conformer.yaml @@ -28,8 +28,8 @@ model: mask_position: pre_conv # position to apply masking, before or after conv subsampling, choices in ['pre_conv', 'post_conv'] train_ds: - manifest_filepath: ??? - noise_manifest: null + manifest_filepath: ??? # path to training manifest, can be a string or list of strings + noise_manifest: ??? # the manifest for noise data, can be a string or list of strings sample_rate: ${model.sample_rate} batch_size: 8 # you may increase batch_size if your memory allows shuffle: true diff --git a/examples/asr/speech_pretraining/README.md b/examples/asr/speech_pretraining/README.md index aeafcf69292b..777ea0602789 100644 --- a/examples/asr/speech_pretraining/README.md +++ b/examples/asr/speech_pretraining/README.md @@ -5,3 +5,11 @@ This directory contains example scripts to self-supervised speech models. There are two main types of supported self-supervised learning methods: - [Wav2vec-BERT](https://arxiv.org/abs/2108.06209): `speech_pre_training.py` - [NEST](https://arxiv.org/abs/2408.13106): `masked_token_pred_pretrain.py` + - For downstream tasks that use NEST as multi-layer feature extractor, please refer to `./downstream/speech_classification_mfa_train.py` + + +For their corresponding usage, please refer to the example yaml config: +- Wav2vec-BERT: `examples/asr/conf/ssl/fastconformer/fast-conformer.yaml` +- NEST: `examples/asr/conf/ssl/nest/nest_fast-conformer.yaml` + + diff --git a/examples/asr/speech_pretraining/masked_token_pred_pretrain.py b/examples/asr/speech_pretraining/masked_token_pred_pretrain.py index 83729dfd9d67..1ea88d696643 100644 --- a/examples/asr/speech_pretraining/masked_token_pred_pretrain.py +++ b/examples/asr/speech_pretraining/masked_token_pred_pretrain.py @@ -28,7 +28,9 @@ python pretrain_masked_token_pred.py \ # (Optional: --config-path= --config-name=) \ model.train_ds.manifest_filepath= \ + model.train_ds.noise_manifest= \ model.validation_ds.manifest_filepath= \ + model.validation_ds.noise_manifest= \ trainer.devices=-1 \ trainer.accelerator="gpu" \ strategy="ddp" \ diff --git a/nemo/collections/asr/models/ssl_models.py b/nemo/collections/asr/models/ssl_models.py index 5424ed79e751..633a00d73f5e 100644 --- a/nemo/collections/asr/models/ssl_models.py +++ b/nemo/collections/asr/models/ssl_models.py @@ -996,7 +996,12 @@ def training_step(self, batch: ssl_dataset.AudioNoiseBatch, batch_idx: int): return {'loss': loss_value, 'log': tensorboard_logs} def inference_pass( - self, batch: ssl_dataset.AudioNoiseBatch, batch_idx: int, dataloader_idx: int = 0, mode: str = 'val' + self, + batch: ssl_dataset.AudioNoiseBatch, + batch_idx: int, + dataloader_idx: int = 0, + mode: str = 'val', + apply_mask: bool = True, ): log_probs, encoded_len, masks, tokens = self.forward( input_signal=batch.audio, @@ -1005,7 +1010,7 @@ def inference_pass( noise_signal_length=batch.noise_len, noisy_input_signal=batch.noisy_audio, noisy_input_signal_length=batch.noisy_audio_len, - apply_mask=True, + apply_mask=apply_mask, ) loss_value = self.loss(masks=masks, decoder_outputs=log_probs, targets=tokens, decoder_lengths=encoded_len)