Skip to content

Commit

Permalink
Lhotse support for transcribe_speech_parallel (#11249)
Browse files Browse the repository at this point in the history
* Lhotse support for transcribe_speech_parallel

Signed-off-by: Nune <[email protected]>

* Apply isort and black reformatting

Signed-off-by: nune-tadevosyan <[email protected]>

* Removing prints

Signed-off-by: Nune <[email protected]>

* Remove

Signed-off-by: Nune <[email protected]>

* Adding shard_id

Signed-off-by: Nune <[email protected]>

* Handling empty text fields

Signed-off-by: Nune <[email protected]>

* Apply isort and black reformatting

Signed-off-by: nune-tadevosyan <[email protected]>

* Changing keys

Signed-off-by: Nune <[email protected]>

* Key

Signed-off-by: Nune <[email protected]>

* Commented issues

Signed-off-by: Nune <[email protected]>

* Apply isort and black reformatting

Signed-off-by: nune-tadevosyan <[email protected]>

* Commented issues

Signed-off-by: Nune <[email protected]>

* Apply isort and black reformatting

Signed-off-by: nune-tadevosyan <[email protected]>

* test for lhotse metadata return

Signed-off-by: Nune <[email protected]>

* test for lhotse metadata return

Signed-off-by: Nune <[email protected]>

* Small change

Signed-off-by: Nune <[email protected]>

* Apply isort and black reformatting

Signed-off-by: nune-tadevosyan <[email protected]>

* Support for RNNT and CTC model

Signed-off-by: Nune <[email protected]>

* Support for all models

Signed-off-by: Nune <[email protected]>

* Small change

Signed-off-by: Nune <[email protected]>

* Apply isort and black reformatting

Signed-off-by: nune-tadevosyan <[email protected]>

* Tests for predict_step

Signed-off-by: Nune <[email protected]>

* Apply isort and black reformatting

Signed-off-by: nune-tadevosyan <[email protected]>

* Adding support for force_map_dataset

Signed-off-by: Nune <[email protected]>

* Apply isort and black reformatting

Signed-off-by: nune-tadevosyan <[email protected]>

---------

Signed-off-by: Nune <[email protected]>
Signed-off-by: nune-tadevosyan <[email protected]>
Co-authored-by: nune-tadevosyan <[email protected]>
  • Loading branch information
nune-tadevosyan and nune-tadevosyan authored Nov 25, 2024
1 parent 5094b2e commit ee07261
Show file tree
Hide file tree
Showing 19 changed files with 296 additions and 42 deletions.
10 changes: 9 additions & 1 deletion examples/asr/transcribe_speech_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,14 @@ def main(cfg: ParallelTranscriptionConfig):
cfg.predict_ds.return_sample_id = True
cfg.predict_ds = match_train_config(predict_ds=cfg.predict_ds, train_ds=model.cfg.train_ds)

if cfg.predict_ds.use_lhotse:
OmegaConf.set_struct(cfg.predict_ds, False)
cfg.trainer.use_distributed_sampler = False
cfg.predict_ds.force_finite = True
cfg.predict_ds.force_map_dataset = True
cfg.predict_ds.do_transcribe = True
OmegaConf.set_struct(cfg.predict_ds, True)

if isinstance(model, EncDecMultiTaskModel):
cfg.trainer.use_distributed_sampler = False
OmegaConf.set_struct(cfg.predict_ds, False)
Expand All @@ -172,7 +180,7 @@ def main(cfg: ParallelTranscriptionConfig):

trainer = ptl.Trainer(**cfg.trainer)

if isinstance(model, EncDecMultiTaskModel):
if cfg.predict_ds.use_lhotse:
OmegaConf.set_struct(cfg.predict_ds, False)
cfg.predict_ds.global_rank = trainer.global_rank
cfg.predict_ds.world_size = trainer.world_size
Expand Down
8 changes: 7 additions & 1 deletion nemo/collections/asr/data/audio_to_text_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,10 +867,16 @@ def write_on_batch_end(
sample = sample_id
if isinstance(sample, lhotse.cut.MixedCut):
sample = sample.first_non_padding_cut
if sample.recording.sources[0].source != '':
item["audio_filepath"] = sample.recording.sources[0].source
else:
item["audio_filepath"] = sample.id
item["audio_filepath"] = sample.recording.sources[0].source
item["offset"] = sample.start
item["duration"] = sample.duration
item["text"] = sample.supervisions[0].text
item["text"] = sample.supervisions[0].text or ''
if hasattr(sample, 'shard_id'):
item["shard_id"] = sample.shard_id
item["pred_text"] = transcribed_text
self.outf.write(json.dumps(item) + "\n")
self.samples_num += 1
Expand Down
7 changes: 5 additions & 2 deletions nemo/collections/asr/data/audio_to_text_lhotse.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,18 @@ def output_types(self) -> Optional[Dict[str, NeuralType]]:
'sample_id': NeuralType(tuple('B'), LengthsType(), optional=True),
}

def __init__(self, tokenizer):
def __init__(self, tokenizer, return_cuts=False):
super().__init__()
self.tokenizer = TokenizerWrapper(tokenizer)
self.load_audio = AudioSamples(fault_tolerant=True)
self.return_cuts = return_cuts

def __getitem__(self, cuts) -> Tuple[torch.Tensor, ...]:
audio, audio_lens, cuts = self.load_audio(cuts)
tokens = [
torch.cat(
[
torch.as_tensor(s.tokens if hasattr(s, "tokens") else self.tokenizer(s.text, s.language))
torch.as_tensor(s.tokens if hasattr(s, "tokens") else self.tokenizer(s.text or "", s.language))
for s in c.supervisions
],
dim=0,
Expand All @@ -62,6 +63,8 @@ def __getitem__(self, cuts) -> Tuple[torch.Tensor, ...]:
]
token_lens = torch.tensor([t.size(0) for t in tokens], dtype=torch.long)
tokens = collate_vectors(tokens, padding_value=0)
if self.return_cuts:
return audio, audio_lens, tokens, token_lens, cuts.drop_in_memory_data()
return audio, audio_lens, tokens, token_lens


Expand Down
11 changes: 11 additions & 0 deletions nemo/collections/asr/models/configs/asr_models_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,17 @@ class ASRDatasetConfig(nemo.core.classes.dataset.DatasetConfig):
shard_manifests: bool = False
shuffle_n: int = 0

# lhotse support
use_lhotse: bool = False
tarred_random_access: bool = False
use_bucketing: bool = False
batch_duration: Optional[int] = None
quadratic_duration: Optional[int] = None
bucket_batch_size: Optional[int] = None
bucket_duration_bins: Optional[list] = None
num_buckets: Optional[int] = 0
pin_memory: bool = False

# Optional
int_values: Optional[int] = None
augmentor: Optional[Dict[str, Any]] = None
Expand Down
12 changes: 9 additions & 3 deletions nemo/collections/asr/models/ctc_bpe_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,15 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]):
if config.get("use_lhotse"):
return get_lhotse_dataloader_from_config(
config,
global_rank=self.global_rank,
world_size=self.world_size,
dataset=LhotseSpeechToTextBpeDataset(tokenizer=self.tokenizer),
# During transcription, the model is initially loaded on the CPU.
# To ensure the correct global_rank and world_size are set,
# these values must be passed from the configuration.
global_rank=self.global_rank if not config.get("do_transcribe", False) else config.get("global_rank"),
world_size=self.world_size if not config.get("do_transcribe", False) else config.get("world_size"),
dataset=LhotseSpeechToTextBpeDataset(
tokenizer=self.tokenizer,
return_cuts=config.get("do_transcribe", False),
),
tokenizer=self.tokenizer,
)

