Skip to content

Commit

Permalink
Adding docstrings and some changes.
Browse files Browse the repository at this point in the history
  • Loading branch information
nune-tadevosyan committed Jul 10, 2024
1 parent a106b21 commit 8400e74
Show file tree
Hide file tree
Showing 14 changed files with 400 additions and 106 deletions.
15 changes: 9 additions & 6 deletions nemo/collections/asr/data/audio_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,7 @@ class _AudioTextDataset(Dataset):
pad_id: Id of pad symbol. Defaults to 0
return_sample_id (bool): whether to return the sample_id as a part of each sample
channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`. Uses zero-based indexing.
cache_audio: If True, will cache manifests and audio from object store
"""

@property
Expand Down Expand Up @@ -455,13 +456,13 @@ def __init__(
pad_id: int = 0,
return_sample_id: bool = False,
channel_selector: Optional[ChannelSelectorType] = None,
do_caching: bool = True,
cache_audio: bool = True,
):
if type(manifest_filepath) == str:
manifest_filepath = manifest_filepath.split(",")

# If necessary, cache manifests and audio from object store
if do_caching:
if cache_audio:
cache_datastore_manifests(manifest_filepaths=manifest_filepath, cache_audio=True)
self.manifest_processor = ASRManifestProcessor(
manifest_filepath=manifest_filepath,
Expand Down Expand Up @@ -551,6 +552,7 @@ class AudioToCharDataset(_AudioTextDataset):
eos_id: Id of end of sequence symbol to append if not None
return_sample_id (bool): whether to return the sample_id as a part of each sample
channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`. Uses zero-based indexing.
cache_audio: If True, will cache manifests and audio from object store
"""

@property
Expand Down Expand Up @@ -584,7 +586,7 @@ def __init__(
parser: Union[str, Callable] = 'en',
return_sample_id: bool = False,
channel_selector: Optional[ChannelSelectorType] = None,
do_caching: bool = True,
cache_audio: bool = True,
):
self.labels = labels

Expand All @@ -607,7 +609,7 @@ def __init__(
pad_id=pad_id,
return_sample_id=return_sample_id,
channel_selector=channel_selector,
do_caching=do_caching,
cache_audio=cache_audio,
)


Expand Down Expand Up @@ -646,6 +648,7 @@ class AudioToBPEDataset(_AudioTextDataset):
tokens to beginning and ending of speech respectively.
return_sample_id (bool): whether to return the sample_id as a part of each sample
channel_selector (int | Iterable[int] | str): select a single channel or a subset of channels from multi-channel audio. If set to `'average'`, it performs averaging across channels. Disabled if set to `None`. Defaults to `None`. Uses zero-based indexing.
cache_audio: If True, will cache manifests and audio from object store
"""

@property
Expand Down Expand Up @@ -673,7 +676,7 @@ def __init__(
use_start_end_token: bool = True,
return_sample_id: bool = False,
channel_selector: Optional[ChannelSelectorType] = None,
do_caching: bool = True,
cache_audio: bool = True,
):
if use_start_end_token and hasattr(tokenizer, "bos_id") and tokenizer.bos_id > 0:
bos_id = tokenizer.bos_id
Expand Down Expand Up @@ -723,7 +726,7 @@ def __call__(self, *args):
trim=trim,
return_sample_id=return_sample_id,
channel_selector=channel_selector,
do_caching=do_caching,
cache_audio=cache_audio,
)


