From 814d3ac702fd8bf995660658ba241c9ad41ce2ab Mon Sep 17 00:00:00 2001 From: root Date: Wed, 14 Aug 2024 15:29:29 +0900 Subject: [PATCH] fix formatting issues --- .../ASR/zipformer/streaming_decode.py | 43 +++++++++---------- 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/egs/reazonspeech/ASR/zipformer/streaming_decode.py b/egs/reazonspeech/ASR/zipformer/streaming_decode.py index 81bdc4845b..9274f4dc4f 100755 --- a/egs/reazonspeech/ASR/zipformer/streaming_decode.py +++ b/egs/reazonspeech/ASR/zipformer/streaming_decode.py @@ -22,13 +22,15 @@ """ -import pdb import argparse import logging import math +import os +import pdb + +# import subprocess as sp from pathlib import Path from typing import Dict, List, Optional, Tuple -from tokenizer import Tokenizer import k2 import numpy as np @@ -42,6 +44,7 @@ greedy_search, modified_beam_search, ) +from tokenizer import Tokenizer from torch import Tensor, nn from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_model, get_params @@ -61,9 +64,6 @@ write_error_stats, ) -import subprocess as sp -import os - LOG_EPS = math.log(1e-10) @@ -124,7 +124,7 @@ def get_parser(): default="data/lang_bpe_500/bpe.model", help="Path to the BPE model", ) - + parser.add_argument( "--lang-dir", type=Path, @@ -449,14 +449,14 @@ def decode_one_chunk( feature_lens = [] states = [] processed_lens = [] # Used in fast-beam-search - + for stream in decode_streams: feat, feat_len = stream.get_feature_frames(chunk_size * 2) features.append(feat) feature_lens.append(feat_len) states.append(stream.states) processed_lens.append(stream.done_frames) - + feature_lens = torch.tensor(feature_lens, device=model.device) features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) @@ -518,9 +518,9 @@ def decode_one_chunk( decode_streams[i].states = states[i] decode_streams[i].done_frames += encoder_out_lens[i] # if decode_streams[i].done: - # finished_streams.append(i) + # finished_streams.append(i) finished_streams.append(i) - + return finished_streams @@ -628,21 +628,20 @@ def decode_dataset( ) # print('INSIDE FOR LOOP ') # print(finished_streams) - + if not finished_streams: print("No finished streams, breaking the loop") break - - + for i in sorted(finished_streams, reverse=True): - try: + try: decode_results.append( ( decode_streams[i].id, decode_streams[i].ground_truth.split(), sp.decode(decode_streams[i].decoding_result()).split(), ) - ) + ) del decode_streams[i] except IndexError as e: print(f"IndexError: {e}") @@ -650,7 +649,7 @@ def decode_dataset( print(f"finished_streams: {finished_streams}") print(f"i: {i}") continue - + if params.decoding_method == "greedy_search": key = "greedy_search" elif params.decoding_method == "fast_beam_search": @@ -663,7 +662,7 @@ def decode_dataset( key = f"num_active_paths_{params.num_active_paths}" else: raise ValueError(f"Unsupported decoding method: {params.decoding_method}") - torch.cuda.synchronize() + torch.cuda.synchronize() return {key: decode_results} @@ -854,11 +853,11 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - + # we need cut ids to display recognition results. args.return_cuts = True reazonspeech_corpus = ReazonSpeechAsrDataModule(args) - + valid_cuts = reazonspeech_corpus.valid_cuts() test_cuts = reazonspeech_corpus.test_cuts() @@ -878,9 +877,9 @@ def main(): test_set_name=test_set, results_dict=results_dict, ) - + # valid_cuts = reazonspeech_corpus.valid_cuts() - + # for valid_cut in valid_cuts: # results_dict = decode_dataset( # cuts=valid_cut, @@ -894,7 +893,7 @@ def main(): # test_set_name="valid", # results_dict=results_dict, # ) - + logging.info("Done!")