Expand Down
11 changes: 8 additions & 3 deletions nemo/collections/asr/models/ctc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,11 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]):
if config.get("use_lhotse"):
return get_lhotse_dataloader_from_config(
config,
global_rank=self.global_rank,
world_size=self.world_size,
# During transcription, the model is initially loaded on the CPU.
# To ensure the correct global_rank and world_size are set,
# these values must be passed from the configuration.
global_rank=self.global_rank if not config.get("do_transcribe", False) else config.get("global_rank"),
world_size=self.world_size if not config.get("do_transcribe", False) else config.get("world_size"),
dataset=LhotseSpeechToTextBpeDataset(
tokenizer=make_parser(
labels=config.get('labels', None),
Expand All @@ -319,6 +322,7 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]):
blank_id=config.get('blank_index', -1),
do_normalize=config.get('normalize_transcripts', False),
),
return_cuts=config.get("do_transcribe", False),
),
)

Expand Down Expand Up @@ -614,7 +618,8 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0):
return_hypotheses=False,
)

sample_id = sample_id.cpu().detach().numpy()
if isinstance(sample_id, torch.Tensor):
sample_id = sample_id.cpu().detach().numpy()
return list(zip(sample_id, transcribed_texts))

def validation_pass(self, batch, batch_idx, dataloader_idx=0):
Expand Down
8 changes: 6 additions & 2 deletions nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,14 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]):
if config.get("use_lhotse"):
return get_lhotse_dataloader_from_config(
config,
global_rank=self.global_rank,
world_size=self.world_size,
# During transcription, the model is initially loaded on the CPU.
# To ensure the correct global_rank and world_size are set,
# these values must be passed from the configuration.
global_rank=self.global_rank if not config.get("do_transcribe", False) else config.get("global_rank"),
world_size=self.world_size if not config.get("do_transcribe", False) else config.get("world_size"),
dataset=LhotseSpeechToTextBpeDataset(
tokenizer=self.tokenizer,
return_cuts=config.get("do_transcribe", False),
),
tokenizer=self.tokenizer,
)
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/asr/models/hybrid_rnnt_ctc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,8 +519,8 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0):
best_hyp_text, all_hyp_text = self.decoding.rnnt_decoder_predictions_tensor(
encoder_output=encoded, encoded_lengths=encoded_len, return_hypotheses=False
)