Expand Down
16 changes: 8 additions & 8 deletions nemo/collections/asr/data/audio_to_text_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def get_concat_char_dataset(


def get_char_dataset(
config: dict, augmentor: Optional['AudioAugmentor'] = None, do_caching: bool = True
config: dict, augmentor: Optional['AudioAugmentor'] = None, cache_audio: bool = True
) -> audio_to_text.AudioToCharDataset:
"""
Instantiates a Character Encoding based AudioToCharDataset.
Expand Down Expand Up @@ -158,7 +158,7 @@ def get_char_dataset(
parser=config.get('parser', 'en'),
return_sample_id=config.get('return_sample_id', False),
channel_selector=config.get('channel_selector', None),
do_caching=do_caching,
cache_audio=cache_audio,
)
return dataset

Expand Down Expand Up @@ -213,7 +213,7 @@ def get_concat_bpe_dataset(


def get_bpe_dataset(
config: dict, tokenizer: 'TokenizerSpec', augmentor: Optional['AudioAugmentor'] = None, do_caching=True
config: dict, tokenizer: 'TokenizerSpec', augmentor: Optional['AudioAugmentor'] = None, cache_audio=True
) -> audio_to_text.AudioToBPEDataset:
"""
Instantiates a Byte Pair Encoding / Word Piece Encoding based AudioToBPEDataset.
Expand All @@ -239,7 +239,7 @@ def get_bpe_dataset(
use_start_end_token=config.get('use_start_end_token', True),
return_sample_id=config.get('return_sample_id', False),
channel_selector=config.get('channel_selector', None),
do_caching=do_caching,
cache_audio=cache_audio,
)
return dataset

Expand Down Expand Up @@ -592,7 +592,7 @@ def get_audio_to_text_char_dataset_from_config(
global_rank: int,
world_size: int,
preprocessor_cfg: Optional[DictConfig] = None,
do_caching: bool = True,
cache_audio: bool = True,
):
"""
Construct Audio-To-Text Char dataset from a config.
Expand Down Expand Up @@ -710,7 +710,7 @@ def get_audio_to_text_char_dataset_from_config(
config=config, global_rank=global_rank, world_size=world_size, augmentor=augmentor
)
else:
dataset = get_char_dataset(config=config, augmentor=augmentor, do_caching=do_caching)
dataset = get_char_dataset(config=config, augmentor=augmentor, cache_audio=cache_audio)
return dataset


Expand All @@ -721,7 +721,7 @@ def get_audio_to_text_bpe_dataset_from_config(
world_size: int,
tokenizer,
preprocessor_cfg: Optional[DictConfig] = None,
do_caching: bool = True,
cache_audio: bool = True,
):
"""
Construct Audio-To-Text BPE dataset from a config.
Expand Down Expand Up @@ -848,7 +848,7 @@ def get_audio_to_text_bpe_dataset_from_config(
augmentor=augmentor,
)
else:
dataset = get_bpe_dataset(config=config, tokenizer=tokenizer, augmentor=augmentor, do_caching=do_caching)
dataset = get_bpe_dataset(config=config, tokenizer=tokenizer, augmentor=augmentor, cache_audio=cache_audio)
return dataset


Expand Down
8 changes: 4 additions & 4 deletions nemo/collections/asr/models/ctc_bpe_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __init__(self, cfg: DictConfig, trainer=None):
log_prediction=self._cfg.get("log_prediction", False),
)

def _setup_dataloader_from_config(self, config: Optional[Dict], do_caching: bool = True):
def _setup_dataloader_from_config(self, config: Optional[Dict], cache_audio: bool = True):
if config.get("use_lhotse"):
return get_lhotse_dataloader_from_config(
config,
Expand All @@ -109,7 +109,7 @@ def _setup_dataloader_from_config(self, config: Optional[Dict], do_caching: bool
world_size=self.world_size,
tokenizer=self.tokenizer,
preprocessor_cfg=self.cfg.get("preprocessor", None),
do_caching=do_caching,
cache_audio=cache_audio,
)

if dataset is None:
Expand Down Expand Up @@ -164,7 +164,7 @@ def _setup_pseudo_label_dataloader(
batch_size: int = 64,
):
"""
Setup function for a data loader for unlabeled dataset
Setup function for a data loader for pseudo-labelled dataset
Args:
manifest_filepaths: Manifests containing information of unlabeled dataset. For tarred dataset manifests should be sharded
Expand Down Expand Up @@ -217,7 +217,7 @@ def _setup_pseudo_label_dataloader(
}

dataset = audio_to_text_dataset.get_bpe_dataset(
config=dl_config, tokenizer=self.tokenizer, augmentor=None, do_caching=False
config=dl_config, tokenizer=self.tokenizer, augmentor=None, cache_audio=False
)
if hasattr(dataset, 'collate_fn'):
collate_fn = dataset.collate_fn
Expand Down
34 changes: 17 additions & 17 deletions nemo/collections/asr/models/ctc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import json
import os
import tempfile
from math import ceil
from typing import Any, Dict, List, Optional, Union

import editdistance
import numpy as np
import torch
from omegaconf import DictConfig, OmegaConf, open_dict
Expand All @@ -30,7 +27,6 @@
from nemo.collections.asr.data.audio_to_text import (
_AudioTextDataset,
cache_datastore_manifests,
expand_sharded_filepaths,
)
from nemo.collections.asr.data.audio_to_text_dali import AudioToCharDALIDataset, DALIOutputs
from nemo.collections.asr.data.audio_to_text_lhotse import LhotseSpeechToTextBpeDataset
Expand Down Expand Up @@ -129,6 +125,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):

# Adapter modules setup (from ASRAdapterModelMixin)
self.setup_adapters()
self.setup_ipl(model_type="ctc")

def on_fit_start(self):
"""
Expand All @@ -142,9 +139,9 @@ def on_fit_start(self):

def on_train_epoch_end(self):
"""
This function is mainly used for iterative pseudo labeling algorithm.
This function is mainly used for IPL algorithm.
To make it work in config file 'ipl' parameters should be provided.
For details, see: SlimIPL:(https://arxiv.org/pdf/2010.11524).
"""
self.maybe_do_ipl()

Expand Down Expand Up @@ -304,7 +301,7 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig):

