diff --git a/README.md b/README.md index a09df9a7a0..27d7638956 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,7 @@ | [**Runtime**](https://github.com/wenet-e2e/wenet/tree/main/runtime) | [**Pretrained Models**](docs/pretrained_models.md) | [**HuggingFace**](https://huggingface.co/spaces/wenet/wenet_demo) +| [**Ask WeNet Guru**](https://gurubase.io/g/wenet) **We** share **Net** together. diff --git a/runtime/core/cmake/libtorch.cmake b/runtime/core/cmake/libtorch.cmake index 07d0157600..a31f43ca6d 100644 --- a/runtime/core/cmake/libtorch.cmake +++ b/runtime/core/cmake/libtorch.cmake @@ -1,5 +1,5 @@ if(TORCH) - set(TORCH_VERSION "2.1.0") + set(TORCH_VERSION "2.2.0") add_definitions(-DUSE_TORCH) if(NOT ANDROID) if(GPU) @@ -13,32 +13,37 @@ if(TORCH) if(${CMAKE_SYSTEM_NAME} STREQUAL "Windows") if(${CMAKE_BUILD_TYPE} MATCHES "Release") set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cpu/libtorch-win-shared-with-deps-${TORCH_VERSION}%2Bcpu.zip") - set(URL_HASH "SHA256=77815aa799f15e91b6fbb0216ac78cc0479adb5cd0ca662072241484cf23f667") + set(URL_HASH "SHA256=96bc833184a7c13a088a2a83cab5a2be853c0c9d9f972740a50580173d0c796d") else() set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cpu/libtorch-win-shared-with-deps-debug-${TORCH_VERSION}%2Bcpu.zip") - set(URL_HASH "SHA256=5f887c02d9abf805c8b53fef89bf5a4dab9dd78771754344e73c98d9c484aa9d") + set(URL_HASH "SHA256=5b7dbabbecd86051b800ce0a244f15b89e9de0f8b5370e5fa65668aa37ecb878") endif() elseif(${CMAKE_SYSTEM_NAME} STREQUAL "Linux") if(CXX11_ABI) if(NOT GPU) set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-${TORCH_VERSION}%2Bcpu.zip") - set(URL_HASH "SHA256=04f699d5181048b0062ef52de1df44b46859b8fbeeee12abdbcb9aac63e2a14b") + set(URL_HASH "SHA256=62cd3001a2886d2db125aabc3be5c4fb66b3e34b32727d84323968f507ee8e32") else() set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu118/libtorch-cxx11-abi-shared-with-deps-${TORCH_VERSION}%2Bcu118.zip") - set(URL_HASH "SHA256=7796249faa9828a53b72d3f616fc97a1d9e87e6a35ac72b392ca1ddc7b125188") + set(URL_HASH "SHA256=a2b0f51ff59ef2787a82c36bba67f7380236a6384dbbd2459c558989af27184f") endif() else() if(NOT GPU) set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cpu/libtorch-shared-with-deps-${TORCH_VERSION}%2Bcpu.zip") - set(URL_HASH "SHA256=0e86d364d05b83c6c66c3bb32e7eee932847843e4085487eefd9b3bbde4e2c58") + set(URL_HASH "SHA256=e1f6bc48403022ff4680c7299cc8b160df146892c414b8a6b6f7d5aff65bcbce") else() set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu118/libtorch-shared-with-deps-${TORCH_VERSION}%2Bcu118.zip") - set(URL_HASH "SHA256=f70cfae25b02ff419e1d51ad137a746941773d2c4b0155a44b4b6b50702d661a") + set(URL_HASH "SHA256=f9c887085207f9500357cae4324a53c3010b8890397db915d7dbefb9183c7964") endif() endif() elseif(${CMAKE_SYSTEM_NAME} STREQUAL "Darwin") - set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cpu/libtorch-macos-${TORCH_VERSION}.zip") - set(URL_HASH "SHA256=ce744d2d27a96df8f34d4227e8b1179dad5a76612dc7230b61db65affce6e7bd") + if(CMAKE_SYSTEM_PROCESSOR MATCHES "arm64") + set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cpu/libtorch-macos-arm64-${TORCH_VERSION}.zip") + set(URL_HASH "SHA256=a2ac530e5db2f5be33fe7f7e3049b9a525ee60b110dbb1e08835e22002756ed1") + else() + set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cpu/libtorch-macos-x86_64-${TORCH_VERSION}.zip") + set(URL_HASH "SHA256=300940c6b1d4402ece72d31cd5694d9579dcfb23b7cf6b05676006411f9b516c") + endif() elseif(${CMAKE_SYSTEM_NAME} STREQUAL "iOS") add_definitions(-DIOS) else() diff --git a/runtime/horizonbpu/.gitignore b/runtime/horizonbpu/.gitignore index c6767241c3..9219231945 100644 --- a/runtime/horizonbpu/.gitignore +++ b/runtime/horizonbpu/.gitignore @@ -1,2 +1,3 @@ build/ fc_base/ +wheels* diff --git a/runtime/horizonbpu/CMakeLists.txt b/runtime/horizonbpu/CMakeLists.txt index 9e17900619..3d3ff62991 100644 --- a/runtime/horizonbpu/CMakeLists.txt +++ b/runtime/horizonbpu/CMakeLists.txt @@ -37,6 +37,8 @@ include_directories( ${CMAKE_CURRENT_SOURCE_DIR}/kaldi ) +include(wetextprocessing) + # Build all libraries add_subdirectory(utils) add_subdirectory(frontend) diff --git a/wenet/bin/export_onnx_gpu.py b/wenet/bin/export_onnx_gpu.py index ab6c1dbe05..e8f3fbf76f 100644 --- a/wenet/bin/export_onnx_gpu.py +++ b/wenet/bin/export_onnx_gpu.py @@ -15,14 +15,13 @@ from __future__ import print_function import argparse +import logging import os import sys import torch -import yaml -import logging - import torch.nn.functional as F +import yaml from wenet.transformer.ctc import CTC from wenet.transformer.decoder import TransformerDecoder from wenet.transformer.encoder import BaseEncoder @@ -169,15 +168,19 @@ def forward(self, chunk_xs, chunk_lens, offset, att_cache, cnn_cache, r_att_cache = [] r_cnn_cache = [] for i, layer in enumerate(self.encoder.encoders): - xs, _, new_att_cache, new_cnn_cache = layer( + i_kv_cache = att_cache[i] + size = att_cache.size(-1) // 2 + kv_cache = (i_kv_cache[:, :, :, :size], i_kv_cache[:, :, :, size:]) + xs, _, new_kv_cache, new_cnn_cache = layer( xs, masks, pos_emb, - att_cache=att_cache[i], + att_cache=kv_cache, cnn_cache=cnn_cache[i], ) # shape(new_att_cache) is (B, head, attention_key_size, d_k * 2), # shape(new_cnn_cache) is (B, hidden-dim, cache_t2) + new_att_cache = torch.cat(new_kv_cache, dim=-1) r_att_cache.append( new_att_cache[:, :, next_cache_start:, :].unsqueeze(1)) if not self.transformer: @@ -1241,8 +1244,8 @@ def export_rescoring_decoder(model, configs, args, logger, decoder_onnx_path, if args.fp16: try: import onnxmltools - from onnxmltools.utils.float16_converter import ( - convert_float_to_float16, ) + from onnxmltools.utils.float16_converter import \ + convert_float_to_float16 except ImportError: print("Please install onnxmltools!") sys.exit(1) diff --git a/wenet/cli/hub.py b/wenet/cli/hub.py index b8ca91ad5a..171e334faa 100644 --- a/wenet/cli/hub.py +++ b/wenet/cli/hub.py @@ -13,12 +13,12 @@ # limitations under the License. import os -import requests import sys import tarfile from pathlib import Path from urllib.request import urlretrieve +import requests import tqdm @@ -77,7 +77,9 @@ class Hub(object): # gigaspeech "english": "gigaspeech_u2pp_conformer_libtorch.tar.gz", # paraformer - "paraformer": "paraformer.tar.gz" + "paraformer": "paraformer.tar.gz", + # punc + "punc": "punc.tar.gz" } def __init__(self) -> None: diff --git a/wenet/cli/paraformer_model.py b/wenet/cli/paraformer_model.py index a4f834ab25..81d04091b6 100644 --- a/wenet/cli/paraformer_model.py +++ b/wenet/cli/paraformer_model.py @@ -1,14 +1,14 @@ +import io import os +from typing import Dict, List, Union import torch import torchaudio import torchaudio.compliance.kaldi as kaldi - from wenet.cli.hub import Hub from wenet.paraformer.search import (gen_timestamps_from_peak, paraformer_greedy_search) from wenet.text.paraformer_tokenizer import ParaformerTokenizer -from wenet.utils.common import TORCH_NPU_AVAILABLE # noqa just ensure to check torch-npu class Paraformer: @@ -22,46 +22,74 @@ def __init__(self, model_dir: str, resample_rate: int = 16000) -> None: self.device = torch.device("cpu") self.tokenizer = ParaformerTokenizer(symbol_table=units_path) - def transcribe(self, audio_file: str, tokens_info: bool = False) -> dict: - waveform, sample_rate = torchaudio.load(audio_file, normalize=False) - waveform = waveform.to(torch.float).to(self.device) - if sample_rate != self.resample_rate: - waveform = torchaudio.transforms.Resample( - orig_freq=sample_rate, new_freq=self.resample_rate)(waveform) - feats = kaldi.fbank(waveform, - num_mel_bins=80, - frame_length=25, - frame_shift=10, - energy_floor=0.0, - sample_frequency=self.resample_rate, - window_type="hamming") - feats = feats.unsqueeze(0) - feats_lens = torch.tensor([feats.size(1)], - dtype=torch.int64, - device=feats.device) + @torch.inference_mode() + def transcribe_batch(self, + audio_files: List[Union[str, bytes]], + tokens_info: bool = False) -> List[Dict]: + feats_lst = [] + feats_lens_lst = [] + for audio in audio_files: + if isinstance(audio, bytes): + with io.BytesIO(audio) as fobj: + waveform, sample_rate = torchaudio.load(fobj, + normalize=False) + else: + waveform, sample_rate = torchaudio.load(audio, normalize=False) + if sample_rate != self.resample_rate: + waveform = torchaudio.transforms.Resample( + orig_freq=sample_rate, + new_freq=self.resample_rate)(waveform) + + waveform = waveform.to(torch.float) + feats = kaldi.fbank(waveform, + num_mel_bins=80, + frame_length=25, + frame_shift=10, + energy_floor=0.0, + sample_frequency=self.resample_rate, + window_type="hamming") + feats_lst.append(feats) + feats_lens_lst.append( + torch.tensor(feats.shape[0], dtype=torch.int64)) + feats_tensor = torch.nn.utils.rnn.pad_sequence( + feats_lst, batch_first=True).to(device=self.device) + feats_lens_tensor = torch.tensor(feats_lens_lst, device=self.device) - decoder_out, token_num, tp_alphas = self.model.forward_paraformer( - feats, feats_lens) + decoder_out, token_num, tp_alphas, frames = self.model.forward_paraformer( + feats_tensor, feats_lens_tensor) + frames = frames.cpu().numpy() cif_peaks = self.model.forward_cif_peaks(tp_alphas, token_num) - res = paraformer_greedy_search(decoder_out, token_num, cif_peaks)[0] - result = {} - result['confidence'] = res.confidence - result['text'] = self.tokenizer.detokenize(res.tokens)[0] - if tokens_info: - tokens_info = [] - times = gen_timestamps_from_peak(res.times, - num_frames=tp_alphas.size(1), - frame_rate=0.02) - for i, x in enumerate(res.tokens): - tokens_info.append({ - 'token': self.tokenizer.char_dict[x], - 'start': round(times[i][0], 3), - 'end': round(times[i][1], 3), - 'confidence': round(res.tokens_confidence[i], 2) - }) - result['tokens'] = tokens_info + results = paraformer_greedy_search(decoder_out, token_num, cif_peaks) + r = [] + for (i, res) in enumerate(results): + result = {} + result['confidence'] = res.confidence + result['text'] = self.tokenizer.detokenize(res.tokens)[0] + if tokens_info: + tokens_info_l = [] + times = gen_timestamps_from_peak(res.times, + num_frames=frames[i], + frame_rate=0.02) + + for i, x in enumerate(res.tokens[:len(times)]): + tokens_info_l.append({ + 'token': + self.tokenizer.char_dict[x], + 'start': + round(times[i][0], 3), + 'end': + round(times[i][1], 3), + 'confidence': + round(res.tokens_confidence[i], 2) + }) + result['tokens'] = tokens_info_l + r.append(result) + return r + + def transcribe(self, audio_file: str, tokens_info: bool = False) -> dict: + result = self.transcribe_batch([audio_file], tokens_info)[0] return result def align(self, audio_file: str, label: str) -> dict: diff --git a/wenet/cli/punc_model.py b/wenet/cli/punc_model.py new file mode 100644 index 0000000000..cab6ca73ab --- /dev/null +++ b/wenet/cli/punc_model.py @@ -0,0 +1,115 @@ +import os +from typing import List + +import jieba +import torch +from wenet.cli.hub import Hub +from wenet.paraformer.search import _isAllAlpha +from wenet.text.char_tokenizer import CharTokenizer + + +class PuncModel: + + def __init__(self, model_dir: str) -> None: + self.model_dir = model_dir + model_path = os.path.join(model_dir, 'final.zip') + units_path = os.path.join(model_dir, 'units.txt') + + self.model = torch.jit.load(model_path) + self.tokenizer = CharTokenizer(units_path) + self.device = torch.device("cpu") + self.use_jieba = False + + self.punc_table = ['', '', ',', '。', '?', '、'] + + def split_words(self, text: str): + if not self.use_jieba: + self.use_jieba = True + import logging + + # Disable jieba's logger + logging.getLogger('jieba').disabled = True + jieba.load_userdict(os.path.join(self.model_dir, 'jieba_usr_dict')) + + result_list = [] + tokens = text.split() + current_language = None + buffer = [] + + for token in tokens: + is_english = token.isascii() + if is_english: + language = "English" + else: + language = "Chinese" + + if current_language and language != current_language: + if current_language == "Chinese": + result_list.extend(jieba.cut(''.join(buffer), HMM=False)) + else: + result_list.extend(buffer) + buffer = [] + + buffer.append(token) + current_language = language + + if buffer: + if current_language == "Chinese": + result_list.extend(jieba.cut(''.join(buffer), HMM=False)) + else: + result_list.extend(buffer) + + return result_list + + def add_punc_batch(self, texts: List[str]): + batch_text_words = [] + batch_text_ids = [] + batch_text_lens = [] + + for text in texts: + words = self.split_words(text) + ids = self.tokenizer.tokens2ids(words) + batch_text_words.append(words) + batch_text_ids.append(ids) + batch_text_lens.append(len(ids)) + + texts_tensor = torch.tensor(batch_text_ids, + device=self.device, + dtype=torch.int64) + texts_lens_tensor = torch.tensor(batch_text_lens, + device=self.device, + dtype=torch.int64) + + log_probs, _ = self.model(texts_tensor, texts_lens_tensor) + result = [] + outs = log_probs.argmax(-1).cpu().numpy() + for i, out in enumerate(outs): + punc_id = out[:batch_text_lens[i]] + sentence = '' + for j, word in enumerate(batch_text_words[i]): + if _isAllAlpha(word): + word = '▁' + word + word += self.punc_table[punc_id[j]] + sentence += word + result.append(sentence.replace('▁', ' ')) + return result + + def __call__(self, text: str): + if text != '': + r = self.add_punc_batch([text])[0] + return r + return '' + + +def load_model(model_dir: str = None, + gpu: int = -1, + device: str = "cpu") -> PuncModel: + if model_dir is None: + model_dir = Hub.get_model_by_lang('punc') + if gpu != -1: + # remain the original usage of gpu + device = "cuda" + punc = PuncModel(model_dir) + punc.device = torch.device(device) + punc.model.to(device) + return punc diff --git a/wenet/cli/transcribe.py b/wenet/cli/transcribe.py index 28bf279192..8d65447c05 100644 --- a/wenet/cli/transcribe.py +++ b/wenet/cli/transcribe.py @@ -14,8 +14,9 @@ import argparse -from wenet.cli.paraformer_model import load_model as load_paraformer from wenet.cli.model import load_model +from wenet.cli.paraformer_model import load_model as load_paraformer +from wenet.cli.punc_model import load_model as load_punc_model def get_args(): @@ -64,6 +65,13 @@ def get_args(): type=float, default=6.0, help='context score') + parser.add_argument('--punc', action='store_true', help='context score') + + parser.add_argument('-pm', + '--punc_model_dir', + default=None, + help='specify your own punc model dir') + args = parser.parse_args() return args @@ -76,10 +84,17 @@ def main(): else: model = load_model(args.language, args.model_dir, args.gpu, args.beam, args.context_path, args.context_score, args.device) + punc_model = None + if args.punc: + punc_model = load_punc_model(args.punc_model_dir, args.gpu, + args.device) if args.align: result = model.align(args.audio_file, args.label) else: result = model.transcribe(args.audio_file, args.show_tokens_info) + if args.punc: + assert punc_model is not None + result['text_with_punc'] = punc_model(result['text']) print(result) diff --git a/wenet/paraformer/layers.py b/wenet/paraformer/layers.py index d17280d8a5..4c4a373906 100644 --- a/wenet/paraformer/layers.py +++ b/wenet/paraformer/layers.py @@ -3,18 +3,17 @@ import math from typing import Optional, Tuple -import torch +import torch import torch.utils.checkpoint as ckpt - from wenet.paraformer.attention import (DummyMultiHeadSANM, MultiHeadAttentionCross, MultiHeadedAttentionSANM) from wenet.paraformer.embedding import ParaformerPositinoalEncoding from wenet.paraformer.subsampling import IdentitySubsampling -from wenet.transformer.encoder import BaseEncoder from wenet.transformer.decoder import TransformerDecoder from wenet.transformer.decoder_layer import DecoderLayer +from wenet.transformer.encoder import BaseEncoder from wenet.transformer.encoder_layer import TransformerEncoderLayer from wenet.transformer.positionwise_feed_forward import PositionwiseFeedForward from wenet.utils.mask import make_non_pad_mask @@ -190,7 +189,7 @@ def __init__( num_blocks: int = 6, dropout_rate: float = 0.1, positional_dropout_rate: float = 0.1, - attention_dropout_rate: float = 0, + attention_dropout_rate: float = 0.0, input_layer: str = "conv2d", pos_enc_layer_type: str = "abs_pos", normalize_before: bool = True, @@ -389,8 +388,8 @@ def __init__( num_blocks: int = 6, dropout_rate: float = 0.1, positional_dropout_rate: float = 0.1, - self_attention_dropout_rate: float = 0, - src_attention_dropout_rate: float = 0, + self_attention_dropout_rate: float = 0.0, + src_attention_dropout_rate: float = 0.0, input_layer: str = "embed", use_output_layer: bool = True, normalize_before: bool = True, diff --git a/wenet/paraformer/paraformer.py b/wenet/paraformer/paraformer.py index be19f15b49..7824225640 100644 --- a/wenet/paraformer/paraformer.py +++ b/wenet/paraformer/paraformer.py @@ -19,9 +19,7 @@ import torch from wenet.paraformer.cif import Cif, cif_without_hidden - -from wenet.paraformer.layers import SanmDecoder, SanmEncoder -from wenet.paraformer.layers import LFR +from wenet.paraformer.layers import LFR, SanmDecoder, SanmEncoder from wenet.paraformer.search import (paraformer_beam_search, paraformer_greedy_search) from wenet.transformer.asr_model import ASRModel @@ -99,7 +97,8 @@ def forward(self, tp_alphas = tp_alphas.squeeze(-1) tp_token_num = tp_alphas.sum(-1) - return acoustic_embeds, token_num, alphas, cif_peak, tp_alphas, tp_token_num + return acoustic_embeds, token_num, alphas, cif_peak, tp_alphas, \ + tp_token_num, mask class Paraformer(ASRModel): @@ -170,7 +169,7 @@ def forward( if self.add_eos: _, ys_pad = add_sos_eos(text, self.sos, self.eos, self.ignore_id) ys_pad_lens = text_lengths + 1 - acoustic_embd, token_num, _, _, _, tp_token_num = self.predictor( + acoustic_embd, token_num, _, _, _, tp_token_num, _ = self.predictor( encoder_out, ys_pad, encoder_out_mask, self.ignore_id) # 2 decoder with sampler @@ -295,9 +294,10 @@ def forward_paraformer( self, speech: torch.Tensor, speech_lengths: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: res = self._forward_paraformer(speech, speech_lengths) - return res['decoder_out'], res['decoder_out_lens'], res['tp_alphas'] + return res['decoder_out'], res['decoder_out_lens'], res[ + 'tp_alphas'], res['tp_mask'].sum(1).squeeze(-1) @torch.jit.export def forward_encoder_chunk( @@ -336,8 +336,10 @@ def _forward_paraformer( num_decoding_left_chunks) # cif predictor - acoustic_embed, token_num, _, _, tp_alphas, _ = self.predictor( - encoder_out, mask=encoder_out_mask) + acoustic_embed, token_num, _, _, tp_alphas, _, tp_mask = self.predictor( + encoder_out, + mask=encoder_out_mask, + ) token_num = token_num.floor().to(speech_lengths.dtype) # decoder @@ -350,7 +352,8 @@ def _forward_paraformer( "encoder_out_mask": encoder_out_mask, "decoder_out": decoder_out, "tp_alphas": tp_alphas, - "decoder_out_lens": token_num + "decoder_out_lens": token_num, + "tp_mask": tp_mask } def decode(