sample_id = sample_id.cpu().detach().numpy()
if isinstance(sample_id, torch.Tensor):
sample_id = sample_id.cpu().detach().numpy()
return list(zip(sample_id, best_hyp_text))

def validation_pass(self, batch, batch_idx, dataloader_idx):
Expand Down
8 changes: 6 additions & 2 deletions nemo/collections/asr/models/rnnt_bpe_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,10 +509,14 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]):
if config.get("use_lhotse"):
return get_lhotse_dataloader_from_config(
config,
global_rank=self.global_rank,
world_size=self.world_size,
# During transcription, the model is initially loaded on the CPU.
# To ensure the correct global_rank and world_size are set,
# these values must be passed from the configuration.
global_rank=self.global_rank if not config.get("do_transcribe", False) else config.get("global_rank"),
world_size=self.world_size if not config.get("do_transcribe", False) else config.get("world_size"),
dataset=LhotseSpeechToTextBpeDataset(
tokenizer=self.tokenizer,
return_cuts=config.get("do_transcribe", False),
),
tokenizer=self.tokenizer,
)
Expand Down
11 changes: 8 additions & 3 deletions nemo/collections/asr/models/rnnt_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,8 +469,11 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]):
if config.get("use_lhotse"):
return get_lhotse_dataloader_from_config(
config,
global_rank=self.global_rank,
world_size=self.world_size,
# During transcription, the model is initially loaded on the CPU.
# To ensure the correct global_rank and world_size are set,
# these values must be passed from the configuration.
global_rank=self.global_rank if not config.get("do_transcribe", False) else config.get("global_rank"),
world_size=self.world_size if not config.get("do_transcribe", False) else config.get("world_size"),
dataset=LhotseSpeechToTextBpeDataset(
tokenizer=make_parser(
labels=config.get('labels', None),
Expand All @@ -479,6 +482,7 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]):
blank_id=config.get('blank_index', -1),
do_normalize=config.get('normalize_transcripts', False),
),
return_cuts=config.get("do_transcribe", False),
),
)

Expand Down Expand Up @@ -814,7 +818,8 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0):
encoder_output=encoded, encoded_lengths=encoded_len, return_hypotheses=False
)

sample_id = sample_id.cpu().detach().numpy()
if isinstance(sample_id, torch.Tensor):
sample_id = sample_id.cpu().detach().numpy()
return list(zip(sample_id, best_hyp_text))

