Skip to content

Commit

Permalink
Beam search algorithm implementation for TDT models (#10903)
Browse files Browse the repository at this point in the history
* initial commit

Signed-off-by: lilithgrigoryan <[email protected]>

* add: default beam search implementation

Signed-off-by: lilithgrigoryan <[email protected]>

* fix: changed to removing duplicate hypothesis in separate function

Signed-off-by: lilithgrigoryan <[email protected]>

* fix: changed to cartesian product in choosing best hyp

Signed-off-by: lilithgrigoryan <[email protected]>

* fix: minor fixes in comments

Signed-off-by: lilithgrigoryan <[email protected]>

* add: maes decoding strategy

Signed-off-by: lilithgrigoryan <[email protected]>

* add: durations filtering in maes, lm fusion in progress

Signed-off-by: lilithgrigoryan <[email protected]>

* fix: refactored, added comments, command line args, finalized

Signed-off-by: lilithgrigoryan <[email protected]>

* fix: removed prints

Signed-off-by: lilithgrigoryan <[email protected]>

* add: docs

Signed-off-by: lilithgrigoryan <[email protected]>

* Apply isort and black reformatting

Signed-off-by: lilithgrigoryan <[email protected]>

* fix: minor fix

Signed-off-by: lilithgrigoryan <[email protected]>

* fix: rm beam_size=1 exception, rm duplicates check, fix error handling

Signed-off-by: lilithgrigoryan <[email protected]>

* fix: error handling

Signed-off-by: lilithgrigoryan <[email protected]>

* Apply isort and black reformatting

Signed-off-by: lilithgrigoryan <[email protected]>

* fix: removed evaluations file

Signed-off-by: lilithgrigoryan <[email protected]>

* rn: blank scoring

Signed-off-by: lilithgrigoryan <[email protected]>

* clean up

Signed-off-by: lilithgrigoryan <[email protected]>

* rm: blank scoring and duration beam size

Signed-off-by: lilithgrigoryan <[email protected]>

* Apply isort and black reformatting

Signed-off-by: lilithgrigoryan <[email protected]>

* fix: removed durations_beam_size from default beam search

Signed-off-by: lilithgrigoryan <[email protected]>

* add: logaddexp

Signed-off-by: lilithgrigoryan <[email protected]>

* rm: prefix search

Signed-off-by: lilithgrigoryan <[email protected]>

* rn: nested loop over extensions

Signed-off-by: lilithgrigoryan <[email protected]>

* fix: bug with caching

Signed-off-by: lilithgrigoryan <[email protected]>

* rm: topk on durations

Signed-off-by: lilithgrigoryan <[email protected]>

* add: restored prefix search

Signed-off-by: lilithgrigoryan <[email protected]>

* Apply isort and black reformatting

Signed-off-by: lilithgrigoryan <[email protected]>

* clean up

Signed-off-by: lilithgrigoryan <[email protected]>

* fix: fixed comments

Signed-off-by: lilithgrigoryan <[email protected]>

* refactored duplicate merging

Signed-off-by: lilithgrigoryan <[email protected]>

* changes batch scoring

Signed-off-by: lilithgrigoryan <[email protected]>

* refactored rnnt batch scoring

Signed-off-by: lilithgrigoryan <[email protected]>

* alsd first working

Signed-off-by: lilithgrigoryan <[email protected]>

* refactored

Signed-off-by: lilithgrigoryan <[email protected]>

* clean up

Signed-off-by: lilithgrigoryan <[email protected]>

* remove stacking operations

Signed-off-by: lilithgrigoryan <[email protected]>

* fixes im base class

Signed-off-by: lilithgrigoryan <[email protected]>

* clean up

Signed-off-by: lilithgrigoryan <[email protected]>

* Apply isort and black reformatting

Signed-off-by: lilithgrigoryan <[email protected]>

* remove potentially uninitialized local variable

Signed-off-by: lilithgrigoryan <[email protected]>

* default beam search minor fixes

Signed-off-by: lilithgrigoryan <[email protected]>

* add test, fix maes timesteps

Signed-off-by: lilithgrigoryan <[email protected]>

* rm file

Signed-off-by: lilithgrigoryan <[email protected]>

* rm file

Signed-off-by: lilithgrigoryan <[email protected]>

* clean up

Signed-off-by: lilithgrigoryan <[email protected]>

* Apply isort and black reformatting

Signed-off-by: lilithgrigoryan <[email protected]>

* clean up

Signed-off-by: lilithgrigoryan <[email protected]>

* fix comments

Signed-off-by: lilithgrigoryan <[email protected]>

* add ngram lm test

Signed-off-by: lilithgrigoryan <[email protected]>

* Apply isort and black reformatting

Signed-off-by: lilithgrigoryan <[email protected]>

* fix maes_num_steps=1

Signed-off-by: lilithgrigoryan <[email protected]>

* fix kenlm model path

Signed-off-by: lilithgrigoryan <[email protected]>

* fix kenlm model full path

Signed-off-by: lilithgrigoryan <[email protected]>

* Apply isort and black reformatting

Signed-off-by: lilithgrigoryan <[email protected]>

* made requested changes

Signed-off-by: lilithgrigoryan <[email protected]>

* merge after isort

Signed-off-by: lilithgrigoryan <[email protected]>

* add prints to test

Signed-off-by: lilithgrigoryan <[email protected]>

* Apply isort and black reformatting

Signed-off-by: lilithgrigoryan <[email protected]>

* add Kenlm to asr requirements

Signed-off-by: lilithgrigoryan <[email protected]>

* remove prints in tests

Signed-off-by: lilithgrigoryan <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add kenlm to test requirements

Signed-off-by: lilithgrigoryan <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* rm kenlm from link, add package-name

Signed-off-by: lilithgrigoryan <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* rm second kenlm installation

Signed-off-by: lilithgrigoryan <[email protected]>

* rm kenlm from dependencies make test optional

Signed-off-by: lilithgrigoryan <[email protected]>

* Apply isort and black reformatting

Signed-off-by: lilithgrigoryan <[email protected]>

* fix in test

Signed-off-by: lilithgrigoryan <[email protected]>

* fix in test

Signed-off-by: lilithgrigoryan <[email protected]>

* Apply isort and black reformatting

Signed-off-by: lilithgrigoryan <[email protected]>

* fix comments

Signed-off-by: lilithgrigoryan <[email protected]>

* Apply isort and black reformatting

Signed-off-by: lilithgrigoryan <[email protected]>

* add comments

Signed-off-by: lilithgrigoryan <[email protected]>

* add comments

Signed-off-by: lilithgrigoryan <[email protected]>

* splitted docstrings

Signed-off-by: lilithgrigoryan <[email protected]>

* Apply isort and black reformatting

Signed-off-by: lilithgrigoryan <[email protected]>

* add comments

Signed-off-by: lilithgrigoryan <[email protected]>

* splitted docstrings

Signed-off-by: lilithgrigoryan <[email protected]>

* Apply isort and black reformatting

Signed-off-by: lilithgrigoryan <[email protected]>

* add comments

Signed-off-by: lilithgrigoryan <[email protected]>

* Apply isort and black reformatting

Signed-off-by: lilithgrigoryan <[email protected]>

* fixes to python3 type annotations

Signed-off-by: lilithgrigoryan <[email protected]>

* Apply isort and black reformatting

Signed-off-by: lilithgrigoryan <[email protected]>

* merging

Signed-off-by: lilithgrigoryan <[email protected]>

* merging

Signed-off-by: lilithgrigoryan <[email protected]>

* fix in return type

Signed-off-by: lilithgrigoryan <[email protected]>

* Apply isort and black reformatting

Signed-off-by: lilithgrigoryan <[email protected]>

* fix test

Signed-off-by: lilithgrigoryan <[email protected]>

* Apply isort and black reformatting

Signed-off-by: lilithgrigoryan <[email protected]>

* rm time_idx

Signed-off-by: lilithgrigoryan <[email protected]>

* fix comments to python3 style

Signed-off-by: lilithgrigoryan <[email protected]>

---------

Signed-off-by: lilithgrigoryan <[email protected]>
Signed-off-by: lilithgrigoryan <[email protected]>
Co-authored-by: lilithgrigoryan <[email protected]>
Co-authored-by: lilithgrigoryan <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
4 people authored Nov 13, 2024
1 parent f311b2e commit a2572a7
Show file tree
Hide file tree
Showing 7 changed files with 1,155 additions and 97 deletions.
15 changes: 15 additions & 0 deletions docs/source/asr/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,21 @@ RNNT Decoding
:show-inheritance:
:members:

TDT Decoding
~~~~~~~~~~~~~

.. autoclass:: nemo.collections.asr.parts.submodules.rnnt_greedy_decoding.GreedyTDTInfer
:show-inheritance:
:members:

.. autoclass:: nemo.collections.asr.parts.submodules.rnnt_greedy_decoding.GreedyBatchedTDTInfer
:show-inheritance:
:members:

.. autoclass:: nemo.collections.asr.parts.submodules.tdt_beam_decoding.BeamTDTInfer
:show-inheritance:
:members:

Hypotheses
~~~~~~~~~~

Expand Down
56 changes: 46 additions & 10 deletions nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,20 @@


def pack_hypotheses(hypotheses: List[Hypothesis]) -> List[Hypothesis]:
"""
Packs a list of hypotheses into a tensor and prepares decoder states.
This function takes a list of token sequences (hypotheses) and converts
it into a tensor format. If any decoder states are on the GPU, they
are moved to the CPU. Additionally, the function removes any timesteps
with a value of -1 from the sequences.
Args:
hypotheses (list): A list of token sequences representing hypotheses.
Returns:
list: A list of packed hypotheses in tensor format.
"""
for idx, hyp in enumerate(hypotheses): # type: rnnt_utils.Hypothesis
hyp.y_sequence = torch.tensor(hyp.y_sequence, dtype=torch.long)

Expand All @@ -69,6 +83,18 @@ def pack_hypotheses(hypotheses: List[Hypothesis]) -> List[Hypothesis]:


def _states_to_device(dec_state, device='cpu'):
"""
Transfers decoder states to the specified device.
This function moves the provided decoder states to the specified device (e.g., 'cpu' or 'cuda').
Args:
dec_state (Tensor): The decoder states to be transferred.
device (str): The target device to which the decoder states should be moved. Defaults to 'cpu'.
Returns:
Tensor: The decoder states on the specified device.
"""
if torch.is_tensor(dec_state):
dec_state = dec_state.to(device)

Expand Down Expand Up @@ -106,15 +132,17 @@ class BeamRNNTInfer(Typing):
however the time required for the search also grows steadily.
`tsd` - time synchronous decoding. Please refer to the paper:
[Alignment-Length Synchronous Decoding for RNN Transducer](https://ieeexplore.ieee.org/document/9053040)
[Alignment-Length Synchronous Decoding for RNN Transducer]
(https://ieeexplore.ieee.org/document/9053040)
for details on the algorithm implemented.
Time synchronous decoding (TSD) execution time grows by the factor T * max_symmetric_expansions.
For longer sequences, T is greater, and can therefore take a long time for beams to obtain
good results. This also requires greater memory to execute.
`alsd` - alignment-length synchronous decoding. Please refer to the paper:
[Alignment-Length Synchronous Decoding for RNN Transducer](https://ieeexplore.ieee.org/document/9053040)
[Alignment-Length Synchronous Decoding for RNN Transducer]
(https://ieeexplore.ieee.org/document/9053040)
for details on the algorithm implemented.
Alignment-length synchronous decoding (ALSD) execution time is faster than TSD, with growth
Expand All @@ -127,7 +155,8 @@ class BeamRNNTInfer(Typing):
For a given decoding accuracy, it is possible to attain faster decoding via ALSD than TSD.
`maes` = modified adaptive expansion searcn. Please refer to the paper:
[Accelerating RNN Transducer Inference via Adaptive Expansion Search](https://ieeexplore.ieee.org/document/9250505)
[Accelerating RNN Transducer Inference via Adaptive Expansion Search]
(https://ieeexplore.ieee.org/document/9250505)
Modified Adaptive Synchronous Decoding (mAES) execution time is adaptive w.r.t the
number of expansions (for tokens) required per timestep. The number of expansions can usually
Expand Down Expand Up @@ -169,10 +198,10 @@ class BeamRNNTInfer(Typing):
and affects the speed of inference since large values will perform large beam search in the next step.
maes_expansion_gamma: Float pruning threshold used in the prune-by-value step when computing the expansions.
The default (2.3) is selected from the paper. It performs a comparison (max_log_prob - gamma <= log_prob[v])
where v is all vocabulary indices in the Vocab set and max_log_prob is the "most" likely token to be
predicted. Gamma therefore provides a margin of additional tokens which can be potential candidates for
expansion apart from the "most likely" candidate.
The default (2.3) is selected from the paper. It performs a comparison
(max_log_prob - gamma <= log_prob[v]) where v is all vocabulary indices in the Vocab set and max_log_prob
is the "most" likely token to be predicted. Gamma therefore provides a margin of additional tokens which
can be potential candidates for expansion apart from the "most likely" candidate.
Lower values will reduce the number of expansions (by increasing pruning-by-value, thereby improving speed
but hurting accuracy). Higher values will increase the number of expansions (by reducing pruning-by-value,
thereby reducing speed but potentially improving accuracy). This is a hyper parameter to be experimentally
Expand All @@ -182,7 +211,7 @@ class BeamRNNTInfer(Typing):
preserve_alignments: Bool flag which preserves the history of alignments generated during
beam decoding (sample). When set to true, the Hypothesis will contain
the non-null value for `alignments` in it. Here, `alignments` is a List of List of Tensor (of length V + 1).
the non-null value for `alignments` in it. Here, `alignments` is a List of List of Tensor (of length V + 1)
The length of the list corresponds to the Acoustic Length (T).
Each value in the list (Ti) is a torch.Tensor (U), representing 1 or more targets from a vocabulary.
Expand Down Expand Up @@ -1456,8 +1485,11 @@ def compute_ngram_score(self, current_lm_state: "kenlm.State", label: int) -> Tu
return lm_score, next_state

def set_decoding_type(self, decoding_type: str):

# Please check train_kenlm.py in scripts/asr_language_modeling/ to find out why we need
"""
Sets decoding type. Please check train_kenlm.py in scripts/asr_language_modeling/ to find out why we need
Args:
decoding_type: decoding type
"""
# TOKEN_OFFSET for BPE-based models
if decoding_type == 'subword':
from nemo.collections.asr.parts.submodules.ctc_beam_decoding import DEFAULT_TOKEN_OFFSET
Expand All @@ -1467,6 +1499,10 @@ def set_decoding_type(self, decoding_type: str):

@dataclass
class BeamRNNTInferConfig:
"""
Beam RNNT Inference config.
"""

beam_size: int
search_type: str = 'default'
score_norm: bool = True
Expand Down
Loading

0 comments on commit a2572a7

Please sign in to comment.