logging.info(f"Changed decoding strategy to \n{OmegaConf.to_yaml(self.cfg.decoding)}")

def _setup_dataloader_from_config(self, config: Optional[Dict], do_caching: bool = True):
def _setup_dataloader_from_config(self, config: Optional[Dict], cache_audio: bool = True):
# Automatically inject args from model config to dataloader config
audio_to_text_dataset.inject_dataloader_value_from_model_config(self.cfg, config, key='sample_rate')
audio_to_text_dataset.inject_dataloader_value_from_model_config(self.cfg, config, key='labels')
Expand All @@ -331,7 +328,7 @@ def _setup_dataloader_from_config(self, config: Optional[Dict], do_caching: bool
global_rank=self.global_rank,
world_size=self.world_size,
preprocessor_cfg=self._cfg.get("preprocessor", None),
do_caching=do_caching,
cache_audio=cache_audio,
)

if dataset is None:
Expand Down Expand Up @@ -382,7 +379,7 @@ def _setup_dataloader_from_config(self, config: Optional[Dict], do_caching: bool
def setup_training_data(
self,
train_data_config: Optional[Union[DictConfig, Dict]],
do_caching: bool = True,
cache_audio: bool = True,
update_limit_train_batches: bool = False,
):
"""
Expand All @@ -405,7 +402,7 @@ def setup_training_data(
# preserve config
self._update_dataset_config(dataset_name='train', config=train_data_config)

self._train_dl = self._setup_dataloader_from_config(config=train_data_config, do_caching=do_caching)
self._train_dl = self._setup_dataloader_from_config(config=train_data_config, cache_audio=cache_audio)

# Need to set this because if using an IterableDataset, the length of the dataloader is the total number
# of samples rather than the number of batches, and this messes up the tqdm progress bar.
Expand All @@ -429,7 +426,7 @@ def setup_training_data(
"training batches will be used. Please set the trainer and rebuild the dataset."
)
elif update_limit_train_batches:
# after generation of pseud-labels for tarred datasets.
# after generation of pseudo-labels for tarred datasets.

self._trainer.limit_train_batches = int(
ceil((len(self._train_dl.dataset) / self.world_size) / train_data_config['batch_size'])
Expand Down Expand Up @@ -705,13 +702,16 @@ def test_dataloader(self):
if self._test_dl is not None:
return self._test_dl

def _setup_pseudo_label_dataloader(self, cache_manifest: str, audio_tar: str = None, batch_size: int = 64):

def _setup_pseudo_label_dataloader(self,
manifest_filepaths: Union[List[List[str]], str],
tarred_audio_filepaths: Union[List[List[str]], str] = None,
batch_size: int = 64,
):
if self.cfg.train_ds.get("is_tarred", False):

dl_config = {
'manifest_filepath': cache_manifest,
'tarred_audio_filepaths': audio_tar,
'manifest_filepath': manifest_filepaths,
'tarred_audio_filepaths': tarred_audio_filepaths,
'sample_rate': self.preprocessor._sample_rate,
'labels': OmegaConf.to_container(self.decoder.vocabulary),
'is_tarred': True,
Expand Down Expand Up @@ -741,7 +741,7 @@ def _setup_pseudo_label_dataloader(self, cache_manifest: str, audio_tar: str = N
else:

dl_config = {
'manifest_filepath': cache_manifest,
'manifest_filepath': manifest_filepaths,
'sample_rate': self.preprocessor._sample_rate,
'labels': self.joint.vocabulary,
'batch_size': batch_size,
Expand All @@ -751,7 +751,7 @@ def _setup_pseudo_label_dataloader(self, cache_manifest: str, audio_tar: str = N
'pin_memory': True,
}

dataset = audio_to_text_dataset.get_char_dataset(config=dl_config, augmentor=None, do_caching=False)
dataset = audio_to_text_dataset.get_char_dataset(config=dl_config, augmentor=None, cache_audio=False)
if hasattr(dataset, 'collate_fn'):
collate_fn = dataset.collate_fn
elif hasattr(dataset.datasets[0], 'collate_fn'):
Expand Down
6 changes: 3 additions & 3 deletions nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
# setting the RNNT decoder as the default one
self.cur_decoder = "rnnt"

def _setup_dataloader_from_config(self, config: Optional[Dict], do_caching: bool = True):
def _setup_dataloader_from_config(self, config: Optional[Dict], cache_audio: bool = True):

if config.get("use_lhotse"):
return get_lhotse_dataloader_from_config(
Expand All @@ -154,7 +154,7 @@ def _setup_dataloader_from_config(self, config: Optional[Dict], do_caching: bool
world_size=self.world_size,
tokenizer=self.tokenizer,
preprocessor_cfg=self.cfg.get("preprocessor", None),
do_caching=do_caching,
cache_audio=cache_audio,
)

if dataset is None:
Expand Down Expand Up @@ -263,7 +263,7 @@ def _setup_pseudo_label_dataloader(
}

dataset = audio_to_text_dataset.get_bpe_dataset(
config=dl_config, tokenizer=self.tokenizer, augmentor=None, do_caching=False
config=dl_config, tokenizer=self.tokenizer, augmentor=None, cache_audio=False
)
if hasattr(dataset, 'collate_fn'):
collate_fn = dataset.collate_fn
Expand Down
17 changes: 6 additions & 11 deletions nemo/collections/asr/models/hybrid_rnnt_ctc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,14 @@
# limitations under the License.

import copy
import json
import os
import tempfile
from typing import Any, List, Optional, Tuple
from typing import Any, List, Optional, Tuple, Union

import editdistance
import torch
from omegaconf import DictConfig, OmegaConf, open_dict
from pytorch_lightning import Trainer
from tqdm.auto import tqdm

from nemo.collections.asr.data import audio_to_text_dataset
from nemo.collections.asr.data.audio_to_text import cache_datastore_manifests, expand_sharded_filepaths
from nemo.collections.asr.data.audio_to_text import cache_datastore_manifests
from nemo.collections.asr.data.audio_to_text_dali import DALIOutputs
from nemo.collections.asr.data.audio_to_text_lhotse import LhotseSpeechToTextBpeDataset
from nemo.collections.asr.losses.ctc import CTCLoss
Expand Down Expand Up @@ -100,7 +95,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):

# setting up interCTC loss (from InterCTCMixin)
self.setup_interctc(decoder_name='ctc_decoder', loss_name='ctc_loss', wer_name='ctc_wer')
self.setup_ipl()
self.setup_ipl(model_type="hybrid")

def on_fit_start(self):
"""
Expand All @@ -114,9 +109,9 @@ def on_fit_start(self):

def on_train_epoch_end(self):
"""
This function is mainly used for iterative pseudo labeling algorithm.
This function is mainly used for IPL algorithm.
To make it work in config file 'ipl' parameters should be provided.
For details, see: SlimIPL:(https://arxiv.org/pdf/2010.11524).
"""
self.maybe_do_ipl()

Expand Down Expand Up @@ -182,7 +177,7 @@ def _setup_pseudo_label_dataloader(
'pin_memory': True,
}

dataset = audio_to_text_dataset.get_char_dataset(config=dl_config, augmentor=None, do_caching=False)
dataset = audio_to_text_dataset.get_char_dataset(config=dl_config, augmentor=None, cache_audio=False)
if hasattr(dataset, 'collate_fn'):
collate_fn = dataset.collate_fn
elif hasattr(dataset.datasets[0], 'collate_fn'):
Expand Down
Loading

0 comments on commit 8400e74

Please sign in to comment.