Skip to content

Commit

Permalink
fix formatting issues
Browse files Browse the repository at this point in the history
  • Loading branch information
root authored and root committed Aug 14, 2024
1 parent 5632925 commit 814d3ac
Showing 1 changed file with 21 additions and 22 deletions.
43 changes: 21 additions & 22 deletions egs/reazonspeech/ASR/zipformer/streaming_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@
"""

import pdb
import argparse
import logging
import math
import os
import pdb

# import subprocess as sp
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from tokenizer import Tokenizer

import k2
import numpy as np
Expand All @@ -42,6 +44,7 @@
greedy_search,
modified_beam_search,
)
from tokenizer import Tokenizer
from torch import Tensor, nn
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_model, get_params
Expand All @@ -61,9 +64,6 @@
write_error_stats,
)

import subprocess as sp
import os

LOG_EPS = math.log(1e-10)


Expand Down Expand Up @@ -124,7 +124,7 @@ def get_parser():
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)

parser.add_argument(
"--lang-dir",
type=Path,
Expand Down Expand Up @@ -449,14 +449,14 @@ def decode_one_chunk(
feature_lens = []
states = []
processed_lens = [] # Used in fast-beam-search

for stream in decode_streams:
feat, feat_len = stream.get_feature_frames(chunk_size * 2)
features.append(feat)
feature_lens.append(feat_len)
states.append(stream.states)
processed_lens.append(stream.done_frames)

feature_lens = torch.tensor(feature_lens, device=model.device)
features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS)

Expand Down Expand Up @@ -518,9 +518,9 @@ def decode_one_chunk(
decode_streams[i].states = states[i]
decode_streams[i].done_frames += encoder_out_lens[i]
# if decode_streams[i].done:
# finished_streams.append(i)
# finished_streams.append(i)
finished_streams.append(i)

return finished_streams


Expand Down Expand Up @@ -628,29 +628,28 @@ def decode_dataset(
)
# print('INSIDE FOR LOOP ')
# print(finished_streams)

if not finished_streams:
print("No finished streams, breaking the loop")
break



for i in sorted(finished_streams, reverse=True):
try:
try:
decode_results.append(
(
decode_streams[i].id,
decode_streams[i].ground_truth.split(),
sp.decode(decode_streams[i].decoding_result()).split(),
)
)
)
del decode_streams[i]
except IndexError as e:
print(f"IndexError: {e}")
print(f"decode_streams length: {len(decode_streams)}")
print(f"finished_streams: {finished_streams}")
print(f"i: {i}")
continue

if params.decoding_method == "greedy_search":
key = "greedy_search"
elif params.decoding_method == "fast_beam_search":
Expand All @@ -663,7 +662,7 @@ def decode_dataset(
key = f"num_active_paths_{params.num_active_paths}"
else:
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
torch.cuda.synchronize()
torch.cuda.synchronize()
return {key: decode_results}


Expand Down Expand Up @@ -854,11 +853,11 @@ def main():

num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")

# we need cut ids to display recognition results.
args.return_cuts = True
reazonspeech_corpus = ReazonSpeechAsrDataModule(args)

valid_cuts = reazonspeech_corpus.valid_cuts()
test_cuts = reazonspeech_corpus.test_cuts()

Expand All @@ -878,9 +877,9 @@ def main():
test_set_name=test_set,
results_dict=results_dict,
)

# valid_cuts = reazonspeech_corpus.valid_cuts()

# for valid_cut in valid_cuts:
# results_dict = decode_dataset(
# cuts=valid_cut,
Expand All @@ -894,7 +893,7 @@ def main():
# test_set_name="valid",
# results_dict=results_dict,
# )

logging.info("Done!")


Expand Down

0 comments on commit 814d3ac

Please sign in to comment.