Skip to content

Commit

Permalink
Merge branch 'wenet-e2e:main' into UpdateONNXRuntimeVersion
Browse files Browse the repository at this point in the history
  • Loading branch information
yangzhengzhe authored Dec 27, 2024
2 parents 41aae43 + d00940f commit 59cf948
Show file tree
Hide file tree
Showing 11 changed files with 247 additions and 73 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
| [**Runtime**](https://github.com/wenet-e2e/wenet/tree/main/runtime)
| [**Pretrained Models**](docs/pretrained_models.md)
| [**HuggingFace**](https://huggingface.co/spaces/wenet/wenet_demo)
| [**Ask WeNet Guru**](https://gurubase.io/g/wenet)

**We** share **Net** together.

Expand Down
23 changes: 14 additions & 9 deletions runtime/core/cmake/libtorch.cmake
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
if(TORCH)
set(TORCH_VERSION "2.1.0")
set(TORCH_VERSION "2.2.0")
add_definitions(-DUSE_TORCH)
if(NOT ANDROID)
if(GPU)
Expand All @@ -13,32 +13,37 @@ if(TORCH)
if(${CMAKE_SYSTEM_NAME} STREQUAL "Windows")
if(${CMAKE_BUILD_TYPE} MATCHES "Release")
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cpu/libtorch-win-shared-with-deps-${TORCH_VERSION}%2Bcpu.zip")
set(URL_HASH "SHA256=77815aa799f15e91b6fbb0216ac78cc0479adb5cd0ca662072241484cf23f667")
set(URL_HASH "SHA256=96bc833184a7c13a088a2a83cab5a2be853c0c9d9f972740a50580173d0c796d")
else()
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cpu/libtorch-win-shared-with-deps-debug-${TORCH_VERSION}%2Bcpu.zip")
set(URL_HASH "SHA256=5f887c02d9abf805c8b53fef89bf5a4dab9dd78771754344e73c98d9c484aa9d")
set(URL_HASH "SHA256=5b7dbabbecd86051b800ce0a244f15b89e9de0f8b5370e5fa65668aa37ecb878")
endif()
elseif(${CMAKE_SYSTEM_NAME} STREQUAL "Linux")
if(CXX11_ABI)
if(NOT GPU)
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-${TORCH_VERSION}%2Bcpu.zip")
set(URL_HASH "SHA256=04f699d5181048b0062ef52de1df44b46859b8fbeeee12abdbcb9aac63e2a14b")
set(URL_HASH "SHA256=62cd3001a2886d2db125aabc3be5c4fb66b3e34b32727d84323968f507ee8e32")
else()
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu118/libtorch-cxx11-abi-shared-with-deps-${TORCH_VERSION}%2Bcu118.zip")
set(URL_HASH "SHA256=7796249faa9828a53b72d3f616fc97a1d9e87e6a35ac72b392ca1ddc7b125188")
set(URL_HASH "SHA256=a2b0f51ff59ef2787a82c36bba67f7380236a6384dbbd2459c558989af27184f")
endif()
else()
if(NOT GPU)
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cpu/libtorch-shared-with-deps-${TORCH_VERSION}%2Bcpu.zip")
set(URL_HASH "SHA256=0e86d364d05b83c6c66c3bb32e7eee932847843e4085487eefd9b3bbde4e2c58")
set(URL_HASH "SHA256=e1f6bc48403022ff4680c7299cc8b160df146892c414b8a6b6f7d5aff65bcbce")
else()
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cu118/libtorch-shared-with-deps-${TORCH_VERSION}%2Bcu118.zip")
set(URL_HASH "SHA256=f70cfae25b02ff419e1d51ad137a746941773d2c4b0155a44b4b6b50702d661a")
set(URL_HASH "SHA256=f9c887085207f9500357cae4324a53c3010b8890397db915d7dbefb9183c7964")
endif()
endif()
elseif(${CMAKE_SYSTEM_NAME} STREQUAL "Darwin")
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cpu/libtorch-macos-${TORCH_VERSION}.zip")
set(URL_HASH "SHA256=ce744d2d27a96df8f34d4227e8b1179dad5a76612dc7230b61db65affce6e7bd")
if(CMAKE_SYSTEM_PROCESSOR MATCHES "arm64")
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cpu/libtorch-macos-arm64-${TORCH_VERSION}.zip")
set(URL_HASH "SHA256=a2ac530e5db2f5be33fe7f7e3049b9a525ee60b110dbb1e08835e22002756ed1")
else()
set(LIBTORCH_URL "https://download.pytorch.org/libtorch/cpu/libtorch-macos-x86_64-${TORCH_VERSION}.zip")
set(URL_HASH "SHA256=300940c6b1d4402ece72d31cd5694d9579dcfb23b7cf6b05676006411f9b516c")
endif()
elseif(${CMAKE_SYSTEM_NAME} STREQUAL "iOS")
add_definitions(-DIOS)
else()
Expand Down
1 change: 1 addition & 0 deletions runtime/horizonbpu/.gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
build/
fc_base/
wheels*
2 changes: 2 additions & 0 deletions runtime/horizonbpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ include_directories(
${CMAKE_CURRENT_SOURCE_DIR}/kaldi
)

include(wetextprocessing)

# Build all libraries
add_subdirectory(utils)
add_subdirectory(frontend)
Expand Down
17 changes: 10 additions & 7 deletions wenet/bin/export_onnx_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,13 @@
from __future__ import print_function

import argparse
import logging
import os
import sys

import torch
import yaml
import logging

import torch.nn.functional as F
import yaml
from wenet.transformer.ctc import CTC
from wenet.transformer.decoder import TransformerDecoder
from wenet.transformer.encoder import BaseEncoder
Expand Down Expand Up @@ -169,15 +168,19 @@ def forward(self, chunk_xs, chunk_lens, offset, att_cache, cnn_cache,
r_att_cache = []
r_cnn_cache = []
for i, layer in enumerate(self.encoder.encoders):
xs, _, new_att_cache, new_cnn_cache = layer(
i_kv_cache = att_cache[i]
size = att_cache.size(-1) // 2
kv_cache = (i_kv_cache[:, :, :, :size], i_kv_cache[:, :, :, size:])
xs, _, new_kv_cache, new_cnn_cache = layer(
xs,
masks,
pos_emb,
att_cache=att_cache[i],
att_cache=kv_cache,
cnn_cache=cnn_cache[i],
)
# shape(new_att_cache) is (B, head, attention_key_size, d_k * 2),
# shape(new_cnn_cache) is (B, hidden-dim, cache_t2)
new_att_cache = torch.cat(new_kv_cache, dim=-1)
r_att_cache.append(
new_att_cache[:, :, next_cache_start:, :].unsqueeze(1))
if not self.transformer:
Expand Down Expand Up @@ -1241,8 +1244,8 @@ def export_rescoring_decoder(model, configs, args, logger, decoder_onnx_path,
if args.fp16:
try:
import onnxmltools
from onnxmltools.utils.float16_converter import (
convert_float_to_float16, )
from onnxmltools.utils.float16_converter import \
convert_float_to_float16
except ImportError:
print("Please install onnxmltools!")
sys.exit(1)
Expand Down
6 changes: 4 additions & 2 deletions wenet/cli/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
# limitations under the License.

import os
import requests
import sys
import tarfile
from pathlib import Path
from urllib.request import urlretrieve

import requests
import tqdm


Expand Down Expand Up @@ -77,7 +77,9 @@ class Hub(object):
# gigaspeech
"english": "gigaspeech_u2pp_conformer_libtorch.tar.gz",
# paraformer
"paraformer": "paraformer.tar.gz"
"paraformer": "paraformer.tar.gz",
# punc
"punc": "punc.tar.gz"
}

def __init__(self) -> None:
Expand Down
104 changes: 66 additions & 38 deletions wenet/cli/paraformer_model.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import io
import os
from typing import Dict, List, Union

import torch
import torchaudio
import torchaudio.compliance.kaldi as kaldi

from wenet.cli.hub import Hub
from wenet.paraformer.search import (gen_timestamps_from_peak,
paraformer_greedy_search)
from wenet.text.paraformer_tokenizer import ParaformerTokenizer
from wenet.utils.common import TORCH_NPU_AVAILABLE # noqa just ensure to check torch-npu


class Paraformer:
Expand All @@ -22,46 +22,74 @@ def __init__(self, model_dir: str, resample_rate: int = 16000) -> None:
self.device = torch.device("cpu")
self.tokenizer = ParaformerTokenizer(symbol_table=units_path)

def transcribe(self, audio_file: str, tokens_info: bool = False) -> dict:
waveform, sample_rate = torchaudio.load(audio_file, normalize=False)
waveform = waveform.to(torch.float).to(self.device)
if sample_rate != self.resample_rate:
waveform = torchaudio.transforms.Resample(
orig_freq=sample_rate, new_freq=self.resample_rate)(waveform)
feats = kaldi.fbank(waveform,
num_mel_bins=80,
frame_length=25,
frame_shift=10,
energy_floor=0.0,
sample_frequency=self.resample_rate,
window_type="hamming")
feats = feats.unsqueeze(0)
feats_lens = torch.tensor([feats.size(1)],
dtype=torch.int64,
device=feats.device)
@torch.inference_mode()
def transcribe_batch(self,
audio_files: List[Union[str, bytes]],
tokens_info: bool = False) -> List[Dict]:
feats_lst = []
feats_lens_lst = []
for audio in audio_files:
if isinstance(audio, bytes):
with io.BytesIO(audio) as fobj:
waveform, sample_rate = torchaudio.load(fobj,
normalize=False)
else:
waveform, sample_rate = torchaudio.load(audio, normalize=False)
if sample_rate != self.resample_rate:
waveform = torchaudio.transforms.Resample(
orig_freq=sample_rate,
new_freq=self.resample_rate)(waveform)

waveform = waveform.to(torch.float)
feats = kaldi.fbank(waveform,
num_mel_bins=80,
frame_length=25,
frame_shift=10,
energy_floor=0.0,
sample_frequency=self.resample_rate,
window_type="hamming")
feats_lst.append(feats)
feats_lens_lst.append(
torch.tensor(feats.shape[0], dtype=torch.int64))
feats_tensor = torch.nn.utils.rnn.pad_sequence(
feats_lst, batch_first=True).to(device=self.device)
feats_lens_tensor = torch.tensor(feats_lens_lst, device=self.device)

decoder_out, token_num, tp_alphas = self.model.forward_paraformer(
feats, feats_lens)
decoder_out, token_num, tp_alphas, frames = self.model.forward_paraformer(
feats_tensor, feats_lens_tensor)
frames = frames.cpu().numpy()
cif_peaks = self.model.forward_cif_peaks(tp_alphas, token_num)
res = paraformer_greedy_search(decoder_out, token_num, cif_peaks)[0]
result = {}
result['confidence'] = res.confidence
result['text'] = self.tokenizer.detokenize(res.tokens)[0]
if tokens_info:
tokens_info = []
times = gen_timestamps_from_peak(res.times,
num_frames=tp_alphas.size(1),
frame_rate=0.02)

for i, x in enumerate(res.tokens):
tokens_info.append({
'token': self.tokenizer.char_dict[x],
'start': round(times[i][0], 3),
'end': round(times[i][1], 3),
'confidence': round(res.tokens_confidence[i], 2)
})
result['tokens'] = tokens_info
results = paraformer_greedy_search(decoder_out, token_num, cif_peaks)

r = []
for (i, res) in enumerate(results):
result = {}
result['confidence'] = res.confidence
result['text'] = self.tokenizer.detokenize(res.tokens)[0]
if tokens_info:
tokens_info_l = []
times = gen_timestamps_from_peak(res.times,
num_frames=frames[i],
frame_rate=0.02)

for i, x in enumerate(res.tokens[:len(times)]):
tokens_info_l.append({
'token':
self.tokenizer.char_dict[x],
'start':
round(times[i][0], 3),
'end':
round(times[i][1], 3),
'confidence':
round(res.tokens_confidence[i], 2)
})
result['tokens'] = tokens_info_l
r.append(result)
return r

def transcribe(self, audio_file: str, tokens_info: bool = False) -> dict:
result = self.transcribe_batch([audio_file], tokens_info)[0]
return result

def align(self, audio_file: str, label: str) -> dict:
Expand Down
115 changes: 115 additions & 0 deletions wenet/cli/punc_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import os
from typing import List

import jieba
import torch
from wenet.cli.hub import Hub
from wenet.paraformer.search import _isAllAlpha
from wenet.text.char_tokenizer import CharTokenizer


class PuncModel:

def __init__(self, model_dir: str) -> None:
self.model_dir = model_dir
model_path = os.path.join(model_dir, 'final.zip')
units_path = os.path.join(model_dir, 'units.txt')

self.model = torch.jit.load(model_path)
self.tokenizer = CharTokenizer(units_path)
self.device = torch.device("cpu")
self.use_jieba = False

self.punc_table = ['<unk>', '', ',', '。', '?', '、']

def split_words(self, text: str):
if not self.use_jieba:
self.use_jieba = True
import logging

# Disable jieba's logger
logging.getLogger('jieba').disabled = True
jieba.load_userdict(os.path.join(self.model_dir, 'jieba_usr_dict'))

result_list = []
tokens = text.split()
current_language = None
buffer = []

for token in tokens:
is_english = token.isascii()
if is_english:
language = "English"
else:
language = "Chinese"

if current_language and language != current_language:
if current_language == "Chinese":
result_list.extend(jieba.cut(''.join(buffer), HMM=False))
else:
result_list.extend(buffer)
buffer = []

buffer.append(token)
current_language = language

if buffer:
if current_language == "Chinese":
result_list.extend(jieba.cut(''.join(buffer), HMM=False))
else:
result_list.extend(buffer)

return result_list

def add_punc_batch(self, texts: List[str]):
batch_text_words = []
batch_text_ids = []
batch_text_lens = []

for text in texts:
words = self.split_words(text)
ids = self.tokenizer.tokens2ids(words)
batch_text_words.append(words)
batch_text_ids.append(ids)
batch_text_lens.append(len(ids))

texts_tensor = torch.tensor(batch_text_ids,
device=self.device,
dtype=torch.int64)
texts_lens_tensor = torch.tensor(batch_text_lens,
device=self.device,
dtype=torch.int64)

log_probs, _ = self.model(texts_tensor, texts_lens_tensor)
result = []
outs = log_probs.argmax(-1).cpu().numpy()
for i, out in enumerate(outs):
punc_id = out[:batch_text_lens[i]]
sentence = ''
for j, word in enumerate(batch_text_words[i]):
if _isAllAlpha(word):
word = '▁' + word
word += self.punc_table[punc_id[j]]
sentence += word
result.append(sentence.replace('▁', ' '))
return result

def __call__(self, text: str):
if text != '':
r = self.add_punc_batch([text])[0]
return r
return ''


def load_model(model_dir: str = None,
gpu: int = -1,
device: str = "cpu") -> PuncModel:
if model_dir is None:
model_dir = Hub.get_model_by_lang('punc')
if gpu != -1:
# remain the original usage of gpu
device = "cuda"
punc = PuncModel(model_dir)
punc.device = torch.device(device)
punc.model.to(device)
return punc
Loading

0 comments on commit 59cf948

Please sign in to comment.