From 1fe749f3c54ef68e29b9e49f2397ea09b7d74134 Mon Sep 17 00:00:00 2001 From: pengzhendong <275331498@qq.com> Date: Wed, 7 Aug 2024 16:15:59 +0800 Subject: [PATCH 1/3] [utils] use force_align of torchaudio --- wenet/utils/ctc_utils.py | 52 ++++------------------------------------ 1 file changed, 5 insertions(+), 47 deletions(-) diff --git a/wenet/utils/ctc_utils.py b/wenet/utils/ctc_utils.py index 718d42926e..99751f34cc 100644 --- a/wenet/utils/ctc_utils.py +++ b/wenet/utils/ctc_utils.py @@ -17,6 +17,7 @@ import numpy as np import torch +import torchaudio.functional as F def remove_duplicates_and_blank(hyp: List[int], @@ -112,53 +113,10 @@ def force_align(ctc_probs: torch.Tensor, y: torch.Tensor, blank_id=0) -> list: Returns: torch.Tensor: alignment result """ - ctc_probs = ctc_probs.cpu() - y = y.cpu() - y_insert_blank = insert_blank(y, blank_id) - - log_alpha = torch.zeros((ctc_probs.size(0), len(y_insert_blank))) - log_alpha = log_alpha - float('inf') # log of zero - state_path = torch.zeros((ctc_probs.size(0), len(y_insert_blank)), - dtype=torch.int16) - 1 # state path - - # init start state - log_alpha[0, 0] = ctc_probs[0][y_insert_blank[0]] - log_alpha[0, 1] = ctc_probs[0][y_insert_blank[1]] - - for t in range(1, ctc_probs.size(0)): - for s in range(len(y_insert_blank)): - if y_insert_blank[s] == blank_id or s < 2 or y_insert_blank[ - s] == y_insert_blank[s - 2]: - candidates = torch.tensor( - [log_alpha[t - 1, s], log_alpha[t - 1, s - 1]]) - prev_state = [s, s - 1] - else: - candidates = torch.tensor([ - log_alpha[t - 1, s], - log_alpha[t - 1, s - 1], - log_alpha[t - 1, s - 2], - ]) - prev_state = [s, s - 1, s - 2] - log_alpha[ - t, s] = torch.max(candidates) + ctc_probs[t][y_insert_blank[s]] - state_path[t, s] = prev_state[torch.argmax(candidates)] - - state_seq = -1 * torch.ones((ctc_probs.size(0), 1), dtype=torch.int16) - - candidates = torch.tensor([ - log_alpha[-1, len(y_insert_blank) - 1], - log_alpha[-1, len(y_insert_blank) - 2] - ]) - final_state = [len(y_insert_blank) - 1, len(y_insert_blank) - 2] - state_seq[-1] = final_state[torch.argmax(candidates)] - for t in range(ctc_probs.size(0) - 2, -1, -1): - state_seq[t] = state_path[t + 1, state_seq[t + 1, 0]] - - output_alignment = [] - for t in range(0, ctc_probs.size(0)): - output_alignment.append(y_insert_blank[state_seq[t, 0]]) - - return output_alignment + ctc_probs = ctc_probs[None].cpu() + y = y[None].cpu() + alignments, _ = F.forced_align(ctc_probs, y, blank=blank_id) + return alignments[0] def get_blank_id(configs, symbol_table): From ff0e57e7d1b1534ae869c8cf0dc6024e5d2cb2ab Mon Sep 17 00:00:00 2001 From: pengzhendong <275331498@qq.com> Date: Wed, 7 Aug 2024 16:16:45 +0800 Subject: [PATCH 2/3] rounded outputs of force_align --- wenet/cli/model.py | 7 ++++--- wenet/cli/paraformer_model.py | 6 +++--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/wenet/cli/model.py b/wenet/cli/model.py index d564870a53..0acab4987a 100644 --- a/wenet/cli/model.py +++ b/wenet/cli/model.py @@ -1,3 +1,4 @@ + # Copyright (c) 2023 Binbin Zhang (binbzha@qq.com) # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -128,9 +129,9 @@ def _decode(self, for i, x in enumerate(res.tokens): tokens_info.append({ 'token': self.char_dict[x], - 'start': times[i][0], - 'end': times[i][1], - 'confidence': res.tokens_confidence[i] + 'start': round(times[i][0], 3), + 'end': round(times[i][1], 3), + 'confidence': round(res.tokens_confidence[i], 2) }) result['tokens'] = tokens_info return result diff --git a/wenet/cli/paraformer_model.py b/wenet/cli/paraformer_model.py index 4f77758f2d..a4f834ab25 100644 --- a/wenet/cli/paraformer_model.py +++ b/wenet/cli/paraformer_model.py @@ -56,9 +56,9 @@ def transcribe(self, audio_file: str, tokens_info: bool = False) -> dict: for i, x in enumerate(res.tokens): tokens_info.append({ 'token': self.tokenizer.char_dict[x], - 'start': times[i][0], - 'end': times[i][1], - 'confidence': res.tokens_confidence[i] + 'start': round(times[i][0], 3), + 'end': round(times[i][1], 3), + 'confidence': round(res.tokens_confidence[i], 2) }) result['tokens'] = tokens_info From 10b4ae83cec9017ed509adf2b9c57fb3c1ea599a Mon Sep 17 00:00:00 2001 From: pengzhendong <275331498@qq.com> Date: Wed, 7 Aug 2024 16:28:33 +0800 Subject: [PATCH 3/3] remove empty line --- wenet/cli/model.py | 1 - 1 file changed, 1 deletion(-) diff --git a/wenet/cli/model.py b/wenet/cli/model.py index 0acab4987a..bb24bdb337 100644 --- a/wenet/cli/model.py +++ b/wenet/cli/model.py @@ -1,4 +1,3 @@ - # Copyright (c) 2023 Binbin Zhang (binbzha@qq.com) # # Licensed under the Apache License, Version 2.0 (the "License");