From a2572a7bd57e6d6d7b95090a792455638740203f Mon Sep 17 00:00:00 2001 From: lilithgrigoryan <38436437+lilithgrigoryan@users.noreply.github.com> Date: Wed, 13 Nov 2024 15:24:51 +0400 Subject: [PATCH] Beam search algorithm implementation for TDT models (#10903) * initial commit Signed-off-by: lilithgrigoryan * add: default beam search implementation Signed-off-by: lilithgrigoryan * fix: changed to removing duplicate hypothesis in separate function Signed-off-by: lilithgrigoryan * fix: changed to cartesian product in choosing best hyp Signed-off-by: lilithgrigoryan * fix: minor fixes in comments Signed-off-by: lilithgrigoryan * add: maes decoding strategy Signed-off-by: lilithgrigoryan * add: durations filtering in maes, lm fusion in progress Signed-off-by: lilithgrigoryan * fix: refactored, added comments, command line args, finalized Signed-off-by: lilithgrigoryan * fix: removed prints Signed-off-by: lilithgrigoryan * add: docs Signed-off-by: lilithgrigoryan * Apply isort and black reformatting Signed-off-by: lilithgrigoryan * fix: minor fix Signed-off-by: lilithgrigoryan * fix: rm beam_size=1 exception, rm duplicates check, fix error handling Signed-off-by: lilithgrigoryan * fix: error handling Signed-off-by: lilithgrigoryan * Apply isort and black reformatting Signed-off-by: lilithgrigoryan * fix: removed evaluations file Signed-off-by: lilithgrigoryan * rn: blank scoring Signed-off-by: lilithgrigoryan * clean up Signed-off-by: lilithgrigoryan * rm: blank scoring and duration beam size Signed-off-by: lilithgrigoryan * Apply isort and black reformatting Signed-off-by: lilithgrigoryan * fix: removed durations_beam_size from default beam search Signed-off-by: lilithgrigoryan * add: logaddexp Signed-off-by: lilithgrigoryan * rm: prefix search Signed-off-by: lilithgrigoryan * rn: nested loop over extensions Signed-off-by: lilithgrigoryan * fix: bug with caching Signed-off-by: lilithgrigoryan * rm: topk on durations Signed-off-by: lilithgrigoryan * add: restored prefix search Signed-off-by: lilithgrigoryan * Apply isort and black reformatting Signed-off-by: lilithgrigoryan * clean up Signed-off-by: lilithgrigoryan * fix: fixed comments Signed-off-by: lilithgrigoryan * refactored duplicate merging Signed-off-by: lilithgrigoryan * changes batch scoring Signed-off-by: lilithgrigoryan * refactored rnnt batch scoring Signed-off-by: lilithgrigoryan * alsd first working Signed-off-by: lilithgrigoryan * refactored Signed-off-by: lilithgrigoryan * clean up Signed-off-by: lilithgrigoryan * remove stacking operations Signed-off-by: lilithgrigoryan * fixes im base class Signed-off-by: lilithgrigoryan * clean up Signed-off-by: lilithgrigoryan * Apply isort and black reformatting Signed-off-by: lilithgrigoryan * remove potentially uninitialized local variable Signed-off-by: lilithgrigoryan * default beam search minor fixes Signed-off-by: lilithgrigoryan * add test, fix maes timesteps Signed-off-by: lilithgrigoryan * rm file Signed-off-by: lilithgrigoryan * rm file Signed-off-by: lilithgrigoryan * clean up Signed-off-by: lilithgrigoryan * Apply isort and black reformatting Signed-off-by: lilithgrigoryan * clean up Signed-off-by: lilithgrigoryan * fix comments Signed-off-by: lilithgrigoryan * add ngram lm test Signed-off-by: lilithgrigoryan * Apply isort and black reformatting Signed-off-by: lilithgrigoryan * fix maes_num_steps=1 Signed-off-by: lilithgrigoryan * fix kenlm model path Signed-off-by: lilithgrigoryan * fix kenlm model full path Signed-off-by: lilithgrigoryan * Apply isort and black reformatting Signed-off-by: lilithgrigoryan * made requested changes Signed-off-by: lilithgrigoryan * merge after isort Signed-off-by: lilithgrigoryan * add prints to test Signed-off-by: lilithgrigoryan * Apply isort and black reformatting Signed-off-by: lilithgrigoryan * add Kenlm to asr requirements Signed-off-by: lilithgrigoryan * remove prints in tests Signed-off-by: lilithgrigoryan * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add kenlm to test requirements Signed-off-by: lilithgrigoryan * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * rm kenlm from link, add package-name Signed-off-by: lilithgrigoryan * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * rm second kenlm installation Signed-off-by: lilithgrigoryan * rm kenlm from dependencies make test optional Signed-off-by: lilithgrigoryan * Apply isort and black reformatting Signed-off-by: lilithgrigoryan * fix in test Signed-off-by: lilithgrigoryan * fix in test Signed-off-by: lilithgrigoryan * Apply isort and black reformatting Signed-off-by: lilithgrigoryan * fix comments Signed-off-by: lilithgrigoryan * Apply isort and black reformatting Signed-off-by: lilithgrigoryan * add comments Signed-off-by: lilithgrigoryan * add comments Signed-off-by: lilithgrigoryan * splitted docstrings Signed-off-by: lilithgrigoryan * Apply isort and black reformatting Signed-off-by: lilithgrigoryan * add comments Signed-off-by: lilithgrigoryan * splitted docstrings Signed-off-by: lilithgrigoryan * Apply isort and black reformatting Signed-off-by: lilithgrigoryan * add comments Signed-off-by: lilithgrigoryan * Apply isort and black reformatting Signed-off-by: lilithgrigoryan * fixes to python3 type annotations Signed-off-by: lilithgrigoryan * Apply isort and black reformatting Signed-off-by: lilithgrigoryan * merging Signed-off-by: lilithgrigoryan * merging Signed-off-by: lilithgrigoryan * fix in return type Signed-off-by: lilithgrigoryan * Apply isort and black reformatting Signed-off-by: lilithgrigoryan * fix test Signed-off-by: lilithgrigoryan * Apply isort and black reformatting Signed-off-by: lilithgrigoryan * rm time_idx Signed-off-by: lilithgrigoryan * fix comments to python3 style Signed-off-by: lilithgrigoryan --------- Signed-off-by: lilithgrigoryan Signed-off-by: lilithgrigoryan Co-authored-by: lilithgrigoryan Co-authored-by: lilithgrigoryan Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- docs/source/asr/api.rst | 15 + .../parts/submodules/rnnt_beam_decoding.py | 56 +- .../asr/parts/submodules/rnnt_decoding.py | 227 +++-- .../parts/submodules/rnnt_greedy_decoding.py | 54 +- .../asr/parts/submodules/tdt_beam_decoding.py | 800 ++++++++++++++++++ .../collections/asr/parts/utils/rnnt_utils.py | 12 +- .../asr/decoding/test_rnnt_decoding.py | 88 +- 7 files changed, 1155 insertions(+), 97 deletions(-) create mode 100644 nemo/collections/asr/parts/submodules/tdt_beam_decoding.py diff --git a/docs/source/asr/api.rst b/docs/source/asr/api.rst index c99d92c0371a..a35ea49ea2c4 100644 --- a/docs/source/asr/api.rst +++ b/docs/source/asr/api.rst @@ -276,6 +276,21 @@ RNNT Decoding :show-inheritance: :members: +TDT Decoding +~~~~~~~~~~~~~ + +.. autoclass:: nemo.collections.asr.parts.submodules.rnnt_greedy_decoding.GreedyTDTInfer + :show-inheritance: + :members: + +.. autoclass:: nemo.collections.asr.parts.submodules.rnnt_greedy_decoding.GreedyBatchedTDTInfer + :show-inheritance: + :members: + +.. autoclass:: nemo.collections.asr.parts.submodules.tdt_beam_decoding.BeamTDTInfer + :show-inheritance: + :members: + Hypotheses ~~~~~~~~~~ diff --git a/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py index c01f2363db75..e0bd47bb8ce0 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py @@ -55,6 +55,20 @@ def pack_hypotheses(hypotheses: List[Hypothesis]) -> List[Hypothesis]: + """ + Packs a list of hypotheses into a tensor and prepares decoder states. + + This function takes a list of token sequences (hypotheses) and converts + it into a tensor format. If any decoder states are on the GPU, they + are moved to the CPU. Additionally, the function removes any timesteps + with a value of -1 from the sequences. + + Args: + hypotheses (list): A list of token sequences representing hypotheses. + + Returns: + list: A list of packed hypotheses in tensor format. + """ for idx, hyp in enumerate(hypotheses): # type: rnnt_utils.Hypothesis hyp.y_sequence = torch.tensor(hyp.y_sequence, dtype=torch.long) @@ -69,6 +83,18 @@ def pack_hypotheses(hypotheses: List[Hypothesis]) -> List[Hypothesis]: def _states_to_device(dec_state, device='cpu'): + """ + Transfers decoder states to the specified device. + + This function moves the provided decoder states to the specified device (e.g., 'cpu' or 'cuda'). + + Args: + dec_state (Tensor): The decoder states to be transferred. + device (str): The target device to which the decoder states should be moved. Defaults to 'cpu'. + + Returns: + Tensor: The decoder states on the specified device. + """ if torch.is_tensor(dec_state): dec_state = dec_state.to(device) @@ -106,7 +132,8 @@ class BeamRNNTInfer(Typing): however the time required for the search also grows steadily. `tsd` - time synchronous decoding. Please refer to the paper: - [Alignment-Length Synchronous Decoding for RNN Transducer](https://ieeexplore.ieee.org/document/9053040) + [Alignment-Length Synchronous Decoding for RNN Transducer] + (https://ieeexplore.ieee.org/document/9053040) for details on the algorithm implemented. Time synchronous decoding (TSD) execution time grows by the factor T * max_symmetric_expansions. @@ -114,7 +141,8 @@ class BeamRNNTInfer(Typing): good results. This also requires greater memory to execute. `alsd` - alignment-length synchronous decoding. Please refer to the paper: - [Alignment-Length Synchronous Decoding for RNN Transducer](https://ieeexplore.ieee.org/document/9053040) + [Alignment-Length Synchronous Decoding for RNN Transducer] + (https://ieeexplore.ieee.org/document/9053040) for details on the algorithm implemented. Alignment-length synchronous decoding (ALSD) execution time is faster than TSD, with growth @@ -127,7 +155,8 @@ class BeamRNNTInfer(Typing): For a given decoding accuracy, it is possible to attain faster decoding via ALSD than TSD. `maes` = modified adaptive expansion searcn. Please refer to the paper: - [Accelerating RNN Transducer Inference via Adaptive Expansion Search](https://ieeexplore.ieee.org/document/9250505) + [Accelerating RNN Transducer Inference via Adaptive Expansion Search] + (https://ieeexplore.ieee.org/document/9250505) Modified Adaptive Synchronous Decoding (mAES) execution time is adaptive w.r.t the number of expansions (for tokens) required per timestep. The number of expansions can usually @@ -169,10 +198,10 @@ class BeamRNNTInfer(Typing): and affects the speed of inference since large values will perform large beam search in the next step. maes_expansion_gamma: Float pruning threshold used in the prune-by-value step when computing the expansions. - The default (2.3) is selected from the paper. It performs a comparison (max_log_prob - gamma <= log_prob[v]) - where v is all vocabulary indices in the Vocab set and max_log_prob is the "most" likely token to be - predicted. Gamma therefore provides a margin of additional tokens which can be potential candidates for - expansion apart from the "most likely" candidate. + The default (2.3) is selected from the paper. It performs a comparison + (max_log_prob - gamma <= log_prob[v]) where v is all vocabulary indices in the Vocab set and max_log_prob + is the "most" likely token to be predicted. Gamma therefore provides a margin of additional tokens which + can be potential candidates for expansion apart from the "most likely" candidate. Lower values will reduce the number of expansions (by increasing pruning-by-value, thereby improving speed but hurting accuracy). Higher values will increase the number of expansions (by reducing pruning-by-value, thereby reducing speed but potentially improving accuracy). This is a hyper parameter to be experimentally @@ -182,7 +211,7 @@ class BeamRNNTInfer(Typing): preserve_alignments: Bool flag which preserves the history of alignments generated during beam decoding (sample). When set to true, the Hypothesis will contain - the non-null value for `alignments` in it. Here, `alignments` is a List of List of Tensor (of length V + 1). + the non-null value for `alignments` in it. Here, `alignments` is a List of List of Tensor (of length V + 1) The length of the list corresponds to the Acoustic Length (T). Each value in the list (Ti) is a torch.Tensor (U), representing 1 or more targets from a vocabulary. @@ -1456,8 +1485,11 @@ def compute_ngram_score(self, current_lm_state: "kenlm.State", label: int) -> Tu return lm_score, next_state def set_decoding_type(self, decoding_type: str): - - # Please check train_kenlm.py in scripts/asr_language_modeling/ to find out why we need + """ + Sets decoding type. Please check train_kenlm.py in scripts/asr_language_modeling/ to find out why we need + Args: + decoding_type: decoding type + """ # TOKEN_OFFSET for BPE-based models if decoding_type == 'subword': from nemo.collections.asr.parts.submodules.ctc_beam_decoding import DEFAULT_TOKEN_OFFSET @@ -1467,6 +1499,10 @@ def set_decoding_type(self, decoding_type: str): @dataclass class BeamRNNTInferConfig: + """ + Beam RNNT Inference config. + """ + beam_size: int search_type: str = 'default' score_norm: bool = True diff --git a/nemo/collections/asr/parts/submodules/rnnt_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_decoding.py index da280a0c6b3c..d3a63467c485 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_decoding.py @@ -23,7 +23,7 @@ import torch from omegaconf import OmegaConf -from nemo.collections.asr.parts.submodules import rnnt_beam_decoding, rnnt_greedy_decoding +from nemo.collections.asr.parts.submodules import rnnt_beam_decoding, rnnt_greedy_decoding, tdt_beam_decoding from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceConfig, ConfidenceMixin from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis, NBestHypotheses from nemo.collections.common.tokenizers.aggregate_tokenizer import AggregateTokenizer @@ -67,15 +67,15 @@ class AbstractRNNTDecoding(ConfidenceMixin): rnnt_timestamp_type: A str value, which represents the types of timestamps that should be calculated. Can take the following values - "char" for character/subword time stamps, "word" for word level - time stamps, "segment" for segment level time stamps and "all" (default), for character, - word and segment level time stamps. + time stamps, "segment" for segment level time stamps and "all" (default), for character, word and + segment level time stamps. word_seperator: Str token representing the seperator between words. segment_seperators: List containing tokens representing the seperator(s) between segments. - segment_gap_threshold: The threshold (in frames) that caps the gap between two words necessary - for forming the segments. + segment_gap_threshold: The threshold (in frames) that caps the gap between two words necessary for forming + the segments. preserve_frame_confidence: Bool flag which preserves the history of per-frame confidence scores generated during decoding (sample / batched). When set to true, the Hypothesis will contain @@ -106,8 +106,8 @@ class AbstractRNNTDecoding(ConfidenceMixin): from the `token_confidence`. aggregation: Which aggregation type to use for collapsing per-token confidence into per-word confidence. Valid options are `mean`, `min`, `max`, `prod`. - tdt_include_duration: Bool flag indicating that the duration confidence scores are to be calculated - and attached to the regular frame confidence, + tdt_include_duration: Bool flag indicating that the duration confidence scores are to be calculated and + attached to the regular frame confidence, making TDT frame confidence element a pair: (`prediction_confidence`, `duration_confidence`). method_cfg: A dict-like object which contains the method name and settings to compute per-frame confidence scores. @@ -179,23 +179,23 @@ class AbstractRNNTDecoding(ConfidenceMixin): maes_num_steps: Number of adaptive steps to take. From the paper, 2 steps is generally sufficient, and can be reduced to 1 to improve decoding speed while sacrificing some accuracy. int > 0. - maes_prefix_alpha: Maximum prefix length in prefix search. Must be an integer, and is advised to - keep this as 1 in order to reduce expensive beam search cost later. int >= 0. + maes_prefix_alpha: Maximum prefix length in prefix search. Must be an integer, and is advised to keep + this as 1 in order to reduce expensive beam search cost later. int >= 0. maes_expansion_beta: Maximum number of prefix expansions allowed, in addition to the beam size. Effectively, the number of hypothesis = beam_size + maes_expansion_beta. Must be an int >= 0, - and affects the speed of inference since large values will perform large beam search in the - next step. + and affects the speed of inference since large values will perform large beam search in the next + step. maes_expansion_gamma: Float pruning threshold used in the prune-by-value step when computing the expansions. The default (2.3) is selected from the paper. It performs a comparison - (max_log_prob - gamma <= log_prob[v]) where v is all vocabulary indices in the Vocab set - and max_log_prob is the "most" likely token to be predicted. Gamma therefore provides a margin - of additional tokens which can be potential candidates for expansion apart from the "most likely" + (max_log_prob - gamma <= log_prob[v]) where v is all vocabulary indices in the Vocab set and + max_log_prob is the "most" likely token to be predicted. Gamma therefore provides a margin of + additional tokens which can be potential candidates for expansion apart from the "most likely" candidate. Lower values will reduce the number of expansions (by increasing pruning-by-value, thereby improving speed but hurting accuracy). Higher values will increase the number of expansions - (by reducing pruning-by-value, thereby reducing speed but potentially improving accuracy). - This is a hyper parameter to be experimentally tuned on a validation set. + (by reducing pruning-by-value, thereby reducing speed but potentially improving accuracy). This is + a hyper parameter to be experimentally tuned on a validation set. softmax_temperature: Scales the logits of the joint prior to computing log_softmax. @@ -234,8 +234,10 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int, supported_punctu raise ValueError("blank_id must equal len(non_blank_vocabs) for TDT models") if self.big_blank_durations is not None and self.big_blank_durations != []: raise ValueError("duration and big_blank_durations can't both be not None") - if self.cfg.strategy not in ['greedy', 'greedy_batch']: - raise ValueError("currently only greedy and greedy_batch inference is supported for TDT models") + if self.cfg.strategy not in ['greedy', 'greedy_batch', 'beam', 'maes']: + raise ValueError( + "currently only greedy, greedy_batch, beam and maes inference is supported for TDT models" + ) if ( self.big_blank_durations is not None and self.big_blank_durations != [] @@ -386,20 +388,32 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int, supported_punctu ) elif self.cfg.strategy == 'beam': - - self.decoding = rnnt_beam_decoding.BeamRNNTInfer( - decoder_model=decoder, - joint_model=joint, - beam_size=self.cfg.beam.beam_size, - return_best_hypothesis=decoding_cfg.beam.get('return_best_hypothesis', True), - search_type='default', - score_norm=self.cfg.beam.get('score_norm', True), - softmax_temperature=self.cfg.beam.get('softmax_temperature', 1.0), - preserve_alignments=self.preserve_alignments, - ) + if self.big_blank_durations is None or self.big_blank_durations == []: + if not self._is_tdt: + self.decoding = rnnt_beam_decoding.BeamRNNTInfer( + decoder_model=decoder, + joint_model=joint, + beam_size=self.cfg.beam.beam_size, + return_best_hypothesis=decoding_cfg.beam.get('return_best_hypothesis', True), + search_type='default', + score_norm=self.cfg.beam.get('score_norm', True), + softmax_temperature=self.cfg.beam.get('softmax_temperature', 1.0), + preserve_alignments=self.preserve_alignments, + ) + else: + self.decoding = tdt_beam_decoding.BeamTDTInfer( + decoder_model=decoder, + joint_model=joint, + durations=self.durations, + beam_size=self.cfg.beam.beam_size, + return_best_hypothesis=decoding_cfg.beam.get('return_best_hypothesis', True), + search_type='default', + score_norm=self.cfg.beam.get('score_norm', True), + softmax_temperature=self.cfg.beam.get('softmax_temperature', 1.0), + preserve_alignments=self.preserve_alignments, + ) elif self.cfg.strategy == 'tsd': - self.decoding = rnnt_beam_decoding.BeamRNNTInfer( decoder_model=decoder, joint_model=joint, @@ -413,7 +427,6 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int, supported_punctu ) elif self.cfg.strategy == 'alsd': - self.decoding = rnnt_beam_decoding.BeamRNNTInfer( decoder_model=decoder, joint_model=joint, @@ -427,26 +440,44 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int, supported_punctu ) elif self.cfg.strategy == 'maes': - - self.decoding = rnnt_beam_decoding.BeamRNNTInfer( - decoder_model=decoder, - joint_model=joint, - beam_size=self.cfg.beam.beam_size, - return_best_hypothesis=decoding_cfg.beam.get('return_best_hypothesis', True), - search_type='maes', - score_norm=self.cfg.beam.get('score_norm', True), - maes_num_steps=self.cfg.beam.get('maes_num_steps', 2), - maes_prefix_alpha=self.cfg.beam.get('maes_prefix_alpha', 1), - maes_expansion_gamma=self.cfg.beam.get('maes_expansion_gamma', 2.3), - maes_expansion_beta=self.cfg.beam.get('maes_expansion_beta', 2.0), - softmax_temperature=self.cfg.beam.get('softmax_temperature', 1.0), - preserve_alignments=self.preserve_alignments, - ngram_lm_model=self.cfg.beam.get('ngram_lm_model', None), - ngram_lm_alpha=self.cfg.beam.get('ngram_lm_alpha', 0.0), - hat_subtract_ilm=self.cfg.beam.get('hat_subtract_ilm', False), - hat_ilm_weight=self.cfg.beam.get('hat_ilm_weight', 0.0), - ) - + if self.big_blank_durations is None or self.big_blank_durations == []: + if not self._is_tdt: + self.decoding = rnnt_beam_decoding.BeamRNNTInfer( + decoder_model=decoder, + joint_model=joint, + beam_size=self.cfg.beam.beam_size, + return_best_hypothesis=decoding_cfg.beam.get('return_best_hypothesis', True), + search_type='maes', + score_norm=self.cfg.beam.get('score_norm', True), + maes_num_steps=self.cfg.beam.get('maes_num_steps', 2), + maes_prefix_alpha=self.cfg.beam.get('maes_prefix_alpha', 1), + maes_expansion_gamma=self.cfg.beam.get('maes_expansion_gamma', 2.3), + maes_expansion_beta=self.cfg.beam.get('maes_expansion_beta', 2.0), + softmax_temperature=self.cfg.beam.get('softmax_temperature', 1.0), + preserve_alignments=self.preserve_alignments, + ngram_lm_model=self.cfg.beam.get('ngram_lm_model', None), + ngram_lm_alpha=self.cfg.beam.get('ngram_lm_alpha', 0.0), + hat_subtract_ilm=self.cfg.beam.get('hat_subtract_ilm', False), + hat_ilm_weight=self.cfg.beam.get('hat_ilm_weight', 0.0), + ) + else: + self.decoding = tdt_beam_decoding.BeamTDTInfer( + decoder_model=decoder, + joint_model=joint, + durations=self.durations, + beam_size=self.cfg.beam.beam_size, + return_best_hypothesis=decoding_cfg.beam.get('return_best_hypothesis', True), + search_type='maes', + score_norm=self.cfg.beam.get('score_norm', True), + maes_num_steps=self.cfg.beam.get('maes_num_steps', 2), + maes_prefix_alpha=self.cfg.beam.get('maes_prefix_alpha', 1), + maes_expansion_gamma=self.cfg.beam.get('maes_expansion_gamma', 2.3), + maes_expansion_beta=self.cfg.beam.get('maes_expansion_beta', 2.0), + softmax_temperature=self.cfg.beam.get('softmax_temperature', 1.0), + preserve_alignments=self.preserve_alignments, + ngram_lm_model=self.cfg.beam.get('ngram_lm_model', None), + ngram_lm_alpha=self.cfg.beam.get('ngram_lm_alpha', 0.3), + ) else: raise ValueError( @@ -728,6 +759,15 @@ def decode_ids_to_langs(self, tokens: List[int]) -> List[str]: raise NotImplementedError() def update_joint_fused_batch_size(self): + """ " + Updates the fused batch size for the joint module if applicable. + + If `joint_fused_batch_size` is set, verifies that the joint module has + the required `set_fused_batch_size` and `set_fuse_loss_wer` functions. + If present, updates the batch size; otherwise, logs a warning. + + If `joint_fused_batch_size` is <= 0, disables fused batch processing. + """ if self.joint_fused_batch_size is None: # do nothing and let the Joint itself handle setting up of the fused batch return @@ -754,6 +794,21 @@ def update_joint_fused_batch_size(self): self.decoding.joint.set_fuse_loss_wer(False) def compute_rnnt_timestamps(self, hypothesis: Hypothesis, timestamp_type: str = "all"): + """ + Computes character, word, and segment timestamps for an RNN-T hypothesis. + + This function generates timestamps for characters, words, and segments within + a hypothesis sequence. The type of timestamps computed depends on `timestamp_type`, + which can be 'char', 'word', 'segment', or 'all'. + + Args: + hypothesis (Hypothesis): Hypothesis. + timestamp_type (str): Type of timestamps to compute. Options are 'char', 'word', 'segment', or 'all'. + Defaults to 'all'. + + Returns: + Hypothesis: The updated hypothesis with computed timestamps for characters, words, and/or segments. + """ assert timestamp_type in ['char', 'word', 'segment', 'all'] # Unpack the temporary storage @@ -890,7 +945,7 @@ def _compute_offsets( # Construct the start and end indices brackets end_indices = np.asarray(token_repetitions).cumsum() - start_indices = np.concatenate(([int(start_index)], end_indices[:-1])) + start_indices = np.concatenate(([start_index], end_indices[:-1])) # Process the TxU dangling alignment tensor, containing pairs of (logits, label) alignment_labels = [al_logits_labels for al_logits_labels in hypothesis.text[1]] @@ -953,8 +1008,8 @@ def _refine_timestamps_tdt( # Check if token is a punctuation mark # If so, set its start and end offset as start and end of the previous token - # This is done because there was observed a behaviour, when punctuation marks are predicted long - # after preceding token (i.e. after silence) + # This is done because there was observed a behaviour, when punctuation marks are + # predicted long after preceding token (i.e. after silence) if offset['char'][0] in supported_punctuation and i > 0: encoded_char_offsets[i]['start_offset'] = offset['start_offset'] = char_offsets[i - 1]['end_offset'] encoded_char_offsets[i]['end_offset'] = offset['end_offset'] = offset['start_offset'] @@ -1114,7 +1169,8 @@ def _get_segment_offsets( offsets: A list of dictionaries, each containing "word", "start_offset" and "end_offset". segments_delimiter_tokens: List containing tokens representing the seperator(s) between segments. supported_punctuation: Set containing punctuation marks in the vocabulary. - segment_gap_threshold: Number of frames between 2 consecutive words necessary to form segments out of plain text. + segment_gap_threshold: Number of frames between 2 consecutive words necessary to form segments out of plain + text. Returns: A list of dictionaries containing the segment offsets. Each item contains "segment", "start_offset" and "end_offset". @@ -1242,9 +1298,10 @@ class RNNTDecoding(AbstractRNNTDecoding): exclude_blank: Bool flag indicating that blank token confidence scores are to be excluded from the `token_confidence`. aggregation: Which aggregation type to use for collapsing per-token confidence into per-word - confidence. Valid options are `mean`, `min`, `max`, `prod`. - tdt_include_duration: Bool flag indicating that the duration confidence scores are to be calculated - and attached to the regular frame confidence, + confidence. + Valid options are `mean`, `min`, `max`, `prod`. + tdt_include_duration: Bool flag indicating that the duration confidence scores are to be calculated and + attached to the regular frame confidence, making TDT frame confidence element a pair: (`prediction_confidence`, `duration_confidence`). method_cfg: A dict-like object which contains the method name and settings to compute per-frame confidence scores. @@ -1331,7 +1388,7 @@ class RNNTDecoding(AbstractRNNTDecoding): and can be reduced to 1 to improve decoding speed while sacrificing some accuracy. int > 0. maes_prefix_alpha: Maximum prefix length in prefix search. Must be an integer, and is advised to - keep this as 1 in order to reduce expensive beam search cost later. int >= 0. + keep this as 1 in order to reduce expensive beam search cost later. int >= 0. maes_expansion_beta: Maximum number of prefix expansions allowed, in addition to the beam size. Effectively, the number of hypothesis = beam_size + maes_expansion_beta. Must be an int >= 0, @@ -1339,8 +1396,7 @@ class RNNTDecoding(AbstractRNNTDecoding): next step. maes_expansion_gamma: Float pruning threshold used in the prune-by-value step when computing the - expansions. - The default (2.3) is selected from the paper. It performs a comparison + expansions. The default (2.3) is selected from the paper. It performs a comparison (max_log_prob - gamma <= log_prob[v]) where v is all vocabulary indices in the Vocab set and max_log_prob is the "most" likely token to be predicted. Gamma therefore provides a margin of additional tokens which can be potential candidates for expansion apart from the "most likely" @@ -1382,7 +1438,9 @@ def __init__( supported_punctuation=supported_punctuation, ) - if isinstance(self.decoding, rnnt_beam_decoding.BeamRNNTInfer): + if isinstance(self.decoding, rnnt_beam_decoding.BeamRNNTInfer) or isinstance( + self.decoding, tdt_beam_decoding.BeamTDTInfer + ): self.decoding.set_decoding_type('char') def _aggregate_token_confidence(self, hypothesis: Hypothesis) -> List[float]: @@ -1498,8 +1556,8 @@ class RNNTBPEDecoding(AbstractRNNTDecoding): segment_seperators: List containing tokens representing the seperator(s) between segments. - segment_gap_threshold: The threshold (in frames) that caps the gap between two words necessary for - forming the segments. + segment_gap_threshold: The threshold (in frames) that caps the gap between two words necessary for forming + the segments. preserve_frame_confidence: Bool flag which preserves the history of per-frame confidence scores generated during decoding (sample / batched). When set to true, the Hypothesis will contain @@ -1530,8 +1588,8 @@ class RNNTBPEDecoding(AbstractRNNTDecoding): from the `token_confidence`. aggregation: Which aggregation type to use for collapsing per-token confidence into per-word confidence. Valid options are `mean`, `min`, `max`, `prod`. - tdt_include_duration: Bool flag indicating that the duration confidence scores are to be - calculated and attached to the regular frame confidence, + tdt_include_duration: Bool flag indicating that the duration confidence scores are to be calculated and + attached to the regular frame confidence, making TDT frame confidence element a pair: (`prediction_confidence`, `duration_confidence`). method_cfg: A dict-like object which contains the method name and settings to compute per-frame confidence scores. @@ -1602,7 +1660,7 @@ class RNNTBPEDecoding(AbstractRNNTDecoding): at increased cost to execution time. alsd_max_target_len: optional int or float, determines the potential maximum target sequence - length. If an integer is provided, it can decode sequences of that particular maximum length. + length.If an integer is provided, it can decode sequences of that particular maximum length. If a float is provided, it can decode sequences of int(alsd_max_target_len * seq_len), where seq_len is the length of the acoustic model output (T). @@ -1622,16 +1680,15 @@ class RNNTBPEDecoding(AbstractRNNTDecoding): and affects the speed of inference since large values will perform large beam search in the next step. - maes_expansion_gamma: Float pruning threshold used in the prune-by-value step when - computing the expansions. The default (2.3) is selected from the paper. It performs a - comparison (max_log_prob - gamma <= log_prob[v]) where v is all vocabulary indices in the - Vocab set and max_log_prob is the "most" likely token to be predicted. Gamma therefore - provides a margin of additional tokens which can be potential candidates for expansion - apart from the "most likely" candidate. Lower values will reduce the number of expansions - (by increasing pruning-by-value, thereby improving speed but hurting accuracy). Higher - values will increase the number of expansions (by reducing pruning-by-value, thereby - reducing speed but potentially improving accuracy). This is a hyper parameter to be - experimentally tuned on a validation set. + maes_expansion_gamma: Float pruning threshold used in the prune-by-value step when computing the + expansions. The default (2.3) is selected from the paper. It performs a comparison + (max_log_prob - gamma <= log_prob[v]) where v is all vocabulary indices in the Vocab set and + max_log_prob is the "most" likely token to be predicted. Gamma therefore provides a margin of + additional tokens which can be potential candidates for expansion apart from the "most likely" + candidate. Lower values will reduce the number of expansions (by increasing pruning-by-value, + thereby improving speed but hurting accuracy). Higher values will increase the number of + expansions (by reducing pruning-by-value, thereby reducing speed but potentially improving + accuracy). This is a hyper parameter to be experimentally tuned on a validation set. softmax_temperature: Scales the logits of the joint prior to computing log_softmax. @@ -1658,7 +1715,9 @@ def __init__(self, decoding_cfg, decoder, joint, tokenizer: TokenizerSpec): supported_punctuation=supported_punctuation, ) - if isinstance(self.decoding, rnnt_beam_decoding.BeamRNNTInfer): + if isinstance(self.decoding, rnnt_beam_decoding.BeamRNNTInfer) or isinstance( + self.decoding, tdt_beam_decoding.BeamTDTInfer + ): self.decoding.set_decoding_type('subword') def _aggregate_token_confidence(self, hypothesis: Hypothesis) -> List[float]: @@ -1759,8 +1818,8 @@ def decode_hypothesis(self, hypotheses_list: List[Hypothesis]) -> List[Union[Hyp hypotheses[ind].langs_chars = self.decode_ids_to_langs(prediction) else: logging.warning( - "Ignoring request for lang output in hypotheses since the model does not use an aggregate\ - tokenizer" + "Ignoring request for lang output in hypotheses since the model does not use an aggregate \ + tokenizer" ) return hypotheses @@ -1768,6 +1827,10 @@ def decode_hypothesis(self, hypotheses_list: List[Hypothesis]) -> List[Union[Hyp @dataclass class RNNTDecodingConfig: + """ + RNNT Decoding config + """ + model_type: str = "rnnt" # one of "rnnt", "multiblank" or "tdt" strategy: str = "greedy_batch" @@ -1825,4 +1888,8 @@ class RNNTDecodingConfig: @dataclass class RNNTBPEDecodingConfig(RNNTDecodingConfig): + """ + RNNT BPE Decoding Config + """ + pass diff --git a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py index f9cf368fe405..bd169d0d224e 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py @@ -49,7 +49,20 @@ def pack_hypotheses( hypotheses: List[rnnt_utils.Hypothesis], logitlen: torch.Tensor, ) -> List[rnnt_utils.Hypothesis]: + """ + Packs a list of hypotheses into a tensor and prepares decoder states. + + This function takes a list of token sequences (hypotheses) and converts + it into a tensor format. If any decoder states are on the GPU, they + are moved to the CPU. Additionally, the function removes any timesteps + with a value of -1 from the sequences. + + Args: + hypotheses (list): A list of token sequences representing hypotheses. + Returns: + list: A list of packed hypotheses in tensor format. + """ if hasattr(logitlen, 'cpu'): logitlen_cpu = logitlen.to('cpu') else: @@ -578,7 +591,8 @@ class GreedyBatchedRNNTInfer(_GreedyRNNTInfer, WithOptionalCudaGraphs): (evaluating Joint multiple times in inner loop); It uses a minimal possible amount of calls to prediction network (with maximum possible batch size), which makes it especially useful for scaling the prediction network. - use_cuda_graph_decoder: if CUDA graphs should be enabled for decoding (currently recommended only for inference) + use_cuda_graph_decoder: if CUDA graphs should be enabled for decoding + (currently recommended only for inference) """ def __init__( @@ -1169,6 +1183,10 @@ def _greedy_decode_masked( class ExportedModelGreedyBatchedRNNTInfer: + """ + Exported Model Greedy Batched RNNT Infer class + """ + def __init__(self, encoder_model: str, decoder_joint_model: str, max_symbols_per_step: Optional[int] = None): self.encoder_model_path = encoder_model self.decoder_joint_model_path = decoder_joint_model @@ -1344,9 +1362,25 @@ def _setup_blank_index(self): raise NotImplementedError() def run_encoder(self, audio_signal, length): + """ + Runs encoder network: + + Args: + audio_signal: audio signal + length: audio length + """ raise NotImplementedError() def run_decoder_joint(self, enc_logits, targets, target_length, *states): + """ + Runs decoder joint networks. + + Args: + enc_logits: encoder logits + targets: targets + target_length: target length + states: states + """ raise NotImplementedError() def _get_initial_states(self, batchsize): @@ -1354,6 +1388,10 @@ def _get_initial_states(self, batchsize): class ONNXGreedyBatchedRNNTInfer(ExportedModelGreedyBatchedRNNTInfer): + """ + ONNX Greedy Batched RNNT Infer class + """ + def __init__(self, encoder_model: str, decoder_joint_model: str, max_symbols_per_step: Optional[int] = 10): super().__init__( encoder_model=encoder_model, @@ -1433,7 +1471,8 @@ def _setup_blank_index(self): self._blank_index = log_probs.shape[-1] - 1 # last token of vocab size is blank token logging.info( - f"Enc-Dec-Joint step was evaluated, blank token id = {self._blank_index}; vocab size = {log_probs.shape[-1]}" + f"Enc-Dec-Joint step was evaluated, \ + blank token id = {self._blank_index}; vocab size = {log_probs.shape[-1]}" ) def run_encoder(self, audio_signal, length): @@ -1512,6 +1551,10 @@ def _get_initial_states(self, batchsize): class TorchscriptGreedyBatchedRNNTInfer(ExportedModelGreedyBatchedRNNTInfer): + """ + Torchscript Greedy Batched RNNT Infer + """ + def __init__( self, encoder_model: str, @@ -2336,6 +2379,8 @@ def _greedy_decode_masked( @dataclass class GreedyRNNTInferConfig: + """Greedy RNNT Infer Config""" + max_symbols_per_step: Optional[int] = 10 preserve_alignments: bool = False preserve_frame_confidence: bool = False @@ -2354,6 +2399,8 @@ def __post_init__(self): @dataclass class GreedyBatchedRNNTInferConfig: + """Greedy Batched RNNT Infer Config""" + max_symbols_per_step: Optional[int] = 10 preserve_alignments: bool = False preserve_frame_confidence: bool = False @@ -2708,7 +2755,8 @@ class GreedyBatchedTDTInfer(_GreedyRNNTInfer, WithOptionalCudaGraphs): - 'lin' for using the linear mapping. - 'exp' for using exponential mapping with linear shift. - use_cuda_graph_decoder: if CUDA graphs should be enabled for decoding (currently recommended only for inference) + use_cuda_graph_decoder: if CUDA graphs should be enabled for decoding + (currently recommended only for inference) """ def __init__( diff --git a/nemo/collections/asr/parts/submodules/tdt_beam_decoding.py b/nemo/collections/asr/parts/submodules/tdt_beam_decoding.py new file mode 100644 index 000000000000..908fc1c13d19 --- /dev/null +++ b/nemo/collections/asr/parts/submodules/tdt_beam_decoding.py @@ -0,0 +1,800 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2017 Johns Hopkins University (Shinji Watanabe) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional, Tuple + +import numpy as np +import torch +from tqdm import tqdm + +from nemo.collections.asr.modules import rnnt_abstract +from nemo.collections.asr.parts.submodules.rnnt_beam_decoding import pack_hypotheses +from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis, NBestHypotheses, is_prefix +from nemo.core.classes import Typing, typecheck +from nemo.core.neural_types import AcousticEncodedRepresentation, HypothesisType, LengthsType, NeuralType +from nemo.utils import logging + +try: + import kenlm + + KENLM_AVAILABLE = True +except (ImportError, ModuleNotFoundError): + KENLM_AVAILABLE = False + + +class BeamTDTInfer(Typing): + """ + Beam search implementation for Token-andDuration Transducer (TDT) models. + + Sequence level beam decoding or batched-beam decoding, performed auto-repressively + depending on the search type chosen. + + Args: + decoder_model: rnnt_utils.AbstractRNNTDecoder implementation. + joint_model: rnnt_utils.AbstractRNNTJoint implementation. + durations: list of duration values from TDT model. + + beam_size: number of beams for beam search. Must be a positive integer >= 1. + If beam size is 1, defaults to stateful greedy search. + For accurate greedy results, please use GreedyRNNTInfer or GreedyBatchedRNNTInfer. + + search_type: str representing the type of beam search to perform. + Must be one of ['beam', 'maes']. + + Algorithm used: + + `default` - basic beam search strategy. Larger beams generally result in better decoding, + however the time required for the search also grows steadily. + + `maes` = modified adaptive expansion search. Please refer to the paper: + [Accelerating RNN Transducer Inference via Adaptive Expansion Search] + (https://ieeexplore.ieee.org/document/9250505) + + Modified Adaptive Synchronous Decoding (mAES) execution time is adaptive w.r.t the + number of expansions (for tokens) required per timestep. The number of expansions can usually + be constrained to 1 or 2, and in most cases 2 is sufficient. + + This beam search technique can possibly obtain superior WER while sacrificing some evaluation time. + + score_norm: bool, whether to normalize the scores of the log probabilities. + + return_best_hypothesis: bool, decides whether to return a single hypothesis (the best out of N), + or return all N hypothesis (sorted with best score first). The container class changes based + this flag - + When set to True (default), returns a single Hypothesis. + When set to False, returns a NBestHypotheses container, which contains a list of Hypothesis. + + # The following arguments are specific to the chosen `search_type` + + # mAES flags + maes_num_steps: Number of adaptive steps to take. From the paper, 2 steps is generally sufficient. int > 1. + + maes_prefix_alpha: Maximum prefix length in prefix search. Must be an integer, and is advised to keep this as 1 + in order to reduce expensive beam search cost later. int >= 0. + + maes_expansion_beta: Maximum number of prefix expansions allowed, in addition to the beam size. + Effectively, the number of hypothesis = beam_size + maes_expansion_beta. Must be an int >= 0, + and affects the speed of inference since large values will perform large beam search in the next step. + + maes_expansion_gamma: Float pruning threshold used in the prune-by-value step when computing the expansions. + The default (2.3) is selected from the paper. It performs a comparison + (max_log_prob - gamma <= log_prob[v]) where v is all vocabulary indices in the Vocab set and max_log_prob + is the "most" likely token to be predicted. Gamma therefore provides a margin of additional tokens which + can be potential candidates for expansion apart from the "most likely" candidate. + Lower values will reduce the number of expansions (by increasing pruning-by-value, thereby improving speed + but hurting accuracy). Higher values will increase the number of expansions (by reducing pruning-by-value, + thereby reducing speed but potentially improving accuracy). This is a hyper parameter to be experimentally + tuned on a validation set. + + softmax_temperature: Scales the logits of the joint prior to computing log_softmax. + + preserve_alignments: Bool flag which preserves the history of alignments generated during + beam decoding (sample). When set to true, the Hypothesis will contain + the non-null value for `alignments` in it. Here, `alignments` is a List of List of Tensor (of length V + 1) + + The length of the list corresponds to the Acoustic Length (T). + Each value in the list (Ti) is a torch.Tensor (U), representing 1 or more targets from a vocabulary. + U is the number of target tokens for the current timestep Ti. + + NOTE: `preserve_alignments` is an invalid argument for any `search_type` + other than basic beam search. + + ngram_lm_model: str + The path to the N-gram LM. + ngram_lm_alpha: float + Alpha weight of N-gram LM. + """ + + @property + def input_types(self): + """Returns definitions of module input ports.""" + return { + "encoder_output": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()), + "encoded_lengths": NeuralType(tuple('B'), LengthsType()), + "partial_hypotheses": [NeuralType(elements_type=HypothesisType(), optional=True)], # must always be last + } + + @property + def output_types(self): + """Returns definitions of module output ports.""" + return {"predictions": [NeuralType(elements_type=HypothesisType())]} + + def __init__( + self, + decoder_model: rnnt_abstract.AbstractRNNTDecoder, + joint_model: rnnt_abstract.AbstractRNNTJoint, + durations: list, + beam_size: int, + search_type: str = 'default', + score_norm: bool = True, + return_best_hypothesis: bool = True, + maes_num_steps: int = 2, + maes_prefix_alpha: int = 1, + maes_expansion_gamma: float = 2.3, + maes_expansion_beta: int = 2, + softmax_temperature: float = 1.0, + preserve_alignments: bool = False, + ngram_lm_model: Optional[str] = None, + ngram_lm_alpha: float = 0.3, + ): + self.joint = joint_model + self.decoder = decoder_model + self.durations = durations + + self.token_offset = 0 + self.search_type = search_type + self.blank = decoder_model.blank_idx + self.vocab_size = decoder_model.vocab_size + self.return_best_hypothesis = return_best_hypothesis + + self.beam_size = beam_size + self.score_norm = score_norm + self.max_candidates = beam_size + self.softmax_temperature = softmax_temperature + self.preserve_alignments = preserve_alignments + + if preserve_alignments: + raise ValueError("Alignment preservation has not been implemented.") + if beam_size < 1: + raise ValueError("Beam search size cannot be less than 1!") + + if self.preserve_alignments: + raise NotImplementedError("Preserving alignments is not implemented.") + + if search_type == "default": + if self.beam_size == 1: + logging.info( + """If beam size is 1, defaults to stateful greedy search. + For accurate greedy results, please use GreedyTDTInfer or GreedyBatchedTDTInfer.""" + ) + self.search_algorithm = self.default_beam_search + elif search_type == "tsd": + raise NotImplementedError("`tsd` (Time Synchronous Decoding) has not been implemented.") + elif search_type == "alsd": + raise NotImplementedError("`alsd` (Alignment Length Synchronous Decoding) has not been implemented.") + elif search_type == "nsc": + raise NotImplementedError("`nsc` (Constrained Beam Search) has not been implemented.") + elif search_type == "maes": + self.search_algorithm = self.modified_adaptive_expansion_search + else: + raise NotImplementedError( + f"The search type ({search_type}) supplied is not supported!\n" f"Please use one of : (default, maes)" + ) + + if self.search_type == 'maes': + self.maes_num_steps = int(maes_num_steps) + self.maes_prefix_alpha = int(maes_prefix_alpha) + self.maes_expansion_beta = int(maes_expansion_beta) + self.maes_expansion_gamma = float(maes_expansion_gamma) + + self.max_candidates += maes_expansion_beta + + if self.maes_prefix_alpha < 0: + raise ValueError("`maes_prefix_alpha` must be a positive integer.") + + if self.vocab_size < beam_size + maes_expansion_beta: + raise ValueError( + f"beam_size ({beam_size}) + expansion_beta ({maes_expansion_beta}) " + f"should be smaller or equal to vocabulary size ({self.vocab_size})." + ) + + if self.maes_num_steps < 1: + raise ValueError("`maes_num_steps` must be greater than 0.") + + try: + self.zero_duration_idx = self.durations.index(0) + except ValueError: + self.zero_duration_idx = None + self.min_non_zero_duration_idx = int( + np.argmin(np.ma.masked_where(np.array(self.durations) == 0, self.durations)) + ) + + if ngram_lm_model: + if search_type != "maes": + raise ValueError("For decoding with language model `maes` decoding strategy must be chosen.") + + if KENLM_AVAILABLE: + self.ngram_lm = kenlm.Model(ngram_lm_model) + self.ngram_lm_alpha = ngram_lm_alpha + else: + raise ImportError( + "KenLM package (https://github.com/kpu/kenlm) is not installed. " "Use ngram_lm_model=None." + ) + else: + self.ngram_lm = None + + @typecheck() + def __call__( + self, + encoder_output: torch.Tensor, + encoded_lengths: torch.Tensor, + partial_hypotheses: tuple[list[Hypothesis | NBestHypotheses],] = None, + ) -> tuple[list[Hypothesis | NBestHypotheses],]: + """Perform general beam search. + + Args: + encoder_output: encoder outputs (batch, features, timesteps). + encoded_lengths: lengths of the encoder outputs. + + Returns: + Either a list containing a single Hypothesis (when `return_best_hypothesis=True`, + otherwise a list containing a single NBestHypotheses, which itself contains a list of + Hypothesis. This list is sorted such that the best hypothesis is the first element. + """ + # Preserve decoder and joint training state + decoder_training_state = self.decoder.training + joint_training_state = self.joint.training + + with torch.inference_mode(): + # Apply optional preprocessing + encoder_output = encoder_output.transpose(1, 2) # (B, T, D) + + self.decoder.eval() + self.joint.eval() + + hypotheses = [] + with tqdm( + range(encoder_output.size(0)), + desc='Beam search progress:', + total=encoder_output.size(0), + unit='sample', + ) as idx_gen: + + _p = next(self.joint.parameters()) + dtype = _p.dtype + + # Decode every sample in the batch independently. + for batch_idx in idx_gen: + inseq = encoder_output[batch_idx : batch_idx + 1, : encoded_lengths[batch_idx], :] # [1, T, D] + logitlen = encoded_lengths[batch_idx] + + if inseq.dtype != dtype: + inseq = inseq.to(dtype=dtype) + + # Extract partial hypothesis if exists + partial_hypothesis = partial_hypotheses[batch_idx] if partial_hypotheses is not None else None + + # Execute the specific search strategy + nbest_hyps = self.search_algorithm( + inseq, logitlen, partial_hypotheses=partial_hypothesis + ) # sorted list of hypothesis + + # Prepare the list of hypotheses + nbest_hyps = pack_hypotheses(nbest_hyps) + + # Pack the result + if self.return_best_hypothesis: + best_hypothesis: Hypothesis = nbest_hyps[0] + else: + best_hypothesis: NBestHypotheses = NBestHypotheses(nbest_hyps) + hypotheses.append(best_hypothesis) + + self.decoder.train(decoder_training_state) + self.joint.train(joint_training_state) + + return (hypotheses,) + + def default_beam_search( + self, + encoder_outputs: torch.Tensor, + encoded_lengths: torch.Tensor, + partial_hypotheses: Optional[Hypothesis] = None, + ) -> List[Hypothesis]: + """Default Beam search implementation for TDT models. + + Args: + encoder_outputs: encoder outputs (batch, features, timesteps). + encoded_lengths: lengths of the encoder outputs. + partial_hypotheses: partial hypoteses. + + Returns: + nbest_hyps: N-best decoding results + """ + if partial_hypotheses is not None: + raise NotImplementedError("Support for `partial_hypotheses` is not implemented.") + + beam = min(self.beam_size, self.vocab_size) + beam_k = min(beam, (self.vocab_size - 1)) + durations_beam_k = min(beam, len(self.durations)) + + # Initialize zero vector states. + decoder_state = self.decoder.initialize_state(encoder_outputs) + # Cache decoder results to avoid duplicate computations. + cache = {} + + # Initialize hypothesis array with blank hypothesis. + start_hyp = Hypothesis( + score=0.0, y_sequence=[self.blank], dec_state=decoder_state, timestep=[-1], length=0, last_frame=0 + ) + kept_hyps = [start_hyp] + + for time_idx in range(int(encoded_lengths)): + # Retrieve hypotheses for current and future frames + hyps = [hyp for hyp in kept_hyps if hyp.last_frame == time_idx] # hypotheses for current frame + kept_hyps = [hyp for hyp in kept_hyps if hyp.last_frame > time_idx] # hypothesis for future frames + + # Loop over hypotheses of current frame + while len(hyps) > 0: + max_hyp = max(hyps, key=lambda x: x.score) + hyps.remove(max_hyp) + + # Update decoder state and get probability distribution over vocabulary and durations. + encoder_output = encoder_outputs[:, time_idx : time_idx + 1, :] # [1, 1, D] + decoder_output, decoder_state, _ = self.decoder.score_hypothesis(max_hyp, cache) # [1, 1, D] + logits = ( + self.joint.joint(encoder_output, decoder_output) / self.softmax_temperature + ) # [1, 1, 1, V + NUM_DURATIONS + 1] + logp = torch.log_softmax(logits[0, 0, 0, : -len(self.durations)], dim=-1) # [V + 1] + durations_logp = torch.log_softmax(logits[0, 0, 0, -len(self.durations) :], dim=-1) # [NUM_DURATIONS] + + # Proccess non-blank tokens + # Retrieve the top `beam_k` most probable tokens and the top `duration_beam_k` most probable durations. + # Then, select the top `beam_k` pairs of (token, duration) based on the highest combined probabilities. + # Note that indices are obtained in the flattened array. + logp_topks, logp_topk_idxs = logp[:-1].topk(beam_k, dim=-1) # topk of tokens without blank token + durations_logp_topks, durations_logp_topk_idxs = durations_logp.topk(durations_beam_k, dim=-1) + total_logp_topks, total_logp_topk_idxs = ( + torch.cartesian_prod(durations_logp_topks, logp_topks).sum(dim=-1).topk(beam_k, dim=-1) + ) + + # Loop over pairs of (token, duration) with highest combined log prob + for total_logp_topk, total_logp_topk_idx in zip(total_logp_topks, total_logp_topk_idxs): + # Restore indices from flattened array indices + token_idx = int(logp_topk_idxs[total_logp_topk_idx % beam_k]) + duration_idx = int(durations_logp_topk_idxs[total_logp_topk_idx // beam_k]) + + duration = self.durations[duration_idx] + # Construct hypothesis for non-blank token + new_hyp = Hypothesis( + score=float(max_hyp.score + total_logp_topk), # update score + y_sequence=max_hyp.y_sequence + [token_idx], # update hypothesis sequence + dec_state=decoder_state, # update decoder state + timestep=max_hyp.timestep + [time_idx + duration], # update timesteps + length=encoded_lengths, + last_frame=max_hyp.last_frame + duration, + ) # update frame idx where last token appeared + + # Update current frame hypotheses if duration is zero and future frame hypotheses otherwise + if duration == 0: + hyps.append(new_hyp) + else: + kept_hyps.append(new_hyp) + + # Update future frames with blank tokens + # Note: blank token can have only non-zero duration + for duration_idx in durations_logp_topk_idxs: + duration_idx = int(duration_idx) + # If zero is the only duration in topk, switch to closest non-zero duration to continue + if duration_idx == self.zero_duration_idx: + if durations_logp_topk_idxs.shape[0] == 1: + duration_idx = self.min_non_zero_duration_idx + else: + continue + + duration = self.durations[duration_idx] + new_hyp = Hypothesis( + score=float(max_hyp.score + logp[self.blank] + durations_logp[duration_idx]), # update score + y_sequence=max_hyp.y_sequence[:], # no need to update sequence + dec_state=max_hyp.dec_state, # no need to update decoder state + timestep=max_hyp.timestep[:], # no need to update timesteps + length=encoded_lengths, + last_frame=max_hyp.last_frame + duration, + ) # update frame idx where last token appeared + kept_hyps.append(new_hyp) + + # Merge duplicate hypotheses. + # If two consecutive blank tokens are predicted and their duration values sum up to the same number, + # it will produce two hypotheses with the same token sequence but different scores. + kept_hyps = self.merge_duplicate_hypotheses(kept_hyps) + + if len(hyps) > 0: + # Keep those hypothesis that have scores greater than next search generation + hyps_max = float(max(hyps, key=lambda x: x.score).score) + kept_most_prob = sorted( + [hyp for hyp in kept_hyps if hyp.score > hyps_max], + key=lambda x: x.score, + ) + # If enough hypotheses have scores greater than next search generation, + # stop beam search. + if len(kept_most_prob) >= beam: + kept_hyps = kept_most_prob + break + else: + # If there are no hypotheses in a current frame, + # keep only `beam` best hypotheses for the next search generation. + kept_hyps = sorted(kept_hyps, key=lambda x: x.score, reverse=True)[:beam] + return self.sort_nbest(kept_hyps) + + def modified_adaptive_expansion_search( + self, + encoder_outputs: torch.Tensor, + encoded_lengths: torch.Tensor, + partial_hypotheses: Optional[Hypothesis] = None, + ) -> List[Hypothesis]: + """ + Modified Adaptive Exoansion Search algorithm for TDT models. + Based on/modified from https://ieeexplore.ieee.org/document/9250505. + Supports N-gram language model shallow fusion. + + Args: + encoder_outputs: encoder outputs (batch, features, timesteps). + encoded_lengths: lengths of the encoder outputs. + partial_hypotheses: partial hypotheses. + + Returns: + nbest_hyps: N-best decoding results + """ + if partial_hypotheses is not None: + raise NotImplementedError("Support for `partial_hypotheses` is not implemented.") + + beam = min(self.beam_size, self.vocab_size) + beam_state = self.decoder.initialize_state( + torch.zeros(1, device=encoder_outputs.device, dtype=encoder_outputs.dtype) + ) # [L, B, H], [L, B, H] for LSTMS + + # Initialize first hypothesis for the beam (blank). + start_hyp = Hypothesis( + y_sequence=[self.blank], + score=0.0, + dec_state=self.decoder.batch_select_state(beam_state, 0), + timestep=[-1], + length=0, + last_frame=0, + ) + init_tokens = [start_hyp] + + # Cache decoder results to avoid duplicate computations. + cache = {} + + # Decode a batch of beam states and scores + beam_decoder_output, beam_state = self.decoder.batch_score_hypothesis(init_tokens, cache) + state = beam_state[0] + + # Initialize first hypothesis for the beam (blank) for kept hypotheses + start_hyp_kept = Hypothesis( + y_sequence=[self.blank], + score=0.0, + dec_state=state, + dec_out=[beam_decoder_output[0]], + timestep=[-1], + length=0, + last_frame=0, + ) + + kept_hyps = [start_hyp_kept] + + # Setup ngram LM: + if self.ngram_lm: + init_lm_state = kenlm.State() + self.ngram_lm.BeginSentenceWrite(init_lm_state) + start_hyp_kept.ngram_lm_state = init_lm_state + + for time_idx in range(encoded_lengths): + # Select current iteration hypotheses + hyps = [x for x in kept_hyps if x.last_frame == time_idx] + kept_hyps = [x for x in kept_hyps if x.last_frame > time_idx] + + if len(hyps) == 0: + continue + + beam_encoder_output = encoder_outputs[:, time_idx : time_idx + 1] # [1, 1, D] + # Perform prefix search to update hypothesis scores. + if self.zero_duration_idx is not None: + hyps = self.prefix_search( + sorted(hyps, key=lambda x: len(x.y_sequence), reverse=True), + beam_encoder_output, + prefix_alpha=self.maes_prefix_alpha, + ) + + list_b = [] # List that contains the blank token emissions + list_nb = [] # List that contains the non-zero duration non-blank token emissions + # Repeat for number of mAES steps + for n in range(self.maes_num_steps): + # Pack the decoder logits for all current hypotheses + beam_decoder_output = torch.stack([h.dec_out[-1] for h in hyps]) # [H, 1, D] + + # Extract the log probabilities + beam_logits = self.joint.joint(beam_encoder_output, beam_decoder_output) / self.softmax_temperature + beam_logp = torch.log_softmax(beam_logits[:, 0, 0, : -len(self.durations)], dim=-1) + beam_duration_logp = torch.log_softmax(beam_logits[:, 0, 0, -len(self.durations) :], dim=-1) + + # Retrieve the top `max_candidades` most probable tokens. + # Then, select the top `max_candidates` pairs of (token, duration) + # based on the highest combined probabilities. + # Note that indices are obtained in flattened array. + beam_logp_topks, beam_idx_topks = beam_logp.topk(self.max_candidates, dim=-1) + beam_total_logp = (beam_duration_logp[:, :, None] + beam_logp_topks[:, None, :]).view( + len(hyps), -1 + ) # [B, MAX_CANDIDATES*DURATION_BEAM] + beam_total_logp_topks, beam_total_logp_topk_idxs = beam_total_logp.topk( + self.max_candidates, dim=-1 + ) # [B, MAX_CANDIDATES] + + # Prune hypothesis to obtain k expansions + beam_best_expansion_scores = beam_total_logp_topks.max(dim=-1, keepdim=True).values + beam_masks = beam_total_logp_topks >= beam_best_expansion_scores - self.maes_expansion_gamma + beam_kexpansions_idxs = [ + sum_logp_topk_idxs[mask] for sum_logp_topk_idxs, mask in zip(beam_total_logp_topk_idxs, beam_masks) + ] + + list_exp = [] # List that contains the hypothesis expansion + list_nb_exp = [] # List that contains the hypothesis expansion + for hyp_idx, hyp in enumerate(hyps): # For all hypothesis + for idx in beam_kexpansions_idxs[hyp_idx]: # For all expansions within this hypothesis + # Restore indices in logp and durations_logp arrays from flattened indices. + k = int(beam_idx_topks[hyp_idx][idx % self.max_candidates]) + duration = self.durations[int(idx // self.max_candidates)] + total_logp = float(beam_total_logp[hyp_idx][idx]) + + # Forcing blank token to have non-zero duration + if k == self.blank and duration == 0: + duration = self.durations[self.min_non_zero_duration_idx] + + new_hyp = Hypothesis( + score=hyp.score + total_logp, + y_sequence=hyp.y_sequence[:], + dec_out=hyp.dec_out[:], + dec_state=hyp.dec_state, + timestep=hyp.timestep[:], + length=time_idx, + last_frame=hyp.last_frame + duration, + ) + + if self.ngram_lm: + new_hyp.ngram_lm_state = hyp.ngram_lm_state + + # If the expansion was for blank + if k == self.blank: + list_b.append(new_hyp) + else: + new_hyp.y_sequence.append(k) + new_hyp.timestep.append(time_idx + duration) + + if self.ngram_lm: + lm_score, new_hyp.ngram_lm_state = self.compute_ngram_score(hyp.ngram_lm_state, int(k)) + new_hyp.score += self.ngram_lm_alpha * lm_score + + # If token duration is 0 adding to expansions list + if duration == 0: + list_exp.append(new_hyp) + else: + list_nb_exp.append(new_hyp) + + # Update states for hypothesis that do not end with blank + hyps_to_update = list_nb_exp + list_exp + if len(hyps_to_update) > 0: + # Decode a batch of beam states and scores + beam_decoder_output, beam_state = self.decoder.batch_score_hypothesis( + hyps_to_update, + cache, + ) + for hyp_idx, hyp in enumerate(hyps_to_update): + # Preserve the decoder logits for the current beam + hyp.dec_out.append(beam_decoder_output[hyp_idx]) + hyp.dec_state = beam_state[hyp_idx] + + # If there were no token expansions in any of the hypotheses, + # Early exit + list_nb += list_nb_exp + if not list_exp: + kept_hyps = kept_hyps + list_b + list_nb + kept_hyps = self.merge_duplicate_hypotheses(kept_hyps) + kept_hyps = sorted(kept_hyps, key=lambda x: x.score, reverse=True)[:beam] + + break + else: + # If this isn't the last mAES step + if n < (self.maes_num_steps - 1): + # Copy the expanded hypothesis for the next iteration + hyps = self.merge_duplicate_hypotheses(list_exp) + else: + # If this is the last mAES step add probabilities of the blank token to the end. + # Extract the log probabilities + beam_decoder_output = torch.stack([h.dec_out[-1] for h in list_exp]) # [H, 1, D] + beam_logits = ( + self.joint.joint(beam_encoder_output, beam_decoder_output) / self.softmax_temperature + ) + beam_logp = torch.log_softmax(beam_logits[:, 0, 0, : -len(self.durations)], dim=-1) + + # Get most probable durations + beam_duration_logp = torch.log_softmax(beam_logits[:, 0, 0, -len(self.durations) :], dim=-1) + _, beam_max_duration_idx = torch.max(beam_duration_logp, dim=-1) + + # For all expansions, add the score for the blank label + for hyp_idx, hyp in enumerate(list_exp): + # If zero duration was obtained, change to the closest non-zero duration + duration_idx = int(beam_max_duration_idx[hyp_idx]) + if duration_idx == self.zero_duration_idx: + duration_idx = self.min_non_zero_duration_idx + + total_logp = float( + beam_logp[hyp_idx, self.blank] + beam_duration_logp[hyp_idx, duration_idx] + ) + hyp.score += total_logp + hyp.last_frame += self.durations[duration_idx] + + # Finally, update the kept hypothesis of sorted top Beam candidates + kept_hyps = kept_hyps + list_b + list_exp + list_nb + kept_hyps = self.merge_duplicate_hypotheses(kept_hyps) + kept_hyps = sorted(kept_hyps, key=lambda x: x.score, reverse=True)[:beam] + + # Sort the hypothesis with best scores + return self.sort_nbest(kept_hyps) + + def merge_duplicate_hypotheses(self, hypotheses): + """ + Merges hypotheses with identical token sequences and lengths. + The combined hypothesis's probability is the sum of the probabilities of all duplicates. + Duplicate hypotheses occur when two consecutive blank tokens are predicted + and their duration values sum up to the same number. + + Args: + hypotheses: list of hypotheses. + + Returns: + hypotheses: list if hypotheses without duplicates. + """ + sorted_hyps = sorted(hypotheses, key=lambda x: x.score, reverse=True) + kept_hyps = {} + for hyp in sorted_hyps: + hyp_key = (tuple(hyp.y_sequence), int(hyp.last_frame)) + if hyp_key in kept_hyps: + kept_hyp = kept_hyps[hyp_key] + kept_hyp.score = float(torch.logaddexp(torch.tensor(kept_hyp.score), torch.tensor(hyp.score))) + else: + kept_hyps[hyp_key] = hyp + return list(kept_hyps.values()) + + def set_decoding_type(self, decoding_type: str): + """ + Sets decoding type. Please check train_kenlm.py in scripts/asr_language_modeling/ to find out why we need + Args: + decoding_type: decoding type + """ + # TOKEN_OFFSET for BPE-based models + if decoding_type == 'subword': + from nemo.collections.asr.parts.submodules.ctc_beam_decoding import DEFAULT_TOKEN_OFFSET + + self.token_offset = DEFAULT_TOKEN_OFFSET + + def prefix_search( + self, hypotheses: List[Hypothesis], encoder_output: torch.Tensor, prefix_alpha: int + ) -> List[Hypothesis]: + """ + Performs a prefix search and updates the scores of the hypotheses in place. + Based on https://arxiv.org/pdf/1211.3711.pdf. + + Args: + hypotheses: a list of hypotheses sorted by the length from the longest to the shortest. + encoder_output: encoder output. + prefix_alpha: maximum allowable length difference between hypothesis and a prefix. + + Returns: + hypotheses: list of hypotheses with updated scores. + """ + # Iterate over hypotheses. + for curr_idx, curr_hyp in enumerate(hypotheses[:-1]): + # For each hypothesis, iterate over the subsequent hypotheses. + # If a hypothesis is a prefix of the current one, update current score. + for pref_hyp in hypotheses[(curr_idx + 1) :]: + curr_hyp_length = len(curr_hyp.y_sequence) + pref_hyp_length = len(pref_hyp.y_sequence) + + if ( + is_prefix(curr_hyp.y_sequence, pref_hyp.y_sequence) + and (curr_hyp_length - pref_hyp_length) <= prefix_alpha + ): + # Compute the score of the first token + # that follows the prefix hypothesis tokens in current hypothesis. + # Use the decoder output, which is stored in the prefix hypothesis. + logits = self.joint.joint(encoder_output, pref_hyp.dec_out[-1]) / self.softmax_temperature + logp = torch.log_softmax(logits[0, 0, 0, : -len(self.durations)], dim=-1) + duration_logp = torch.log_softmax(logits[0, 0, 0, -len(self.durations) :], dim=-1) + curr_score = pref_hyp.score + float( + logp[curr_hyp.y_sequence[pref_hyp_length]] + duration_logp[self.zero_duration_idx] + ) + + if self.ngram_lm: + lm_score, next_state = self.compute_ngram_score( + pref_hyp.ngram_lm_state, int(curr_hyp.y_sequence[pref_hyp_length]) + ) + curr_score += self.ngram_lm_alpha * lm_score + + for k in range(pref_hyp_length, (curr_hyp_length - 1)): + # Compute the score of the next token. + # Approximate decoder output with the one that is stored in current hypothesis. + logits = self.joint.joint(encoder_output, curr_hyp.dec_out[k]) / self.softmax_temperature + logp = torch.log_softmax(logits[0, 0, 0, : -len(self.durations)], dim=-1) + duration_logp = torch.log_softmax(logits[0, 0, 0, -len(self.durations) :], dim=-1) + curr_score += float(logp[curr_hyp.y_sequence[k + 1]] + duration_logp[self.zero_duration_idx]) + + if self.ngram_lm: + lm_score, next_state = self.compute_ngram_score( + next_state, int(curr_hyp.y_sequence[k + 1]) + ) + curr_score += self.ngram_lm_alpha * lm_score + + # Update current hypothesis score + curr_hyp.score = np.logaddexp(curr_hyp.score, curr_score) + return hypotheses + + def compute_ngram_score(self, current_lm_state: "kenlm.State", label: int) -> Tuple[float, "kenlm.State"]: + """ + Computes the score for KenLM Ngram language model. + + Args: + current_lm_state: current state of the KenLM language model. + label: next label. + + Returns: + lm_score: score for `label`. + """ + if self.token_offset: + label = chr(label + self.token_offset) + else: + label = str(label) + + next_state = kenlm.State() + lm_score = self.ngram_lm.BaseScore(current_lm_state, label, next_state) + lm_score *= 1.0 / np.log10(np.e) + + return lm_score, next_state + + def sort_nbest(self, hyps: List[Hypothesis]) -> List[Hypothesis]: + """Sort hypotheses by score or score given sequence length. + + Args: + hyps: list of hypotheses + + Return: + hyps: sorted list of hypotheses + """ + if self.score_norm: + return sorted(hyps, key=lambda x: x.score / len(x.y_sequence), reverse=True) + else: + return sorted(hyps, key=lambda x: x.score, reverse=True) diff --git a/nemo/collections/asr/parts/utils/rnnt_utils.py b/nemo/collections/asr/parts/utils/rnnt_utils.py index 76e9da6087ed..8d2755fcc0ae 100644 --- a/nemo/collections/asr/parts/utils/rnnt_utils.py +++ b/nemo/collections/asr/parts/utils/rnnt_utils.py @@ -85,6 +85,8 @@ class Hypothesis: tokens: (Optional) A list of decoded tokens (can be characters or word-pieces. last_token (Optional): A token or batch of tokens which was predicted in the last step. + + last_frame (Optional): Index of the last decoding step hypothesis was updated including blank token prediction. """ score: float @@ -105,6 +107,7 @@ class Hypothesis: tokens: Optional[Union[List[int], torch.Tensor]] = None last_token: Optional[torch.Tensor] = None token_duration: Optional[List[int]] = None + last_frame: Optional[int] = None @property def non_blank_frame_confidence(self) -> List[float]: @@ -244,7 +247,8 @@ def __init__( Args: batch_size: batch size for hypotheses - init_length: initial estimate for the length of hypotheses (if the real length is higher, tensors will be reallocated) + init_length: initial estimate for the length of hypotheses (if the real length is higher, + tensors will be reallocated) device: device for storing hypotheses float_dtype: float type for scores """ @@ -274,6 +278,9 @@ def __init__( self._ones_batch = torch.ones_like(self._batch_indices) def clear_(self): + """ + Clears batched hypotheses state. + """ self.current_lengths.fill_(0) self.transcript.fill_(0) self.timesteps.fill_(0) @@ -497,6 +504,9 @@ def __init__( self._batch_indices = torch.arange(batch_size, device=device) def clear_(self): + """ + Clears batched hypotheses state. + """ self.current_lengths.fill_(0) self.timesteps.fill_(0) self.logits.fill_(0.0) diff --git a/tests/collections/asr/decoding/test_rnnt_decoding.py b/tests/collections/asr/decoding/test_rnnt_decoding.py index 82b5d00bede6..b5250ad5f144 100644 --- a/tests/collections/asr/decoding/test_rnnt_decoding.py +++ b/tests/collections/asr/decoding/test_rnnt_decoding.py @@ -22,8 +22,9 @@ from nemo.collections.asr.models import ASRModel from nemo.collections.asr.modules import RNNTDecoder, RNNTJoint from nemo.collections.asr.parts.mixins import mixins -from nemo.collections.asr.parts.submodules import rnnt_beam_decoding as beam_decode +from nemo.collections.asr.parts.submodules import rnnt_beam_decoding from nemo.collections.asr.parts.submodules import rnnt_greedy_decoding as greedy_decode +from nemo.collections.asr.parts.submodules import tdt_beam_decoding from nemo.collections.asr.parts.submodules.rnnt_decoding import RNNTBPEDecoding, RNNTDecoding, RNNTDecodingConfig from nemo.collections.asr.parts.utils import rnnt_utils from nemo.core.utils import numba_utils @@ -166,6 +167,39 @@ def check_subword_timestamps(hyp: rnnt_utils.Hypothesis, decoding: RNNTBPEDecodi assert len(hyp.timestep['segment']) == segments_count +def check_beam_decoding(test_data_dir, beam_config): + beam_size = beam_config.pop("beam_size", 1) + model, encoded, encoded_len = get_model_encoder_output(test_data_dir, 'nvidia/parakeet-tdt_ctc-110m') + + model_config = model.to_config_dict() + durations = list(model_config["model_defaults"]["tdt_durations"]) + + beam = tdt_beam_decoding.BeamTDTInfer( + model.decoder, + model.joint, + beam_size=beam_size, + return_best_hypothesis=False, + durations=durations, + **beam_config, + ) + + enc_out = encoded + enc_len = encoded_len + + with torch.no_grad(): + hyps: rnnt_utils.Hypothesis = beam(encoder_output=enc_out, encoded_lengths=enc_len)[0] + _, all_hyps = decode_text_from_nbest_hypotheses(hyps, model.decoding) + all_hyps = all_hyps[0] + + print("Beam search algorithm :", beam_config['search_type']) + for idx, hyp_ in enumerate(all_hyps): + print("Hyp index", idx + 1, "text :", hyp_.text) + + assert len(hyp_.timestep) > 0 + print("Timesteps", hyp_.timestep) + print() + + class TestRNNTDecoding: @pytest.mark.unit def test_constructor(self): @@ -312,10 +346,10 @@ def test_batched_greedy_decoding_preserve_alignments(self, test_data_dir, loop_l {"search_type": "maes", "maes_num_steps": 3, "maes_expansion_beta": 1, "beam_size": 2}, ], ) - def test_beam_decoding_preserve_alignments(self, test_data_dir, beam_config): + def test_rnnt_beam_decoding_preserve_alignments(self, test_data_dir, beam_config): beam_size = beam_config.pop("beam_size", 1) model, encoded, encoded_len = get_model_encoder_output(test_data_dir, 'stt_en_conformer_transducer_small') - beam = beam_decode.BeamRNNTInfer( + beam = rnnt_beam_decoding.BeamRNNTInfer( model.decoder, model.joint, beam_size=beam_size, @@ -442,3 +476,51 @@ def test_char_decoding_compute_timestamps(self, test_data_dir, decoding_strategy hyps, _ = decoding.rnnt_decoder_predictions_tensor(encoded, encoded_len, return_hypotheses=True) check_char_timestamps(hyps[0], decoding) + + @pytest.mark.skipif( + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', + ) + @pytest.mark.with_downloads + @pytest.mark.unit + @pytest.mark.parametrize( + "beam_config", + [ + { + "search_type": "default", + "beam_size": 2, + }, + {"search_type": "maes", "maes_num_steps": 2, "maes_expansion_beta": 2, "beam_size": 2}, + {"search_type": "maes", "maes_num_steps": 2, "maes_expansion_beta": 1, "beam_size": 4}, + ], + ) + def test_tdt_beam_decoding(self, test_data_dir, beam_config): + check_beam_decoding(test_data_dir, beam_config) + + @pytest.mark.skipif( + not NUMBA_RNNT_LOSS_AVAILABLE, + reason='RNNTLoss has not been compiled with appropriate numba version.', + ) + @pytest.mark.with_downloads + @pytest.mark.unit + @pytest.mark.parametrize( + "beam_config", + [ + { + "search_type": "maes", + "maes_num_steps": 2, + "maes_expansion_beta": 1, + "beam_size": 4, + "ngram_lm_alpha": 0.3, + }, + ], + ) + def test_tdt_beam_decoding_with_kenlm(self, test_data_dir, beam_config): + # skipping if kenlm is not installed + pytest.importorskip("kenlm", reason="Skipping test because 'kenlm' is not installed.") + + kenlm_model_path = os.path.join( + test_data_dir, "asr", "kenlm_ngram_lm", "parakeet-tdt_ctc-110m-libri-1024.kenlm.tmp.arpa" + ) + beam_config["ngram_lm_model"] = kenlm_model_path + check_beam_decoding(test_data_dir, beam_config)