diff --git a/examples/asr/transcribe_speech_parallel.py b/examples/asr/transcribe_speech_parallel.py index bdf54ea67f7d..d60099acd379 100644 --- a/examples/asr/transcribe_speech_parallel.py +++ b/examples/asr/transcribe_speech_parallel.py @@ -163,6 +163,14 @@ def main(cfg: ParallelTranscriptionConfig): cfg.predict_ds.return_sample_id = True cfg.predict_ds = match_train_config(predict_ds=cfg.predict_ds, train_ds=model.cfg.train_ds) + if cfg.predict_ds.use_lhotse: + OmegaConf.set_struct(cfg.predict_ds, False) + cfg.trainer.use_distributed_sampler = False + cfg.predict_ds.force_finite = True + cfg.predict_ds.force_map_dataset = True + cfg.predict_ds.do_transcribe = True + OmegaConf.set_struct(cfg.predict_ds, True) + if isinstance(model, EncDecMultiTaskModel): cfg.trainer.use_distributed_sampler = False OmegaConf.set_struct(cfg.predict_ds, False) @@ -172,7 +180,7 @@ def main(cfg: ParallelTranscriptionConfig): trainer = ptl.Trainer(**cfg.trainer) - if isinstance(model, EncDecMultiTaskModel): + if cfg.predict_ds.use_lhotse: OmegaConf.set_struct(cfg.predict_ds, False) cfg.predict_ds.global_rank = trainer.global_rank cfg.predict_ds.world_size = trainer.world_size diff --git a/nemo/collections/asr/data/audio_to_text_dataset.py b/nemo/collections/asr/data/audio_to_text_dataset.py index 76537a8b2b78..f91710de3cb3 100644 --- a/nemo/collections/asr/data/audio_to_text_dataset.py +++ b/nemo/collections/asr/data/audio_to_text_dataset.py @@ -867,10 +867,16 @@ def write_on_batch_end( sample = sample_id if isinstance(sample, lhotse.cut.MixedCut): sample = sample.first_non_padding_cut + if sample.recording.sources[0].source != '': + item["audio_filepath"] = sample.recording.sources[0].source + else: + item["audio_filepath"] = sample.id item["audio_filepath"] = sample.recording.sources[0].source item["offset"] = sample.start item["duration"] = sample.duration - item["text"] = sample.supervisions[0].text + item["text"] = sample.supervisions[0].text or '' + if hasattr(sample, 'shard_id'): + item["shard_id"] = sample.shard_id item["pred_text"] = transcribed_text self.outf.write(json.dumps(item) + "\n") self.samples_num += 1 diff --git a/nemo/collections/asr/data/audio_to_text_lhotse.py b/nemo/collections/asr/data/audio_to_text_lhotse.py index f916ae1de56b..0ae3059a9296 100644 --- a/nemo/collections/asr/data/audio_to_text_lhotse.py +++ b/nemo/collections/asr/data/audio_to_text_lhotse.py @@ -43,17 +43,18 @@ def output_types(self) -> Optional[Dict[str, NeuralType]]: 'sample_id': NeuralType(tuple('B'), LengthsType(), optional=True), } - def __init__(self, tokenizer): + def __init__(self, tokenizer, return_cuts=False): super().__init__() self.tokenizer = TokenizerWrapper(tokenizer) self.load_audio = AudioSamples(fault_tolerant=True) + self.return_cuts = return_cuts def __getitem__(self, cuts) -> Tuple[torch.Tensor, ...]: audio, audio_lens, cuts = self.load_audio(cuts) tokens = [ torch.cat( [ - torch.as_tensor(s.tokens if hasattr(s, "tokens") else self.tokenizer(s.text, s.language)) + torch.as_tensor(s.tokens if hasattr(s, "tokens") else self.tokenizer(s.text or "", s.language)) for s in c.supervisions ], dim=0, @@ -62,6 +63,8 @@ def __getitem__(self, cuts) -> Tuple[torch.Tensor, ...]: ] token_lens = torch.tensor([t.size(0) for t in tokens], dtype=torch.long) tokens = collate_vectors(tokens, padding_value=0) + if self.return_cuts: + return audio, audio_lens, tokens, token_lens, cuts.drop_in_memory_data() return audio, audio_lens, tokens, token_lens diff --git a/nemo/collections/asr/models/configs/asr_models_config.py b/nemo/collections/asr/models/configs/asr_models_config.py index 29dbbe06d1f8..081233da5d32 100644 --- a/nemo/collections/asr/models/configs/asr_models_config.py +++ b/nemo/collections/asr/models/configs/asr_models_config.py @@ -41,6 +41,17 @@ class ASRDatasetConfig(nemo.core.classes.dataset.DatasetConfig): shard_manifests: bool = False shuffle_n: int = 0 + # lhotse support + use_lhotse: bool = False + tarred_random_access: bool = False + use_bucketing: bool = False + batch_duration: Optional[int] = None + quadratic_duration: Optional[int] = None + bucket_batch_size: Optional[int] = None + bucket_duration_bins: Optional[list] = None + num_buckets: Optional[int] = 0 + pin_memory: bool = False + # Optional int_values: Optional[int] = None augmentor: Optional[Dict[str, Any]] = None diff --git a/nemo/collections/asr/models/ctc_bpe_models.py b/nemo/collections/asr/models/ctc_bpe_models.py index 79c22794de01..1f84989c8ebe 100644 --- a/nemo/collections/asr/models/ctc_bpe_models.py +++ b/nemo/collections/asr/models/ctc_bpe_models.py @@ -97,9 +97,15 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]): if config.get("use_lhotse"): return get_lhotse_dataloader_from_config( config, - global_rank=self.global_rank, - world_size=self.world_size, - dataset=LhotseSpeechToTextBpeDataset(tokenizer=self.tokenizer), + # During transcription, the model is initially loaded on the CPU. + # To ensure the correct global_rank and world_size are set, + # these values must be passed from the configuration. + global_rank=self.global_rank if not config.get("do_transcribe", False) else config.get("global_rank"), + world_size=self.world_size if not config.get("do_transcribe", False) else config.get("world_size"), + dataset=LhotseSpeechToTextBpeDataset( + tokenizer=self.tokenizer, + return_cuts=config.get("do_transcribe", False), + ), tokenizer=self.tokenizer, ) diff --git a/nemo/collections/asr/models/ctc_models.py b/nemo/collections/asr/models/ctc_models.py index 76dcd13cca50..ae8c35220931 100644 --- a/nemo/collections/asr/models/ctc_models.py +++ b/nemo/collections/asr/models/ctc_models.py @@ -309,8 +309,11 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]): if config.get("use_lhotse"): return get_lhotse_dataloader_from_config( config, - global_rank=self.global_rank, - world_size=self.world_size, + # During transcription, the model is initially loaded on the CPU. + # To ensure the correct global_rank and world_size are set, + # these values must be passed from the configuration. + global_rank=self.global_rank if not config.get("do_transcribe", False) else config.get("global_rank"), + world_size=self.world_size if not config.get("do_transcribe", False) else config.get("world_size"), dataset=LhotseSpeechToTextBpeDataset( tokenizer=make_parser( labels=config.get('labels', None), @@ -319,6 +322,7 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]): blank_id=config.get('blank_index', -1), do_normalize=config.get('normalize_transcripts', False), ), + return_cuts=config.get("do_transcribe", False), ), ) @@ -614,7 +618,8 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): return_hypotheses=False, ) - sample_id = sample_id.cpu().detach().numpy() + if isinstance(sample_id, torch.Tensor): + sample_id = sample_id.cpu().detach().numpy() return list(zip(sample_id, transcribed_texts)) def validation_pass(self, batch, batch_idx, dataloader_idx=0): diff --git a/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py index 7e8720ee3ad8..cd04a5ad2462 100644 --- a/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py +++ b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py @@ -140,10 +140,14 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]): if config.get("use_lhotse"): return get_lhotse_dataloader_from_config( config, - global_rank=self.global_rank, - world_size=self.world_size, + # During transcription, the model is initially loaded on the CPU. + # To ensure the correct global_rank and world_size are set, + # these values must be passed from the configuration. + global_rank=self.global_rank if not config.get("do_transcribe", False) else config.get("global_rank"), + world_size=self.world_size if not config.get("do_transcribe", False) else config.get("world_size"), dataset=LhotseSpeechToTextBpeDataset( tokenizer=self.tokenizer, + return_cuts=config.get("do_transcribe", False), ), tokenizer=self.tokenizer, ) diff --git a/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py b/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py index 34dd9aae5711..1f63c617cea2 100644 --- a/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py +++ b/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py @@ -519,8 +519,8 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): best_hyp_text, all_hyp_text = self.decoding.rnnt_decoder_predictions_tensor( encoder_output=encoded, encoded_lengths=encoded_len, return_hypotheses=False ) - - sample_id = sample_id.cpu().detach().numpy() + if isinstance(sample_id, torch.Tensor): + sample_id = sample_id.cpu().detach().numpy() return list(zip(sample_id, best_hyp_text)) def validation_pass(self, batch, batch_idx, dataloader_idx): diff --git a/nemo/collections/asr/models/rnnt_bpe_models.py b/nemo/collections/asr/models/rnnt_bpe_models.py index c92bcfaaef7a..cd8667f2f0fe 100644 --- a/nemo/collections/asr/models/rnnt_bpe_models.py +++ b/nemo/collections/asr/models/rnnt_bpe_models.py @@ -509,10 +509,14 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]): if config.get("use_lhotse"): return get_lhotse_dataloader_from_config( config, - global_rank=self.global_rank, - world_size=self.world_size, + # During transcription, the model is initially loaded on the CPU. + # To ensure the correct global_rank and world_size are set, + # these values must be passed from the configuration. + global_rank=self.global_rank if not config.get("do_transcribe", False) else config.get("global_rank"), + world_size=self.world_size if not config.get("do_transcribe", False) else config.get("world_size"), dataset=LhotseSpeechToTextBpeDataset( tokenizer=self.tokenizer, + return_cuts=config.get("do_transcribe", False), ), tokenizer=self.tokenizer, ) diff --git a/nemo/collections/asr/models/rnnt_models.py b/nemo/collections/asr/models/rnnt_models.py index e4d1abd0b50c..78038d404107 100644 --- a/nemo/collections/asr/models/rnnt_models.py +++ b/nemo/collections/asr/models/rnnt_models.py @@ -469,8 +469,11 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]): if config.get("use_lhotse"): return get_lhotse_dataloader_from_config( config, - global_rank=self.global_rank, - world_size=self.world_size, + # During transcription, the model is initially loaded on the CPU. + # To ensure the correct global_rank and world_size are set, + # these values must be passed from the configuration. + global_rank=self.global_rank if not config.get("do_transcribe", False) else config.get("global_rank"), + world_size=self.world_size if not config.get("do_transcribe", False) else config.get("world_size"), dataset=LhotseSpeechToTextBpeDataset( tokenizer=make_parser( labels=config.get('labels', None), @@ -479,6 +482,7 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]): blank_id=config.get('blank_index', -1), do_normalize=config.get('normalize_transcripts', False), ), + return_cuts=config.get("do_transcribe", False), ), ) @@ -814,7 +818,8 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0): encoder_output=encoded, encoded_lengths=encoded_len, return_hypotheses=False ) - sample_id = sample_id.cpu().detach().numpy() + if isinstance(sample_id, torch.Tensor): + sample_id = sample_id.cpu().detach().numpy() return list(zip(sample_id, best_hyp_text)) def validation_pass(self, batch, batch_idx, dataloader_idx=0): diff --git a/nemo/collections/asr/models/transformer_bpe_models.py b/nemo/collections/asr/models/transformer_bpe_models.py index 8d0f2b2223a3..4692cb662b4b 100644 --- a/nemo/collections/asr/models/transformer_bpe_models.py +++ b/nemo/collections/asr/models/transformer_bpe_models.py @@ -225,10 +225,14 @@ def _setup_dataloader_from_config(self, config: Optional[Dict]): config = self._update_default_values(config) return get_lhotse_dataloader_from_config( config, - global_rank=self.global_rank, - world_size=self.world_size, + # During transcription, the model is initially loaded on the CPU. + # To ensure the correct global_rank and world_size are set, + # these values must be passed from the configuration. + global_rank=self.global_rank if not config.get("do_transcribe", False) else config.get("global_rank"), + world_size=self.world_size if not config.get("do_transcribe", False) else config.get("world_size"), dataset=LhotseSpeechToTextBpeDataset( tokenizer=self.tokenizer, + return_cuts=config.get("do_transcribe", False), ), tokenizer=self.tokenizer, ) diff --git a/nemo/collections/common/data/lhotse/dataloader.py b/nemo/collections/common/data/lhotse/dataloader.py index 98b63a07fa9d..bf6b77ad907e 100644 --- a/nemo/collections/common/data/lhotse/dataloader.py +++ b/nemo/collections/common/data/lhotse/dataloader.py @@ -147,6 +147,28 @@ class LhotseDataLoadingConfig: # In most cases (such as regular multi-GPU training) it will result in a deadlock due to # a different number of steps on different DDP ranks. force_finite: bool = False + # The following two options may be used to override auto-detection of appropriate PyTorch dataset flavor + # for your data types. PyTorch DataLoader uses two objects to yield data: dataset and sampler. + # *Map-dataset flavor.* There is one sampler per GPU that lives in the training loop process; + # it selects the examples to be prepared by map-dataset class. Each batch selection determined by the sampler + # is then passed by the dataloader to one of its worker processes to be processed by the dataset class. + # *Iterable-dataset flavor.* Each dataloading worker has its own sampler replica instead; + # the sampler must have the logic for either data deduplication or unique order shuffling to avoid + # duplicated data across workers and GPUs. Lhotse relies on unique order shuffling. + # The default settings are: + # * use iterable dataset for tarred audio data. + # * use iterable dataset for any text data. + # * use map dataset for non-tarred audio data (we might change this in the future) + force_map_dataset: bool = False + force_iterable_dataset: bool = False + + +def determine_use_iterable_dataset(use_iterable_dataset: bool, config: DictConfig) -> bool: + assert not ( + config.force_map_dataset and config.force_iterable_dataset + ), "Conflicting options: force_map_dataset=True and force_iterable_dataset=True" + use_iterable_dataset = (use_iterable_dataset or config.force_iterable_dataset) and not config.force_map_dataset + return use_iterable_dataset def get_lhotse_dataloader_from_config( @@ -176,7 +198,6 @@ def get_lhotse_dataloader_from_config( Note that ``tokenizer`` can be any tokenizer type (e.g. both SentencePiece and Aggregate tokenizers work). """ logging.info("We will be using a Lhotse DataLoader.") - config = make_structured_with_schema_warnings(config) maybe_set_cuda_expandable_segments(enabled=config.cuda_expandable_segments) @@ -186,8 +207,8 @@ def get_lhotse_dataloader_from_config( fix_random_seed(seed) # 1. Load a manifest as a Lhotse CutSet. - cuts, is_tarred = read_cutset_from_config(config) - + cuts, use_iterable_dataset = read_cutset_from_config(config) + use_iterable_dataset = determine_use_iterable_dataset(use_iterable_dataset, config) # Apply channel selector if config.channel_selector is not None: logging.info('Using channel selector %s.', config.channel_selector) @@ -202,7 +223,7 @@ def get_lhotse_dataloader_from_config( if tokenizer is not None and config.pretokenize: from nemo.collections.asr.data.audio_to_text_lhotse import TokenizerWrapper - if not is_tarred: + if not use_iterable_dataset: logging.warning( "You are using a non-tarred dataset and requested tokenization during data sampling (pretokenize=True). " "This will cause the tokenization to happen in the main (GPU) process, possibly impacting the training speed " @@ -317,8 +338,8 @@ def get_lhotse_dataloader_from_config( duration_bins=determine_bucket_duration_bins(config), num_cuts_for_bins_estimate=config.num_cuts_for_bins_estimate, buffer_size=config.bucket_buffer_size, - rank=0 if is_tarred else global_rank, - world_size=1 if is_tarred else world_size, + rank=0 if use_iterable_dataset else global_rank, + world_size=1 if use_iterable_dataset else world_size, ) else: # Non-bucketing sampler, similar to original NeMo dataloading without bucketing, @@ -335,8 +356,8 @@ def get_lhotse_dataloader_from_config( drop_last=config.drop_last, shuffle_buffer_size=config.shuffle_buffer_size, seed=config.shard_seed, - rank=0 if is_tarred else global_rank, - world_size=1 if is_tarred else world_size, + rank=0 if use_iterable_dataset else global_rank, + world_size=1 if use_iterable_dataset else world_size, ) if config.concatenate_samples: @@ -368,7 +389,7 @@ def get_lhotse_dataloader_from_config( ) # 4. Creating dataloader. - if is_tarred and not config.tarred_random_access: + if use_iterable_dataset and not config.tarred_random_access: # Wrapper here is necessary when using NeMo tarred data or Lhotse Shar data, # because then I/O happens upon sampler iteration. Normally, the sampler resides # in the training loop process, but when we use iterable dataset, we can move it to @@ -601,8 +622,8 @@ class DurationFilter: """Callable, returns ``True`` if a cut's duration is in range [d_min, d_max] and ``False`` otherwise.""" def __init__(self, d_min: float, d_max: float) -> None: - self.d_min = d_min - self.d_max = d_max + self.d_min = d_min if d_min is not None else -1.0 + self.d_max = d_max if d_max is not None else float("inf") def __call__(self, example) -> bool: if isinstance(example, Cut): diff --git a/tests/collections/asr/test_asr_ctc_encoder_model_bpe.py b/tests/collections/asr/test_asr_ctc_encoder_model_bpe.py index 247906247091..02442291a918 100644 --- a/tests/collections/asr/test_asr_ctc_encoder_model_bpe.py +++ b/tests/collections/asr/test_asr_ctc_encoder_model_bpe.py @@ -19,9 +19,12 @@ import pytest import torch +from lhotse import CutSet, MonoCut +from lhotse.testing.dummies import DummyManifest from omegaconf import DictConfig from nemo.collections.asr.data import audio_to_text +from nemo.collections.asr.data.audio_to_text_lhotse import LhotseSpeechToTextBpeDataset from nemo.collections.asr.models import configs from nemo.collections.asr.models.ctc_bpe_models import EncDecCTCModelBPE from nemo.collections.asr.parts.submodules import ctc_beam_decoding as beam_decode @@ -118,6 +121,18 @@ def test_forward(self, asr_model): diff = torch.max(torch.abs(logprobs_instance - logprobs_batch)) assert diff <= 1e-6 + @pytest.mark.unit + def test_predict_step(self, asr_model): + asr_model = asr_model.eval() + cuts = DummyManifest(CutSet, begin_id=0, end_id=1, with_data=True) + dataset = LhotseSpeechToTextBpeDataset(tokenizer=asr_model.tokenizer, return_cuts=True) + batch = dataset[cuts] + outputs = asr_model.predict_step(batch, 0) + assert len(outputs) == 1 + assert len(outputs[0]) == 2 + assert isinstance(outputs[0][0], MonoCut) + assert isinstance(outputs[0][1], str) + @pytest.mark.with_downloads() @pytest.mark.unit def test_save_restore_artifact(self, asr_model): @@ -333,6 +348,15 @@ def test_ASRDatasetConfig_for_AudioToBPEDataset(self): 'bucketing_strategy', 'bucketing_weights', 'channel_selector', + 'use_lhotse', + 'tarred_random_access', + 'use_bucketing', + 'batch_duration', + 'quadratic_duration', + 'bucket_batch_size', + 'bucket_duration_bins', + 'num_buckets', + 'pin_memory', ] REMAP_ARGS = {'trim_silence': 'trim', 'labels': 'tokenizer'} @@ -372,6 +396,15 @@ def test_ASRDatasetConfig_for_TarredAudioToBPEDataset(self): 'bucketing_strategy', 'bucketing_weights', 'max_utts', + 'use_lhotse', + 'tarred_random_access', + 'use_bucketing', + 'batch_duration', + 'quadratic_duration', + 'bucket_batch_size', + 'bucket_duration_bins', + 'num_buckets', + 'pin_memory', ] REMAP_ARGS = { diff --git a/tests/collections/asr/test_asr_ctcencdec_model.py b/tests/collections/asr/test_asr_ctcencdec_model.py index 28a07fd54663..55451758578f 100644 --- a/tests/collections/asr/test_asr_ctcencdec_model.py +++ b/tests/collections/asr/test_asr_ctcencdec_model.py @@ -15,12 +15,16 @@ import pytest import torch +from lhotse import CutSet, MonoCut +from lhotse.testing.dummies import DummyManifest from omegaconf import DictConfig, OmegaConf, open_dict import nemo.collections.asr as nemo_asr from nemo.collections.asr.data import audio_to_text +from nemo.collections.asr.data.audio_to_text_lhotse import LhotseSpeechToTextBpeDataset from nemo.collections.asr.models import EncDecCTCModel, configs from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecoding, CTCDecodingConfig +from nemo.collections.common.parts.preprocessing.parsers import make_parser from nemo.utils.config_utils import assert_dataclass_signature_match, update_model_config @@ -131,6 +135,19 @@ def test_forward(self, asr_model): diff = torch.max(torch.abs(logprobs_instance - logprobs_batch)) assert diff <= 1e-6 + @pytest.mark.unit + def test_predict_step(self, asr_model): + token_list = [" ", "a", "b", "c"] + asr_model = asr_model.eval() + cuts = DummyManifest(CutSet, begin_id=0, end_id=1, with_data=True) + dataset = LhotseSpeechToTextBpeDataset(tokenizer=make_parser(labels=token_list), return_cuts=True) + batch = dataset[cuts] + outputs = asr_model.predict_step(batch, 0) + assert len(outputs) == 1 + assert len(outputs[0]) == 2 + assert isinstance(outputs[0][0], MonoCut) + assert isinstance(outputs[0][1], str) + @pytest.mark.unit def test_vocab_change(self, asr_model): old_vocab = copy.deepcopy(asr_model.decoder.vocabulary) @@ -274,6 +291,15 @@ def test_ASRDatasetConfig_for_AudioToCharDataset(self): 'bucketing_strategy', 'bucketing_weights', 'channel_selector', + 'use_lhotse', + 'tarred_random_access', + 'use_bucketing', + 'batch_duration', + 'quadratic_duration', + 'bucket_batch_size', + 'bucket_duration_bins', + 'num_buckets', + 'pin_memory', ] REMAP_ARGS = {'trim_silence': 'trim'} @@ -307,6 +333,15 @@ def test_ASRDatasetConfig_for_TarredAudioToCharDataset(self): 'bucketing_strategy', 'bucketing_weights', 'max_utts', + 'use_lhotse', + 'tarred_random_access', + 'use_bucketing', + 'batch_duration', + 'quadratic_duration', + 'bucket_batch_size', + 'bucket_duration_bins', + 'num_buckets', + 'pin_memory', ] REMAP_ARGS = { diff --git a/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_bpe.py b/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_bpe.py index 1743acc6878c..d13c879e47f9 100644 --- a/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_bpe.py +++ b/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_bpe.py @@ -18,8 +18,11 @@ import pytest import torch +from lhotse import CutSet, MonoCut +from lhotse.testing.dummies import DummyManifest from omegaconf import DictConfig +from nemo.collections.asr.data.audio_to_text_lhotse import LhotseSpeechToTextBpeDataset from nemo.collections.asr.models.hybrid_rnnt_ctc_bpe_models import EncDecHybridRNNTCTCBPEModel from nemo.collections.asr.parts.submodules import rnnt_beam_decoding as beam_decode from nemo.collections.asr.parts.submodules import rnnt_greedy_decoding as greedy_decode @@ -166,6 +169,18 @@ def test_forward(self, hybrid_asr_model): diff = torch.max(torch.abs(logits_instance - logprobs_batch)) assert diff <= 1e-6 + @pytest.mark.unit + def test_predict_step(self, hybrid_asr_model): + hybrid_asr_model = hybrid_asr_model.eval() + cuts = DummyManifest(CutSet, begin_id=0, end_id=1, with_data=True) + dataset = LhotseSpeechToTextBpeDataset(tokenizer=hybrid_asr_model.tokenizer, return_cuts=True) + batch = dataset[cuts] + outputs = hybrid_asr_model.predict_step(batch, 0) + assert len(outputs) == 1 + assert len(outputs[0]) == 2 + assert isinstance(outputs[0][0], MonoCut) + assert isinstance(outputs[0][1], str) + @pytest.mark.with_downloads() @pytest.mark.skipif( not NUMBA_RNNT_LOSS_AVAILABLE, diff --git a/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_char.py b/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_char.py index 5362966e2e9e..b5c34e197237 100644 --- a/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_char.py +++ b/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_char.py @@ -16,14 +16,18 @@ import pytest import torch +from lhotse import CutSet, MonoCut +from lhotse.testing.dummies import DummyManifest from omegaconf import DictConfig, ListConfig +from nemo.collections.asr.data.audio_to_text_lhotse import LhotseSpeechToTextBpeDataset from nemo.collections.asr.models import EncDecHybridRNNTCTCModel from nemo.collections.asr.modules import RNNTDecoder, RNNTJoint, SampledRNNTJoint, StatelessTransducerDecoder from nemo.collections.asr.parts.submodules import rnnt_beam_decoding as beam_decode from nemo.collections.asr.parts.submodules import rnnt_greedy_decoding as greedy_decode from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecoding, CTCDecodingConfig from nemo.collections.asr.parts.utils import rnnt_utils +from nemo.collections.common.parts.preprocessing.parsers import make_parser from nemo.core.utils import numba_utils from nemo.core.utils.numba_utils import __NUMBA_MINIMUM_VERSION__ from nemo.utils.config_utils import assert_dataclass_signature_match @@ -164,6 +168,19 @@ def test_forward(self, hybrid_asr_model): diff = torch.max(torch.abs(logprobs_instance - logprobs_batch)) assert diff <= 1e-6 + @pytest.mark.unit + def test_predict_step(self, hybrid_asr_model): + token_list = [" ", "a", "b", "c"] + hybrid_asr_model = hybrid_asr_model.eval() + cuts = DummyManifest(CutSet, begin_id=0, end_id=1, with_data=True) + dataset = LhotseSpeechToTextBpeDataset(tokenizer=make_parser(labels=token_list), return_cuts=True) + batch = dataset[cuts] + outputs = hybrid_asr_model.predict_step(batch, 0) + assert len(outputs) == 1 + assert len(outputs[0]) == 2 + assert isinstance(outputs[0][0], MonoCut) + assert isinstance(outputs[0][1], str) + @pytest.mark.skipif( not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', diff --git a/tests/collections/asr/test_asr_lhotse_dataset.py b/tests/collections/asr/test_asr_lhotse_dataset.py index 5a1450e606ac..c131fac70310 100644 --- a/tests/collections/asr/test_asr_lhotse_dataset.py +++ b/tests/collections/asr/test_asr_lhotse_dataset.py @@ -65,3 +65,35 @@ def test_lhotse_asr_dataset(tokenizer): assert tokens[2].tolist() == [1, 7, 10, 19, 20, 21, 1, 20, 6, 4, 16, 15, 5] assert token_lens.tolist() == [11, 11, 13] + + +def test_lhotse_asr_dataset_metadata(tokenizer): + + cuts = DummyManifest(CutSet, begin_id=0, end_id=2, with_data=True) + + cuts[0].id = "cuts0" + cuts[1].id = "cuts1" + cuts[0].supervisions = [ + SupervisionSegment(id="cuts0-sup0", recording_id=cuts[0].recording_id, start=0.2, duration=0.5, text="first"), + ] + cuts[1].supervisions = [ + SupervisionSegment(id="cuts1-sup0", recording_id=cuts[1].recording_id, start=0, duration=1, text=""), + ] + + datasets_metadata = LhotseSpeechToTextBpeDataset(tokenizer=tokenizer, return_cuts=True) + batch = datasets_metadata[cuts] + assert isinstance(batch, tuple) + assert len(batch) == 5 + + _, _, _, _, cuts_metadata = batch + + assert cuts_metadata[0].supervisions[0].text == "first" + assert cuts_metadata[1].supervisions[0].text == "" + assert cuts_metadata[0].id == "cuts0" + assert cuts_metadata[1].id == "cuts1" + + assert cuts_metadata[0].supervisions[0].duration == 0.5 + assert cuts_metadata[0].supervisions[0].start == 0.2 + + assert cuts_metadata[1].supervisions[0].duration == 1 + assert cuts_metadata[1].supervisions[0].start == 0.0 diff --git a/tests/collections/asr/test_asr_rnnt_encdec_model.py b/tests/collections/asr/test_asr_rnnt_encdec_model.py index d68088fce376..5e810243c919 100644 --- a/tests/collections/asr/test_asr_rnnt_encdec_model.py +++ b/tests/collections/asr/test_asr_rnnt_encdec_model.py @@ -17,13 +17,17 @@ import pytest import torch import torch.nn.functional as F +from lhotse import CutSet, MonoCut +from lhotse.testing.dummies import DummyManifest from omegaconf import DictConfig, ListConfig +from nemo.collections.asr.data.audio_to_text_lhotse import LhotseSpeechToTextBpeDataset from nemo.collections.asr.models import EncDecRNNTModel from nemo.collections.asr.modules import HATJoint, RNNTDecoder, RNNTJoint, SampledRNNTJoint, StatelessTransducerDecoder from nemo.collections.asr.parts.submodules import rnnt_beam_decoding as beam_decode from nemo.collections.asr.parts.submodules import rnnt_greedy_decoding as greedy_decode from nemo.collections.asr.parts.utils import rnnt_utils +from nemo.collections.common.parts.preprocessing.parsers import make_parser from nemo.core.utils import numba_utils from nemo.core.utils.numba_utils import __NUMBA_MINIMUM_VERSION__ from nemo.utils.config_utils import assert_dataclass_signature_match @@ -296,6 +300,19 @@ def test_forward(self, asr_model): diff = torch.max(torch.abs(logprobs_instance - logprobs_batch)) assert diff <= 1e-6 + @pytest.mark.unit + def test_predict_step(self, asr_model): + token_list = [" ", "a", "b", "c"] + asr_model = asr_model.eval() + cuts = DummyManifest(CutSet, begin_id=0, end_id=1, with_data=True) + dataset = LhotseSpeechToTextBpeDataset(tokenizer=make_parser(labels=token_list), return_cuts=True) + batch = dataset[cuts] + outputs = asr_model.predict_step(batch, 0) + assert len(outputs) == 1 + assert len(outputs[0]) == 2 + assert isinstance(outputs[0][0], MonoCut) + assert isinstance(outputs[0][1], str) + @pytest.mark.skipif( not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', diff --git a/tests/collections/asr/test_asr_rnnt_encoder_model_bpe.py b/tests/collections/asr/test_asr_rnnt_encoder_model_bpe.py index 960445061e24..aba364868e88 100644 --- a/tests/collections/asr/test_asr_rnnt_encoder_model_bpe.py +++ b/tests/collections/asr/test_asr_rnnt_encoder_model_bpe.py @@ -18,8 +18,11 @@ import pytest import torch +from lhotse import CutSet, MonoCut +from lhotse.testing.dummies import DummyManifest from omegaconf import DictConfig +from nemo.collections.asr.data.audio_to_text_lhotse import LhotseSpeechToTextBpeDataset from nemo.collections.asr.models import ASRModel from nemo.collections.asr.models.rnnt_bpe_models import EncDecRNNTBPEModel from nemo.collections.asr.parts.submodules import rnnt_beam_decoding as beam_decode @@ -64,12 +67,18 @@ def asr_model(test_data_dir): decoder = { '_target_': 'nemo.collections.asr.modules.RNNTDecoder', - 'prednet': {'pred_hidden': model_defaults['pred_hidden'], 'pred_rnn_layers': 1,}, + 'prednet': { + 'pred_hidden': model_defaults['pred_hidden'], + 'pred_rnn_layers': 1, + }, } joint = { '_target_': 'nemo.collections.asr.modules.RNNTJoint', - 'jointnet': {'joint_hidden': 32, 'activation': 'relu',}, + 'jointnet': { + 'joint_hidden': 32, + 'activation': 'relu', + }, } decoding = {'strategy': 'greedy_batch', 'greedy': {'max_symbols': 30}} @@ -123,7 +132,8 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): class TestEncDecRNNTBPEModel: @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.with_downloads() @pytest.mark.unit @@ -137,7 +147,8 @@ def test_constructor(self, asr_model): @pytest.mark.with_downloads() @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit def test_forward(self, asr_model): @@ -170,9 +181,22 @@ def test_forward(self, asr_model): diff = torch.max(torch.abs(logits_instance - logprobs_batch)) assert diff <= 1e-6 + @pytest.mark.unit + def test_predict_step(self, asr_model): + asr_model = asr_model.eval() + cuts = DummyManifest(CutSet, begin_id=0, end_id=1, with_data=True) + dataset = LhotseSpeechToTextBpeDataset(tokenizer=asr_model.tokenizer, return_cuts=True) + batch = dataset[cuts] + outputs = asr_model.predict_step(batch, 0) + assert len(outputs) == 1 + assert len(outputs[0]) == 2 + assert isinstance(outputs[0][0], MonoCut) + assert isinstance(outputs[0][1], str) + @pytest.mark.with_downloads() @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit def test_save_restore_artifact(self, asr_model): @@ -190,7 +214,8 @@ def test_save_restore_artifact(self, asr_model): @pytest.mark.with_downloads() @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit def test_save_restore_artifact_spe(self, asr_model, test_data_dir): @@ -236,7 +261,8 @@ def test_save_restore_artifact_agg(self, asr_model, test_data_dir): @pytest.mark.with_downloads() @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit def test_vocab_change(self, test_data_dir, asr_model): @@ -266,7 +292,8 @@ def test_vocab_change(self, test_data_dir, asr_model): @pytest.mark.with_downloads() @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) @pytest.mark.unit def test_decoding_change(self, asr_model): @@ -309,7 +336,8 @@ def test_decoding_change(self, asr_model): @pytest.mark.with_downloads() @pytest.mark.unit @pytest.mark.skipif( - not NUMBA_RNNT_LOSS_AVAILABLE, reason='RNNTLoss has not been compiled with appropriate numba version.', + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', ) def test_save_restore_nested_model(self): with tempfile.TemporaryDirectory() as tmp_dir: @@ -330,7 +358,7 @@ def test_save_restore_nested_model(self): # Check size of the checkpoint, which contains weights from pretrained model + linear layer fp_weights = os.path.join(tmp_dir, 'model_weights.ckpt') - assert os.path.getsize(fp_weights) > 50 * (2 ** 20) # Assert the weights are more than 50 MB + assert os.path.getsize(fp_weights) > 50 * (2**20) # Assert the weights are more than 50 MB # Check if param after restoration is exact match original_state_dict = model.inner_model.state_dict()