def validation_pass(self, batch, batch_idx, dataloader_idx=0):
Expand Down
8 changes: 6 additions & 2 deletions nemo/collections/asr/models/transformer_bpe_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,10 +225,14 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]):
config = self._update_default_values(config)
return get_lhotse_dataloader_from_config(
config,
global_rank=self.global_rank,
world_size=self.world_size,
# During transcription, the model is initially loaded on the CPU.
# To ensure the correct global_rank and world_size are set,
# these values must be passed from the configuration.
global_rank=self.global_rank if not config.get("do_transcribe", False) else config.get("global_rank"),
world_size=self.world_size if not config.get("do_transcribe", False) else config.get("world_size"),
dataset=LhotseSpeechToTextBpeDataset(
tokenizer=self.tokenizer,
return_cuts=config.get("do_transcribe", False),
),
tokenizer=self.tokenizer,
)
Expand Down
43 changes: 32 additions & 11 deletions nemo/collections/common/data/lhotse/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,28 @@ class LhotseDataLoadingConfig:
# In most cases (such as regular multi-GPU training) it will result in a deadlock due to
# a different number of steps on different DDP ranks.
force_finite: bool = False
# The following two options may be used to override auto-detection of appropriate PyTorch dataset flavor
# for your data types. PyTorch DataLoader uses two objects to yield data: dataset and sampler.
# *Map-dataset flavor.* There is one sampler per GPU that lives in the training loop process;
# it selects the examples to be prepared by map-dataset class. Each batch selection determined by the sampler
# is then passed by the dataloader to one of its worker processes to be processed by the dataset class.
# *Iterable-dataset flavor.* Each dataloading worker has its own sampler replica instead;
# the sampler must have the logic for either data deduplication or unique order shuffling to avoid
# duplicated data across workers and GPUs. Lhotse relies on unique order shuffling.
# The default settings are:
# * use iterable dataset for tarred audio data.
# * use iterable dataset for any text data.
# * use map dataset for non-tarred audio data (we might change this in the future)
force_map_dataset: bool = False
force_iterable_dataset: bool = False


def determine_use_iterable_dataset(use_iterable_dataset: bool, config: DictConfig) -> bool:
assert not (
config.force_map_dataset and config.force_iterable_dataset
), "Conflicting options: force_map_dataset=True and force_iterable_dataset=True"
use_iterable_dataset = (use_iterable_dataset or config.force_iterable_dataset) and not config.force_map_dataset
return use_iterable_dataset


