Skip to content

Commit

Permalink
[WIP] Add docs for NEST SSL (NVIDIA#10804)
Browse files Browse the repository at this point in the history
* add docs

Signed-off-by: stevehuang52 <[email protected]>

* update doc and fix missing param

Signed-off-by: stevehuang52 <[email protected]>

---------

Signed-off-by: stevehuang52 <[email protected]>
  • Loading branch information
stevehuang52 authored and HuiyingLi committed Nov 15, 2024
1 parent 9520062 commit 70eb548
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 4 deletions.
4 changes: 4 additions & 0 deletions docs/source/asr/ssl/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ NeMo SSL collection API

Model Classes
-------------
.. autoclass:: nemo.collections.asr.models.EncDecDenoiseMaskedTokenPredModel
:show-inheritance:
:members:

.. autoclass:: nemo.collections.asr.models.SpeechEncDecSelfSupervisedModel
:show-inheritance:
:members:
Expand Down
4 changes: 4 additions & 0 deletions docs/source/asr/ssl/intro.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ encoder module of neural ASR models. Here too, majority of SSL effort is focused
While it is common that AM is the focus of SSL in ASR, it can also be utilized in improving other parts of
ASR models (e.g., predictor module in transducer based ASR models).

In NeMo, we provide two types of SSL models, `Wav2Vec-BERT <https://arxiv.org/abs/2108.06209>`_ and `NEST <https://arxiv.org/abs/2408.13106>`_.
The training script for them can be found in `https://github.com/NVIDIA/NeMo/tree/main/examples/asr/speech_pretraining`.


The full documentation tree is as follows:

.. toctree::
Expand Down
4 changes: 2 additions & 2 deletions examples/asr/conf/ssl/nest/nest_fast-conformer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ model:
mask_position: pre_conv # position to apply masking, before or after conv subsampling, choices in ['pre_conv', 'post_conv']

train_ds:
manifest_filepath: ???
noise_manifest: null
manifest_filepath: ??? # path to training manifest, can be a string or list of strings
noise_manifest: ??? # the manifest for noise data, can be a string or list of strings
sample_rate: ${model.sample_rate}
batch_size: 8 # you may increase batch_size if your memory allows
shuffle: true
Expand Down
8 changes: 8 additions & 0 deletions examples/asr/speech_pretraining/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,11 @@ This directory contains example scripts to self-supervised speech models.
There are two main types of supported self-supervised learning methods:
- [Wav2vec-BERT](https://arxiv.org/abs/2108.06209): `speech_pre_training.py`
- [NEST](https://arxiv.org/abs/2408.13106): `masked_token_pred_pretrain.py`
- For downstream tasks that use NEST as multi-layer feature extractor, please refer to `./downstream/speech_classification_mfa_train.py`


For their corresponding usage, please refer to the example yaml config:
- Wav2vec-BERT: `examples/asr/conf/ssl/fastconformer/fast-conformer.yaml`
- NEST: `examples/asr/conf/ssl/nest/nest_fast-conformer.yaml`


2 changes: 2 additions & 0 deletions examples/asr/speech_pretraining/masked_token_pred_pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
python pretrain_masked_token_pred.py \
# (Optional: --config-path=<path to dir of configs> --config-name=<name of config without .yaml>) \
model.train_ds.manifest_filepath=<path to train manifest> \
model.train_ds.noise_manifest=<path to noise manifest> \
model.validation_ds.manifest_filepath=<path to val/test manifest> \
model.validation_ds.noise_manifest=<path to noise manifest> \
trainer.devices=-1 \
trainer.accelerator="gpu" \
strategy="ddp" \
Expand Down
9 changes: 7 additions & 2 deletions nemo/collections/asr/models/ssl_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -996,7 +996,12 @@ def training_step(self, batch: ssl_dataset.AudioNoiseBatch, batch_idx: int):
return {'loss': loss_value, 'log': tensorboard_logs}

def inference_pass(
self, batch: ssl_dataset.AudioNoiseBatch, batch_idx: int, dataloader_idx: int = 0, mode: str = 'val'
self,
batch: ssl_dataset.AudioNoiseBatch,
batch_idx: int,
dataloader_idx: int = 0,
mode: str = 'val',
apply_mask: bool = True,
):
log_probs, encoded_len, masks, tokens = self.forward(
input_signal=batch.audio,
Expand All @@ -1005,7 +1010,7 @@ def inference_pass(
noise_signal_length=batch.noise_len,
noisy_input_signal=batch.noisy_audio,
noisy_input_signal_length=batch.noisy_audio_len,
apply_mask=True,
apply_mask=apply_mask,
)

loss_value = self.loss(masks=masks, decoder_outputs=log_probs, targets=tokens, decoder_lengths=encoded_len)
Expand Down

0 comments on commit 70eb548

Please sign in to comment.