Skip to content

Commit

Permalink
Speedup RNN-T greedy decoding (#7926)
Browse files Browse the repository at this point in the history
* Add structure for batched hypotheses

Signed-off-by: Vladimir Bataev <[email protected]>

* Add faster decoding algo

Signed-off-by: Vladimir Bataev <[email protected]>

* Simplify max_symbols support. More speedup

Signed-off-by: Vladimir Bataev <[email protected]>

* Clean up

Signed-off-by: Vladimir Bataev <[email protected]>

* Clean up

Signed-off-by: Vladimir Bataev <[email protected]>

* Filtering only when necessary

Signed-off-by: Vladimir Bataev <[email protected]>

* Move max_symbols check to the end of loop

Signed-off-by: Vladimir Bataev <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Support returning prediction network states

Signed-off-by: Vladimir Bataev <[email protected]>

* Support preserve_alignments flag

Signed-off-by: Vladimir Bataev <[email protected]>

* Support confidence

Signed-off-by: Vladimir Bataev <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Partial fix for jit compatibility

Signed-off-by: Vladimir Bataev <[email protected]>

* Support switching between decoding algorithms

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix switching algorithms

Signed-off-by: Vladimir Bataev <[email protected]>

* Clean up

Signed-off-by: Vladimir Bataev <[email protected]>

* Clean up

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix max symbols per step

Signed-off-by: Vladimir Bataev <[email protected]>

* Add tests. Preserve torch.jit compatibility for BatchedHyps

Signed-off-by: Vladimir Bataev <[email protected]>

* Separate projection from Joint calculation in decoding

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix config instantiation

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix after main merge

Signed-off-by: Vladimir Bataev <[email protected]>

* Add tests for batched hypotheses

Signed-off-by: Vladimir Bataev <[email protected]>

* Speedup alignments

Signed-off-by: Vladimir Bataev <[email protected]>

* Test alignments

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix alignments

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix tests for alignments

Signed-off-by: Vladimir Bataev <[email protected]>

* Add more tests

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix confidence tests

Signed-off-by: Vladimir Bataev <[email protected]>

* Avoid common package modification

Signed-off-by: Vladimir Bataev <[email protected]>

* Support Stateless prediction network

Signed-off-by: Vladimir Bataev <[email protected]>

* Improve stateless decoder support. Separate alignments and confidence

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix alignments for max_symbols_per_step

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix alignments for max_symbols_per_step=0

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix tests

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix test

Signed-off-by: Vladimir Bataev <[email protected]>

* Add comments

Signed-off-by: Vladimir Bataev <[email protected]>

* Batched Hyps/Alignments: lengths -> current_lengths

Signed-off-by: Vladimir Bataev <[email protected]>

* Simplify indexing

Signed-off-by: Vladimir Bataev <[email protected]>

* Improve type annotations

Signed-off-by: Vladimir Bataev <[email protected]>

* Rework test for greedy decoding

Signed-off-by: Vladimir Bataev <[email protected]>

* Document loop_labels

Signed-off-by: Vladimir Bataev <[email protected]>

* Raise ValueError if max_symbols_per_step <= 0

Signed-off-by: Vladimir Bataev <[email protected]>

* Add comments

Signed-off-by: Vladimir Bataev <[email protected]>

* Fix test

Signed-off-by: Vladimir Bataev <[email protected]>

---------

Signed-off-by: Vladimir Bataev <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: stevehuang52 <[email protected]>
  • Loading branch information
2 people authored and stevehuang52 committed Jan 31, 2024
1 parent 75cf584 commit ea496de
Show file tree
Hide file tree
Showing 11 changed files with 1,032 additions and 55 deletions.
14 changes: 4 additions & 10 deletions nemo/collections/asr/modules/hybrid_autoregressive_transducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,9 @@ def return_hat_ilm(self):
def return_hat_ilm(self, hat_subtract_ilm):
self._return_hat_ilm = hat_subtract_ilm

def joint(self, f: torch.Tensor, g: torch.Tensor) -> Union[torch.Tensor, HATJointOutput]:
def joint_after_projection(self, f: torch.Tensor, g: torch.Tensor) -> Union[torch.Tensor, HATJointOutput]:
"""
Compute the joint step of the network.
Compute the joint step of the network after Encoder/Decoder projection.
Here,
B = Batch size
Expand Down Expand Up @@ -169,14 +169,8 @@ def joint(self, f: torch.Tensor, g: torch.Tensor) -> Union[torch.Tensor, HATJoin
Log softmaxed tensor of shape (B, T, U, V + 1).
Internal LM probability (B, 1, U, V) -- in case of return_ilm==True.
"""
# f = [B, T, H1]
f = self.enc(f)
f.unsqueeze_(dim=2) # (B, T, 1, H)

# g = [B, U, H2]
g = self.pred(g)
g.unsqueeze_(dim=1) # (B, 1, U, H)

f = f.unsqueeze(dim=2) # (B, T, 1, H)
g = g.unsqueeze(dim=1) # (B, 1, U, H)
inp = f + g # [B, T, U, H]

del f
Expand Down
71 changes: 60 additions & 11 deletions nemo/collections/asr/modules/rnnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,22 @@ def batch_copy_states(

return old_states

def mask_select_states(
self, states: Optional[List[torch.Tensor]], mask: torch.Tensor
) -> Optional[List[torch.Tensor]]:
"""
Return states by mask selection
Args:
states: states for the batch
mask: boolean mask for selecting states; batch dimension should be the same as for states
Returns:
states filtered by mask
"""
if states is None:
return None
return [states[0][mask]]

def batch_score_hypothesis(
self, hypotheses: List[rnnt_utils.Hypothesis], cache: Dict[Tuple[int], Any], batch_states: List[torch.Tensor]
) -> Tuple[torch.Tensor, List[torch.Tensor], torch.Tensor]:
Expand Down Expand Up @@ -1047,6 +1063,21 @@ def batch_copy_states(

return old_states

def mask_select_states(
self, states: Tuple[torch.Tensor, torch.Tensor], mask: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Return states by mask selection
Args:
states: states for the batch
mask: boolean mask for selecting states; batch dimension should be the same as for states
Returns:
states filtered by mask
"""
# LSTM in PyTorch returns a tuple of 2 tensors as a state
return states[0][:, mask], states[1][:, mask]

# Adapter method overrides
def add_adapter(self, name: str, cfg: DictConfig):
# Update the config with correct input dim
Expand Down Expand Up @@ -1382,9 +1413,33 @@ def forward(

return losses, wer, wer_num, wer_denom

def joint(self, f: torch.Tensor, g: torch.Tensor) -> torch.Tensor:
def project_encoder(self, encoder_output: torch.Tensor) -> torch.Tensor:
"""
Project the encoder output to the joint hidden dimension.
Args:
encoder_output: A torch.Tensor of shape [B, T, D]
Returns:
A torch.Tensor of shape [B, T, H]
"""
return self.enc(encoder_output)

def project_prednet(self, prednet_output: torch.Tensor) -> torch.Tensor:
"""
Project the Prediction Network (Decoder) output to the joint hidden dimension.
Args:
prednet_output: A torch.Tensor of shape [B, U, D]
Returns:
A torch.Tensor of shape [B, U, H]
"""
return self.pred(prednet_output)

def joint_after_projection(self, f: torch.Tensor, g: torch.Tensor) -> torch.Tensor:
"""
Compute the joint step of the network.
Compute the joint step of the network after projection.
Here,
B = Batch size
Expand Down Expand Up @@ -1412,14 +1467,8 @@ def joint(self, f: torch.Tensor, g: torch.Tensor) -> torch.Tensor:
Returns:
Logits / log softmaxed tensor of shape (B, T, U, V + 1).
"""
# f = [B, T, H1]
f = self.enc(f)
f.unsqueeze_(dim=2) # (B, T, 1, H)

# g = [B, U, H2]
g = self.pred(g)
g.unsqueeze_(dim=1) # (B, 1, U, H)

f = f.unsqueeze(dim=2) # (B, T, 1, H)
g = g.unsqueeze(dim=1) # (B, 1, U, H)
inp = f + g # [B, T, U, H]

del f, g
Expand Down Expand Up @@ -1536,7 +1585,7 @@ def set_fuse_loss_wer(self, fuse_loss_wer, loss=None, metric=None):

@property
def fused_batch_size(self):
return self._fuse_loss_wer
return self._fused_batch_size

def set_fused_batch_size(self, fused_batch_size):
self._fused_batch_size = fused_batch_size
Expand Down
53 changes: 52 additions & 1 deletion nemo/collections/asr/modules/rnnt_abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,45 @@ class AbstractRNNTJoint(NeuralModule, ABC):
"""

@abstractmethod
def joint_after_projection(self, f: torch.Tensor, g: torch.Tensor) -> Any:
"""
Compute the joint step of the network after the projection step.
Args:
f: Output of the Encoder model after projection. A torch.Tensor of shape [B, T, H]
g: Output of the Decoder model (Prediction Network) after projection. A torch.Tensor of shape [B, U, H]
Returns:
Logits / log softmaxed tensor of shape (B, T, U, V + 1).
Arbitrary return type, preferably torch.Tensor, but not limited to (e.g., see HatJoint)
"""
raise NotImplementedError()

@abstractmethod
def project_encoder(self, encoder_output: torch.Tensor) -> torch.Tensor:
"""
Project the encoder output to the joint hidden dimension.
Args:
encoder_output: A torch.Tensor of shape [B, T, D]
Returns:
A torch.Tensor of shape [B, T, H]
"""
raise NotImplementedError()

@abstractmethod
def project_prednet(self, prednet_output: torch.Tensor) -> torch.Tensor:
"""
Project the Prediction Network (Decoder) output to the joint hidden dimension.
Args:
prednet_output: A torch.Tensor of shape [B, U, D]
Returns:
A torch.Tensor of shape [B, U, H]
"""
raise NotImplementedError()

def joint(self, f: torch.Tensor, g: torch.Tensor) -> torch.Tensor:
"""
Compute the joint step of the network.
Expand Down Expand Up @@ -58,7 +97,7 @@ def joint(self, f: torch.Tensor, g: torch.Tensor) -> torch.Tensor:
Returns:
Logits / log softmaxed tensor of shape (B, T, U, V + 1).
"""
raise NotImplementedError()
return self.joint_after_projection(self.project_encoder(f), self.project_prednet(g))

@property
def num_classes_with_blank(self):
Expand Down Expand Up @@ -277,3 +316,15 @@ def batch_copy_states(
(L x B x H, L x B x H)
"""
raise NotImplementedError()

def mask_select_states(self, states: Any, mask: torch.Tensor) -> Any:
"""
Return states by mask selection
Args:
states: states for the batch (preferably a list of tensors, but not limited to)
mask: boolean mask for selecting states; batch dimension should be the same as for states
Returns:
states filtered by mask (same type as `states`)
"""
raise NotImplementedError()
5 changes: 3 additions & 2 deletions nemo/collections/asr/parts/submodules/rnnt_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int):
preserve_alignments=self.preserve_alignments,
preserve_frame_confidence=self.preserve_frame_confidence,
confidence_method_cfg=self.confidence_method_cfg,
loop_labels=self.cfg.greedy.get('loop_labels', True),
)
else:
self.decoding = rnnt_greedy_decoding.GreedyBatchedTDTInfer(
Expand Down Expand Up @@ -1495,8 +1496,8 @@ class RNNTDecodingConfig:
rnnt_timestamp_type: str = "all" # can be char, word or all for both

# greedy decoding config
greedy: rnnt_greedy_decoding.GreedyRNNTInferConfig = field(
default_factory=lambda: rnnt_greedy_decoding.GreedyRNNTInferConfig()
greedy: rnnt_greedy_decoding.GreedyBatchedRNNTInferConfig = field(
default_factory=rnnt_greedy_decoding.GreedyBatchedRNNTInferConfig
)

# beam decoding config
Expand Down
Loading

0 comments on commit ea496de

Please sign in to comment.