Skip to content

Commit

Permalink
Apply isort and black reformatting
Browse files Browse the repository at this point in the history
Signed-off-by: hainan-xv <[email protected]>
  • Loading branch information
hainan-xv committed Oct 10, 2024
1 parent 6254186 commit 17b86cf
Showing 1 changed file with 31 additions and 26 deletions.
57 changes: 31 additions & 26 deletions nemo/collections/asr/modules/rnnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,7 @@ class StatelessTransducerDecoder(rnnt_abstract.AbstractRNNTDecoder, Exportable):

@property
def input_types(self):
"""Returns definitions of module input ports.
"""
"""Returns definitions of module input ports."""
return {
"targets": NeuralType(('B', 'T'), LabelsType()),
"target_length": NeuralType(tuple('B'), LengthsType()),
Expand All @@ -84,8 +83,7 @@ def input_types(self):

@property
def output_types(self):
"""Returns definitions of module output ports.
"""
"""Returns definitions of module output ports."""
return {
"outputs": NeuralType(('B', 'D', 'T'), EmbeddedTextType()),
"prednet_lengths": NeuralType(tuple('B'), LengthsType()),
Expand Down Expand Up @@ -382,15 +380,20 @@ def batch_concat_states(self, batch_states: List[List[torch.Tensor]]) -> List[to

@classmethod
def batch_replace_states_mask(
cls, src_states: list[torch.Tensor], dst_states: list[torch.Tensor], mask: torch.Tensor,
cls,
src_states: list[torch.Tensor],
dst_states: list[torch.Tensor],
mask: torch.Tensor,
):
"""Replace states in dst_states with states from src_states using the mask"""
# same as `dst_states[0][mask] = src_states[0][mask]`, but non-blocking
torch.where(mask.unsqueeze(-1), src_states[0], dst_states[0], out=dst_states[0])

@classmethod
def batch_replace_states_all(
cls, src_states: list[torch.Tensor], dst_states: list[torch.Tensor],
cls,
src_states: list[torch.Tensor],
dst_states: list[torch.Tensor],
):
"""Replace states in dst_states with states from src_states"""
dst_states[0].copy_(src_states[0])
Expand Down Expand Up @@ -591,8 +594,7 @@ class RNNTDecoder(rnnt_abstract.AbstractRNNTDecoder, Exportable, AdapterModuleMi

@property
def input_types(self):
"""Returns definitions of module input ports.
"""
"""Returns definitions of module input ports."""
return {
"targets": NeuralType(('B', 'T'), LabelsType()),
"target_length": NeuralType(tuple('B'), LengthsType()),
Expand All @@ -601,8 +603,7 @@ def input_types(self):

@property
def output_types(self):
"""Returns definitions of module output ports.
"""
"""Returns definitions of module output ports."""
return {
"outputs": NeuralType(('B', 'D', 'T'), EmbeddedTextType()),
"prednet_lengths": NeuralType(tuple('B'), LengthsType()),
Expand Down Expand Up @@ -1018,19 +1019,19 @@ def batch_score_hypothesis(

def batch_initialize_states(self, batch_states: List[torch.Tensor], decoder_states: List[List[torch.Tensor]]):
"""
Create batch of decoder states.
Create batch of decoder states.
Args:
batch_states (list): batch of decoder states
([L x (B, H)], [L x (B, H)])
Args:
batch_states (list): batch of decoder states
([L x (B, H)], [L x (B, H)])
decoder_states (list of list): list of decoder states
[B x ([L x (1, H)], [L x (1, H)])]
decoder_states (list of list): list of decoder states
[B x ([L x (1, H)], [L x (1, H)])]
Returns:
batch_states (tuple): batch of decoder states
([L x (B, H)], [L x (B, H)])
"""
Returns:
batch_states (tuple): batch of decoder states
([L x (B, H)], [L x (B, H)])
"""
# LSTM has 2 states
new_states = [[] for _ in range(len(decoder_states[0]))]
for layer in range(self.pred_rnn_layers):
Expand Down Expand Up @@ -1109,7 +1110,9 @@ def batch_replace_states_mask(

@classmethod
def batch_replace_states_all(
cls, src_states: Tuple[torch.Tensor, torch.Tensor], dst_states: Tuple[torch.Tensor, torch.Tensor],
cls,
src_states: Tuple[torch.Tensor, torch.Tensor],
dst_states: Tuple[torch.Tensor, torch.Tensor],
):
"""Replace states in dst_states with states from src_states"""
dst_states[0].copy_(src_states[0])
Expand Down Expand Up @@ -1257,8 +1260,7 @@ class RNNTJoint(rnnt_abstract.AbstractRNNTJoint, Exportable, AdapterModuleMixin)

@property
def input_types(self):
"""Returns definitions of module input ports.
"""
"""Returns definitions of module input ports."""
return {
"encoder_outputs": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()),
"decoder_outputs": NeuralType(('B', 'D', 'T'), EmbeddedTextType()),
Expand All @@ -1270,8 +1272,7 @@ def input_types(self):

@property
def output_types(self):
"""Returns definitions of module output ports.
"""
"""Returns definitions of module output ports."""
if not self._fuse_loss_wer:
return {
"outputs": NeuralType(('B', 'T', 'T', 'D'), LogprobsType()),
Expand Down Expand Up @@ -2063,7 +2064,11 @@ def forward(
return losses, wer, wer_num, wer_denom

def sampled_joint(
self, f: torch.Tensor, g: torch.Tensor, transcript: torch.Tensor, transcript_lengths: torch.Tensor,
self,
f: torch.Tensor,
g: torch.Tensor,
transcript: torch.Tensor,
transcript_lengths: torch.Tensor,
) -> torch.Tensor:
"""
Compute the sampled joint step of the network.
Expand Down

0 comments on commit 17b86cf

Please sign in to comment.