diff --git a/nemo/collections/asr/modules/rnnt.py b/nemo/collections/asr/modules/rnnt.py index cc6ec0452ece..2f0029765a2e 100644 --- a/nemo/collections/asr/modules/rnnt.py +++ b/nemo/collections/asr/modules/rnnt.py @@ -74,8 +74,7 @@ class StatelessTransducerDecoder(rnnt_abstract.AbstractRNNTDecoder, Exportable): @property def input_types(self): - """Returns definitions of module input ports. - """ + """Returns definitions of module input ports.""" return { "targets": NeuralType(('B', 'T'), LabelsType()), "target_length": NeuralType(tuple('B'), LengthsType()), @@ -84,8 +83,7 @@ def input_types(self): @property def output_types(self): - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return { "outputs": NeuralType(('B', 'D', 'T'), EmbeddedTextType()), "prednet_lengths": NeuralType(tuple('B'), LengthsType()), @@ -382,7 +380,10 @@ def batch_concat_states(self, batch_states: List[List[torch.Tensor]]) -> List[to @classmethod def batch_replace_states_mask( - cls, src_states: list[torch.Tensor], dst_states: list[torch.Tensor], mask: torch.Tensor, + cls, + src_states: list[torch.Tensor], + dst_states: list[torch.Tensor], + mask: torch.Tensor, ): """Replace states in dst_states with states from src_states using the mask""" # same as `dst_states[0][mask] = src_states[0][mask]`, but non-blocking @@ -390,7 +391,9 @@ def batch_replace_states_mask( @classmethod def batch_replace_states_all( - cls, src_states: list[torch.Tensor], dst_states: list[torch.Tensor], + cls, + src_states: list[torch.Tensor], + dst_states: list[torch.Tensor], ): """Replace states in dst_states with states from src_states""" dst_states[0].copy_(src_states[0]) @@ -591,8 +594,7 @@ class RNNTDecoder(rnnt_abstract.AbstractRNNTDecoder, Exportable, AdapterModuleMi @property def input_types(self): - """Returns definitions of module input ports. - """ + """Returns definitions of module input ports.""" return { "targets": NeuralType(('B', 'T'), LabelsType()), "target_length": NeuralType(tuple('B'), LengthsType()), @@ -601,8 +603,7 @@ def input_types(self): @property def output_types(self): - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" return { "outputs": NeuralType(('B', 'D', 'T'), EmbeddedTextType()), "prednet_lengths": NeuralType(tuple('B'), LengthsType()), @@ -1018,19 +1019,19 @@ def batch_score_hypothesis( def batch_initialize_states(self, batch_states: List[torch.Tensor], decoder_states: List[List[torch.Tensor]]): """ - Create batch of decoder states. + Create batch of decoder states. - Args: - batch_states (list): batch of decoder states - ([L x (B, H)], [L x (B, H)]) + Args: + batch_states (list): batch of decoder states + ([L x (B, H)], [L x (B, H)]) - decoder_states (list of list): list of decoder states - [B x ([L x (1, H)], [L x (1, H)])] + decoder_states (list of list): list of decoder states + [B x ([L x (1, H)], [L x (1, H)])] - Returns: - batch_states (tuple): batch of decoder states - ([L x (B, H)], [L x (B, H)]) - """ + Returns: + batch_states (tuple): batch of decoder states + ([L x (B, H)], [L x (B, H)]) + """ # LSTM has 2 states new_states = [[] for _ in range(len(decoder_states[0]))] for layer in range(self.pred_rnn_layers): @@ -1109,7 +1110,9 @@ def batch_replace_states_mask( @classmethod def batch_replace_states_all( - cls, src_states: Tuple[torch.Tensor, torch.Tensor], dst_states: Tuple[torch.Tensor, torch.Tensor], + cls, + src_states: Tuple[torch.Tensor, torch.Tensor], + dst_states: Tuple[torch.Tensor, torch.Tensor], ): """Replace states in dst_states with states from src_states""" dst_states[0].copy_(src_states[0]) @@ -1257,8 +1260,7 @@ class RNNTJoint(rnnt_abstract.AbstractRNNTJoint, Exportable, AdapterModuleMixin) @property def input_types(self): - """Returns definitions of module input ports. - """ + """Returns definitions of module input ports.""" return { "encoder_outputs": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()), "decoder_outputs": NeuralType(('B', 'D', 'T'), EmbeddedTextType()), @@ -1270,8 +1272,7 @@ def input_types(self): @property def output_types(self): - """Returns definitions of module output ports. - """ + """Returns definitions of module output ports.""" if not self._fuse_loss_wer: return { "outputs": NeuralType(('B', 'T', 'T', 'D'), LogprobsType()), @@ -2063,7 +2064,11 @@ def forward( return losses, wer, wer_num, wer_denom def sampled_joint( - self, f: torch.Tensor, g: torch.Tensor, transcript: torch.Tensor, transcript_lengths: torch.Tensor, + self, + f: torch.Tensor, + g: torch.Tensor, + transcript: torch.Tensor, + transcript_lengths: torch.Tensor, ) -> torch.Tensor: """ Compute the sampled joint step of the network.