def get_lhotse_dataloader_from_config(
Expand Down Expand Up @@ -176,7 +198,6 @@ def get_lhotse_dataloader_from_config(
Note that ``tokenizer`` can be any tokenizer type (e.g. both SentencePiece and Aggregate tokenizers work).
"""
logging.info("We will be using a Lhotse DataLoader.")

config = make_structured_with_schema_warnings(config)

maybe_set_cuda_expandable_segments(enabled=config.cuda_expandable_segments)
Expand All @@ -186,8 +207,8 @@ def get_lhotse_dataloader_from_config(
fix_random_seed(seed)

# 1. Load a manifest as a Lhotse CutSet.
cuts, is_tarred = read_cutset_from_config(config)

cuts, use_iterable_dataset = read_cutset_from_config(config)
use_iterable_dataset = determine_use_iterable_dataset(use_iterable_dataset, config)
# Apply channel selector
if config.channel_selector is not None:
logging.info('Using channel selector %s.', config.channel_selector)
Expand All @@ -202,7 +223,7 @@ def get_lhotse_dataloader_from_config(
if tokenizer is not None and config.pretokenize:
from nemo.collections.asr.data.audio_to_text_lhotse import TokenizerWrapper

if not is_tarred:
if not use_iterable_dataset:
logging.warning(
"You are using a non-tarred dataset and requested tokenization during data sampling (pretokenize=True). "
"This will cause the tokenization to happen in the main (GPU) process, possibly impacting the training speed "
Expand Down Expand Up @@ -317,8 +338,8 @@ def get_lhotse_dataloader_from_config(
duration_bins=determine_bucket_duration_bins(config),
num_cuts_for_bins_estimate=config.num_cuts_for_bins_estimate,
buffer_size=config.bucket_buffer_size,
rank=0 if is_tarred else global_rank,
world_size=1 if is_tarred else world_size,
rank=0 if use_iterable_dataset else global_rank,
world_size=1 if use_iterable_dataset else world_size,
)
else:
# Non-bucketing sampler, similar to original NeMo dataloading without bucketing,
Expand All @@ -335,8 +356,8 @@ def get_lhotse_dataloader_from_config(
drop_last=config.drop_last,
shuffle_buffer_size=config.shuffle_buffer_size,
seed=config.shard_seed,
rank=0 if is_tarred else global_rank,
world_size=1 if is_tarred else world_size,
rank=0 if use_iterable_dataset else global_rank,
world_size=1 if use_iterable_dataset else world_size,
)

if config.concatenate_samples:
Expand Down Expand Up @@ -368,7 +389,7 @@ def get_lhotse_dataloader_from_config(
)

# 4. Creating dataloader.
if is_tarred and not config.tarred_random_access:
if use_iterable_dataset and not config.tarred_random_access:
# Wrapper here is necessary when using NeMo tarred data or Lhotse Shar data,
# because then I/O happens upon sampler iteration. Normally, the sampler resides
# in the training loop process, but when we use iterable dataset, we can move it to
Expand Down Expand Up @@ -601,8 +622,8 @@ class DurationFilter:
"""Callable, returns ``True`` if a cut's duration is in range [d_min, d_max] and ``False`` otherwise."""

def __init__(self, d_min: float, d_max: float) -> None:
self.d_min = d_min
self.d_max = d_max
self.d_min = d_min if d_min is not None else -1.0
self.d_max = d_max if d_max is not None else float("inf")

def __call__(self, example) -> bool:
if isinstance(example, Cut):
Expand Down
33 changes: 33 additions & 0 deletions tests/collections/asr/test_asr_ctc_encoder_model_bpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@

import pytest
import torch
from lhotse import CutSet, MonoCut
from lhotse.testing.dummies import DummyManifest
from omegaconf import DictConfig

from nemo.collections.asr.data import audio_to_text
from nemo.collections.asr.data.audio_to_text_lhotse import LhotseSpeechToTextBpeDataset
from nemo.collections.asr.models import configs
from nemo.collections.asr.models.ctc_bpe_models import EncDecCTCModelBPE
from nemo.collections.asr.parts.submodules import ctc_beam_decoding as beam_decode
Expand Down Expand Up @@ -118,6 +121,18 @@ def test_forward(self, asr_model):
diff = torch.max(torch.abs(logprobs_instance - logprobs_batch))
assert diff <= 1e-6

@pytest.mark.unit
def test_predict_step(self, asr_model):
asr_model = asr_model.eval()
cuts = DummyManifest(CutSet, begin_id=0, end_id=1, with_data=True)
dataset = LhotseSpeechToTextBpeDataset(tokenizer=asr_model.tokenizer, return_cuts=True)
batch = dataset[cuts]
outputs = asr_model.predict_step(batch, 0)
assert len(outputs) == 1
assert len(outputs[0]) == 2
assert isinstance(outputs[0][0], MonoCut)
assert isinstance(outputs[0][1], str)

@pytest.mark.with_downloads()
@pytest.mark.unit
def test_save_restore_artifact(self, asr_model):
Expand Down Expand Up @@ -333,6 +348,15 @@ def test_ASRDatasetConfig_for_AudioToBPEDataset(self):
'bucketing_strategy',
'bucketing_weights',
'channel_selector',
'use_lhotse',
'tarred_random_access',
'use_bucketing',
'batch_duration',
'quadratic_duration',
'bucket_batch_size',
'bucket_duration_bins',
'num_buckets',
'pin_memory',
]

REMAP_ARGS = {'trim_silence': 'trim', 'labels': 'tokenizer'}
Expand Down Expand Up @@ -372,6 +396,15 @@ def test_ASRDatasetConfig_for_TarredAudioToBPEDataset(self):
'bucketing_strategy',
'bucketing_weights',
'max_utts',
'use_lhotse',
'tarred_random_access',
'use_bucketing',
'batch_duration',
'quadratic_duration',
'bucket_batch_size',
'bucket_duration_bins',
'num_buckets',
'pin_memory',
]

REMAP_ARGS = {
Expand Down
Loading

0 comments on commit ee07261

Please sign in to comment.