From 905a1f25853e2e4adb3a19f37cdff1610b251fa7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Thu, 2 Mar 2023 19:34:46 +0800 Subject: [PATCH 1/3] torchscripts --- funasr/runtime/python/libtorch/README.md | 70 +++++ .../{torchscripts => libtorch}/__init__.py | 0 funasr/runtime/python/libtorch/demo.py | 11 + funasr/runtime/python/libtorch/setup.py | 43 ++++ .../libtorch/torch_paraformer/__init__.py | 2 + .../torch_paraformer/paraformer_bin.py | 155 +++++++++++ .../torch_paraformer/utils}/__init__.py | 0 .../torch_paraformer/utils/frontend.py | 191 ++++++++++++++ .../utils/postprocess_utils.py | 240 ++++++++++++++++++ .../libtorch/torch_paraformer/utils/utils.py | 165 ++++++++++++ 10 files changed, 877 insertions(+) create mode 100644 funasr/runtime/python/libtorch/README.md rename funasr/runtime/python/{torchscripts => libtorch}/__init__.py (100%) create mode 100644 funasr/runtime/python/libtorch/demo.py create mode 100644 funasr/runtime/python/libtorch/setup.py create mode 100644 funasr/runtime/python/libtorch/torch_paraformer/__init__.py create mode 100644 funasr/runtime/python/libtorch/torch_paraformer/paraformer_bin.py rename funasr/runtime/python/{torchscripts/paraformer => libtorch/torch_paraformer/utils}/__init__.py (100%) create mode 100644 funasr/runtime/python/libtorch/torch_paraformer/utils/frontend.py create mode 100644 funasr/runtime/python/libtorch/torch_paraformer/utils/postprocess_utils.py create mode 100644 funasr/runtime/python/libtorch/torch_paraformer/utils/utils.py diff --git a/funasr/runtime/python/libtorch/README.md b/funasr/runtime/python/libtorch/README.md new file mode 100644 index 000000000..33a56afca --- /dev/null +++ b/funasr/runtime/python/libtorch/README.md @@ -0,0 +1,70 @@ +## Using paraformer with libtorch + + +### Introduction +- Model comes from [speech_paraformer](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary). + +### Steps: +1. Export the model. + - Command: (`Tips`: torch >= 1.11.0 is required.) + + ```shell + python -m funasr.export.export_model [model_name] [export_dir] [true] + ``` + `model_name`: the model is to export. + + `export_dir`: the dir where the onnx is export. + + More details ref to ([export docs](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/export)) + + - `e.g.`, Export model from modelscope + ```shell + python -m funasr.export.export_model 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' "./export" true + ``` + - `e.g.`, Export model from local path, the model'name must be `model.pb`. + ```shell + python -m funasr.export.export_model '/mnt/workspace/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' "./export" true + ``` + + +2. Install the `torch_paraformer`. + - Build the torch_paraformer `whl` + ```shell + git clone https://github.com/alibaba/FunASR.git && cd FunASR + cd funasr/runtime/python/libtorch + python setup.py bdist_wheel + ``` + - Install the build `whl` + ```bash + pip install dist/torch_paraformer-0.0.1-py3-none-any.whl + ``` + +3. Run the demo. + - Model_dir: the model path, which contains `model.torchscripts`, `config.yaml`, `am.mvn`. + - Input: wav formt file, support formats: `str, np.ndarray, List[str]` + - Output: `List[str]`: recognition result. + - Example: + ```python + from torch_paraformer import Paraformer + + model_dir = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" + model = Paraformer(model_dir, batch_size=1) + + wav_path = ['/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav'] + + result = model(wav_path) + print(result) + ``` + +## Speed + +Environment:Intel(R) Xeon(R) Platinum 8163 CPU @ 2.50GHz + +Test [wav, 5.53s, 100 times avg.](https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav) + +| Backend | RTF | +|:-------:|:-----------------:| +| Pytorch | 0.110 | +| Onnx | 0.038 | + +## Acknowledge diff --git a/funasr/runtime/python/torchscripts/__init__.py b/funasr/runtime/python/libtorch/__init__.py similarity index 100% rename from funasr/runtime/python/torchscripts/__init__.py rename to funasr/runtime/python/libtorch/__init__.py diff --git a/funasr/runtime/python/libtorch/demo.py b/funasr/runtime/python/libtorch/demo.py new file mode 100644 index 000000000..71b2b855e --- /dev/null +++ b/funasr/runtime/python/libtorch/demo.py @@ -0,0 +1,11 @@ + +from torch_paraformer import Paraformer + +model_dir = "/Users/shixian/code/funasr2/export/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch" +model_dir = "/Users/shixian/code/funasr2/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" +model = Paraformer(model_dir, batch_size=1) + +wav_path = ['/Users/shixian/code/funasr2/export/damo/speech_paraformer-tiny-commandword_asr_nat-zh-cn-16k-vocab544-pytorch/example/asr_example.wav'] + +result = model(wav_path) +print(result) \ No newline at end of file diff --git a/funasr/runtime/python/libtorch/setup.py b/funasr/runtime/python/libtorch/setup.py new file mode 100644 index 000000000..99d8b5201 --- /dev/null +++ b/funasr/runtime/python/libtorch/setup.py @@ -0,0 +1,43 @@ +# -*- encoding: utf-8 -*- +from pathlib import Path +import setuptools + + +def get_readme(): + root_dir = Path(__file__).resolve().parent + readme_path = str(root_dir / 'README.md') + print(readme_path) + with open(readme_path, 'r', encoding='utf-8') as f: + readme = f.read() + return readme + + + +setuptools.setup( + name='torch_paraformer', + version='0.0.1', + platforms="Any", + url="https://github.com/alibaba-damo-academy/FunASR.git", + author="Speech Lab, Alibaba Group, China", + author_email="funasr@list.alibaba-inc.com", + description="FunASR: A Fundamental End-to-End Speech Recognition Toolkit", + license="The MIT License", + long_description=get_readme(), + long_description_content_type='text/markdown', + include_package_data=True, + install_requires=["librosa", "onnxruntime>=1.7.0", + "scipy", "numpy>=1.19.3", + "typeguard", "kaldi-native-fbank", + "PyYAML>=5.1.2"], + packages=['torch_paraformer'], + keywords=[ + 'funasr,paraformer' + ], + classifiers=[ + 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', + 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.10', + ], +) diff --git a/funasr/runtime/python/libtorch/torch_paraformer/__init__.py b/funasr/runtime/python/libtorch/torch_paraformer/__init__.py new file mode 100644 index 000000000..647f9fadc --- /dev/null +++ b/funasr/runtime/python/libtorch/torch_paraformer/__init__.py @@ -0,0 +1,2 @@ +# -*- encoding: utf-8 -*- +from .paraformer_bin import Paraformer diff --git a/funasr/runtime/python/libtorch/torch_paraformer/paraformer_bin.py b/funasr/runtime/python/libtorch/torch_paraformer/paraformer_bin.py new file mode 100644 index 000000000..ca9055887 --- /dev/null +++ b/funasr/runtime/python/libtorch/torch_paraformer/paraformer_bin.py @@ -0,0 +1,155 @@ +# -*- encoding: utf-8 -*- +import os.path +from pathlib import Path +from typing import List, Union, Tuple + +import copy +import librosa +import numpy as np + +from .utils.utils import (CharTokenizer, Hypothesis, + TokenIDConverter, get_logger, + read_yaml) +from .utils.postprocess_utils import sentence_postprocess +from .utils.frontend import WavFrontend +from funasr.utils.timestamp_tools import time_stamp_lfr6_pl +logging = get_logger() + +import torch + + +class Paraformer(): + def __init__(self, model_dir: Union[str, Path] = None, + batch_size: int = 1, + device_id: Union[str, int] = "-1", + ): + + if not Path(model_dir).exists(): + raise FileNotFoundError(f'{model_dir} does not exist.') + + model_file = os.path.join(model_dir, 'model.onnx') + config_file = os.path.join(model_dir, 'config.yaml') + cmvn_file = os.path.join(model_dir, 'am.mvn') + config = read_yaml(config_file) + + self.converter = TokenIDConverter(config['token_list']) + self.tokenizer = CharTokenizer() + self.frontend = WavFrontend( + cmvn_file=cmvn_file, + **config['frontend_conf'] + ) + self.ort_infer = torch.jit.load(model_file) + self.batch_size = batch_size + + def __call__(self, wav_content: Union[str, np.ndarray, List[str]], **kwargs) -> List: + waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq) + waveform_nums = len(waveform_list) + + asr_res = [] + for beg_idx in range(0, waveform_nums, self.batch_size): + res = {} + end_idx = min(waveform_nums, beg_idx + self.batch_size) + feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx]) + + try: + outputs = self.infer(feats, feats_len) + outs = outputs[0], outputs[1] + am_scores, valid_token_lens = outs[0], outs[1] + if len(outputs) == 4: + # for BiCifParaformer Inference + us_alphas, us_cif_peak = outputs[2], outputs[3] + else: + us_alphas, us_cif_peak = None, None + except: + #logging.warning(traceback.format_exc()) + logging.warning("input wav is silence or noise") + preds = [''] + else: + am_scores, valid_token_lens = am_scores.cpu().numpy(), valid_token_lens.cpu().numpy() + preds, raw_token = self.decode(am_scores, valid_token_lens)[0] + res['preds'] = preds + if us_cif_peak is not None: + us_alphas, us_cif_peak = us_alphas.cpu().numpy(), us_cif_peak.cpu().numpy() + timestamp = time_stamp_lfr6_pl(us_alphas, us_cif_peak, copy.copy(raw_token), log=False) + res['timestamp'] = timestamp + asr_res.append(res) + return asr_res + + def load_data(self, + wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List: + def load_wav(path: str) -> np.ndarray: + waveform, _ = librosa.load(path, sr=fs) + return waveform + + if isinstance(wav_content, np.ndarray): + return [wav_content] + + if isinstance(wav_content, str): + return [load_wav(wav_content)] + + if isinstance(wav_content, list): + return [load_wav(path) for path in wav_content] + + raise TypeError( + f'The type of {wav_content} is not in [str, np.ndarray, list]') + + def extract_feat(self, + waveform_list: List[np.ndarray] + ) -> Tuple[np.ndarray, np.ndarray]: + feats, feats_len = [], [] + for waveform in waveform_list: + speech, _ = self.frontend.fbank(waveform) + feat, feat_len = self.frontend.lfr_cmvn(speech) + feats.append(feat) + feats_len.append(feat_len) + + feats = self.pad_feats(feats, np.max(feats_len)) + feats_len = np.array(feats_len).astype(np.int32) + return feats, feats_len + + @staticmethod + def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray: + def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray: + pad_width = ((0, max_feat_len - cur_len), (0, 0)) + return np.pad(feat, pad_width, 'constant', constant_values=0) + + feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats] + feats = np.array(feat_res).astype(np.float32) + return feats + + def infer(self, feats: np.ndarray, + feats_len: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + outputs = self.ort_infer([feats, feats_len]) + return outputs + + def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]: + return [self.decode_one(am_score, token_num) + for am_score, token_num in zip(am_scores, token_nums)] + + def decode_one(self, + am_score: np.ndarray, + valid_token_num: int) -> List[str]: + yseq = am_score.argmax(axis=-1) + score = am_score.max(axis=-1) + score = np.sum(score, axis=-1) + + # pad with mask tokens to ensure compatibility with sos/eos tokens + # asr_model.sos:1 asr_model.eos:2 + yseq = np.array([1] + yseq.tolist() + [2]) + hyp = Hypothesis(yseq=yseq, score=score) + + # remove sos/eos and get results + last_pos = -1 + token_int = hyp.yseq[1:last_pos].tolist() + + # remove blank symbol id, which is assumed to be 0 + token_int = list(filter(lambda x: x not in (0, 2), token_int)) + + # Change integer-ids to tokens + token = self.converter.ids2tokens(token_int) + # token = token[:valid_token_num-1] + texts = sentence_postprocess(token) + text = texts[0] + # text = self.tokenizer.tokens2text(token) + return text, token + diff --git a/funasr/runtime/python/torchscripts/paraformer/__init__.py b/funasr/runtime/python/libtorch/torch_paraformer/utils/__init__.py similarity index 100% rename from funasr/runtime/python/torchscripts/paraformer/__init__.py rename to funasr/runtime/python/libtorch/torch_paraformer/utils/__init__.py diff --git a/funasr/runtime/python/libtorch/torch_paraformer/utils/frontend.py b/funasr/runtime/python/libtorch/torch_paraformer/utils/frontend.py new file mode 100644 index 000000000..11a86445d --- /dev/null +++ b/funasr/runtime/python/libtorch/torch_paraformer/utils/frontend.py @@ -0,0 +1,191 @@ +# -*- encoding: utf-8 -*- +from pathlib import Path +from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union + +import numpy as np +from typeguard import check_argument_types +import kaldi_native_fbank as knf + +root_dir = Path(__file__).resolve().parent + +logger_initialized = {} + + +class WavFrontend(): + """Conventional frontend structure for ASR. + """ + + def __init__( + self, + cmvn_file: str = None, + fs: int = 16000, + window: str = 'hamming', + n_mels: int = 80, + frame_length: int = 25, + frame_shift: int = 10, + lfr_m: int = 1, + lfr_n: int = 1, + dither: float = 1.0, + **kwargs, + ) -> None: + check_argument_types() + + opts = knf.FbankOptions() + opts.frame_opts.samp_freq = fs + opts.frame_opts.dither = dither + opts.frame_opts.window_type = window + opts.frame_opts.frame_shift_ms = float(frame_shift) + opts.frame_opts.frame_length_ms = float(frame_length) + opts.mel_opts.num_bins = n_mels + opts.energy_floor = 0 + opts.frame_opts.snip_edges = True + opts.mel_opts.debug_mel = False + self.opts = opts + + self.lfr_m = lfr_m + self.lfr_n = lfr_n + self.cmvn_file = cmvn_file + + if self.cmvn_file: + self.cmvn = self.load_cmvn() + self.fbank_fn = None + self.fbank_beg_idx = 0 + self.reset_status() + + def fbank(self, + waveform: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + waveform = waveform * (1 << 15) + self.fbank_fn = knf.OnlineFbank(self.opts) + self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist()) + frames = self.fbank_fn.num_frames_ready + mat = np.empty([frames, self.opts.mel_opts.num_bins]) + for i in range(frames): + mat[i, :] = self.fbank_fn.get_frame(i) + feat = mat.astype(np.float32) + feat_len = np.array(mat.shape[0]).astype(np.int32) + return feat, feat_len + + def fbank_online(self, + waveform: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + waveform = waveform * (1 << 15) + # self.fbank_fn = knf.OnlineFbank(self.opts) + self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist()) + frames = self.fbank_fn.num_frames_ready + mat = np.empty([frames, self.opts.mel_opts.num_bins]) + for i in range(self.fbank_beg_idx, frames): + mat[i, :] = self.fbank_fn.get_frame(i) + # self.fbank_beg_idx += (frames-self.fbank_beg_idx) + feat = mat.astype(np.float32) + feat_len = np.array(mat.shape[0]).astype(np.int32) + return feat, feat_len + + def reset_status(self): + self.fbank_fn = knf.OnlineFbank(self.opts) + self.fbank_beg_idx = 0 + + def lfr_cmvn(self, feat: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + if self.lfr_m != 1 or self.lfr_n != 1: + feat = self.apply_lfr(feat, self.lfr_m, self.lfr_n) + + if self.cmvn_file: + feat = self.apply_cmvn(feat) + + feat_len = np.array(feat.shape[0]).astype(np.int32) + return feat, feat_len + + @staticmethod + def apply_lfr(inputs: np.ndarray, lfr_m: int, lfr_n: int) -> np.ndarray: + LFR_inputs = [] + + T = inputs.shape[0] + T_lfr = int(np.ceil(T / lfr_n)) + left_padding = np.tile(inputs[0], ((lfr_m - 1) // 2, 1)) + inputs = np.vstack((left_padding, inputs)) + T = T + (lfr_m - 1) // 2 + for i in range(T_lfr): + if lfr_m <= T - i * lfr_n: + LFR_inputs.append( + (inputs[i * lfr_n:i * lfr_n + lfr_m]).reshape(1, -1)) + else: + # process last LFR frame + num_padding = lfr_m - (T - i * lfr_n) + frame = inputs[i * lfr_n:].reshape(-1) + for _ in range(num_padding): + frame = np.hstack((frame, inputs[-1])) + + LFR_inputs.append(frame) + LFR_outputs = np.vstack(LFR_inputs).astype(np.float32) + return LFR_outputs + + def apply_cmvn(self, inputs: np.ndarray) -> np.ndarray: + """ + Apply CMVN with mvn data + """ + frame, dim = inputs.shape + means = np.tile(self.cmvn[0:1, :dim], (frame, 1)) + vars = np.tile(self.cmvn[1:2, :dim], (frame, 1)) + inputs = (inputs + means) * vars + return inputs + + def load_cmvn(self,) -> np.ndarray: + with open(self.cmvn_file, 'r', encoding='utf-8') as f: + lines = f.readlines() + + means_list = [] + vars_list = [] + for i in range(len(lines)): + line_item = lines[i].split() + if line_item[0] == '': + line_item = lines[i + 1].split() + if line_item[0] == '': + add_shift_line = line_item[3:(len(line_item) - 1)] + means_list = list(add_shift_line) + continue + elif line_item[0] == '': + line_item = lines[i + 1].split() + if line_item[0] == '': + rescale_line = line_item[3:(len(line_item) - 1)] + vars_list = list(rescale_line) + continue + + means = np.array(means_list).astype(np.float64) + vars = np.array(vars_list).astype(np.float64) + cmvn = np.array([means, vars]) + return cmvn + +def load_bytes(input): + middle_data = np.frombuffer(input, dtype=np.int16) + middle_data = np.asarray(middle_data) + if middle_data.dtype.kind not in 'iu': + raise TypeError("'middle_data' must be an array of integers") + dtype = np.dtype('float32') + if dtype.kind != 'f': + raise TypeError("'dtype' must be a floating point type") + + i = np.iinfo(middle_data.dtype) + abs_max = 2 ** (i.bits - 1) + offset = i.min + abs_max + array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32) + return array + + +def test(): + path = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav" + import librosa + cmvn_file = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/am.mvn" + config_file = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/config.yaml" + from funasr.runtime.python.onnxruntime.rapid_paraformer.utils.utils import read_yaml + config = read_yaml(config_file) + waveform, _ = librosa.load(path, sr=None) + frontend = WavFrontend( + cmvn_file=cmvn_file, + **config['frontend_conf'], + ) + speech, _ = frontend.fbank_online(waveform) #1d, (sample,), numpy + feat, feat_len = frontend.lfr_cmvn(speech) # 2d, (frame, 450), np.float32 -> torch, torch.from_numpy(), dtype, (1, frame, 450) + + frontend.reset_status() # clear cache + return feat, feat_len + +if __name__ == '__main__': + test() \ No newline at end of file diff --git a/funasr/runtime/python/libtorch/torch_paraformer/utils/postprocess_utils.py b/funasr/runtime/python/libtorch/torch_paraformer/utils/postprocess_utils.py new file mode 100644 index 000000000..575fb90dd --- /dev/null +++ b/funasr/runtime/python/libtorch/torch_paraformer/utils/postprocess_utils.py @@ -0,0 +1,240 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import string +import logging +from typing import Any, List, Union + + +def isChinese(ch: str): + if '\u4e00' <= ch <= '\u9fff' or '\u0030' <= ch <= '\u0039': + return True + return False + + +def isAllChinese(word: Union[List[Any], str]): + word_lists = [] + for i in word: + cur = i.replace(' ', '') + cur = cur.replace('', '') + cur = cur.replace('', '') + word_lists.append(cur) + + if len(word_lists) == 0: + return False + + for ch in word_lists: + if isChinese(ch) is False: + return False + return True + + +def isAllAlpha(word: Union[List[Any], str]): + word_lists = [] + for i in word: + cur = i.replace(' ', '') + cur = cur.replace('', '') + cur = cur.replace('', '') + word_lists.append(cur) + + if len(word_lists) == 0: + return False + + for ch in word_lists: + if ch.isalpha() is False and ch != "'": + return False + elif ch.isalpha() is True and isChinese(ch) is True: + return False + + return True + + +# def abbr_dispose(words: List[Any]) -> List[Any]: +def abbr_dispose(words: List[Any], time_stamp: List[List] = None) -> List[Any]: + words_size = len(words) + word_lists = [] + abbr_begin = [] + abbr_end = [] + last_num = -1 + ts_lists = [] + ts_nums = [] + ts_index = 0 + for num in range(words_size): + if num <= last_num: + continue + + if len(words[num]) == 1 and words[num].encode('utf-8').isalpha(): + if num + 1 < words_size and words[ + num + 1] == ' ' and num + 2 < words_size and len( + words[num + + 2]) == 1 and words[num + + 2].encode('utf-8').isalpha(): + # found the begin of abbr + abbr_begin.append(num) + num += 2 + abbr_end.append(num) + # to find the end of abbr + while True: + num += 1 + if num < words_size and words[num] == ' ': + num += 1 + if num < words_size and len( + words[num]) == 1 and words[num].encode( + 'utf-8').isalpha(): + abbr_end.pop() + abbr_end.append(num) + last_num = num + else: + break + else: + break + + for num in range(words_size): + if words[num] == ' ': + ts_nums.append(ts_index) + else: + ts_nums.append(ts_index) + ts_index += 1 + last_num = -1 + for num in range(words_size): + if num <= last_num: + continue + + if num in abbr_begin: + if time_stamp is not None: + begin = time_stamp[ts_nums[num]][0] + word_lists.append(words[num].upper()) + num += 1 + while num < words_size: + if num in abbr_end: + word_lists.append(words[num].upper()) + last_num = num + break + else: + if words[num].encode('utf-8').isalpha(): + word_lists.append(words[num].upper()) + num += 1 + if time_stamp is not None: + end = time_stamp[ts_nums[num]][1] + ts_lists.append([begin, end]) + else: + word_lists.append(words[num]) + if time_stamp is not None and words[num] != ' ': + begin = time_stamp[ts_nums[num]][0] + end = time_stamp[ts_nums[num]][1] + ts_lists.append([begin, end]) + begin = end + + if time_stamp is not None: + return word_lists, ts_lists + else: + return word_lists + + +def sentence_postprocess(words: List[Any], time_stamp: List[List] = None): + middle_lists = [] + word_lists = [] + word_item = '' + ts_lists = [] + + # wash words lists + for i in words: + word = '' + if isinstance(i, str): + word = i + else: + word = i.decode('utf-8') + + if word in ['', '', '']: + continue + else: + middle_lists.append(word) + + # all chinese characters + if isAllChinese(middle_lists): + for i, ch in enumerate(middle_lists): + word_lists.append(ch.replace(' ', '')) + if time_stamp is not None: + ts_lists = time_stamp + + # all alpha characters + elif isAllAlpha(middle_lists): + ts_flag = True + for i, ch in enumerate(middle_lists): + if ts_flag and time_stamp is not None: + begin = time_stamp[i][0] + end = time_stamp[i][1] + word = '' + if '@@' in ch: + word = ch.replace('@@', '') + word_item += word + if time_stamp is not None: + ts_flag = False + end = time_stamp[i][1] + else: + word_item += ch + word_lists.append(word_item) + word_lists.append(' ') + word_item = '' + if time_stamp is not None: + ts_flag = True + end = time_stamp[i][1] + ts_lists.append([begin, end]) + begin = end + + # mix characters + else: + alpha_blank = False + ts_flag = True + begin = -1 + end = -1 + for i, ch in enumerate(middle_lists): + if ts_flag and time_stamp is not None: + begin = time_stamp[i][0] + end = time_stamp[i][1] + word = '' + if isAllChinese(ch): + if alpha_blank is True: + word_lists.pop() + word_lists.append(ch) + alpha_blank = False + if time_stamp is not None: + ts_flag = True + ts_lists.append([begin, end]) + begin = end + elif '@@' in ch: + word = ch.replace('@@', '') + word_item += word + alpha_blank = False + if time_stamp is not None: + ts_flag = False + end = time_stamp[i][1] + elif isAllAlpha(ch): + word_item += ch + word_lists.append(word_item) + word_lists.append(' ') + word_item = '' + alpha_blank = True + if time_stamp is not None: + ts_flag = True + end = time_stamp[i][1] + ts_lists.append([begin, end]) + begin = end + else: + raise ValueError('invalid character: {}'.format(ch)) + + if time_stamp is not None: + word_lists, ts_lists = abbr_dispose(word_lists, ts_lists) + real_word_lists = [] + for ch in word_lists: + if ch != ' ': + real_word_lists.append(ch) + sentence = ' '.join(real_word_lists).strip() + return sentence, ts_lists, real_word_lists + else: + word_lists = abbr_dispose(word_lists) + real_word_lists = [] + for ch in word_lists: + if ch != ' ': + real_word_lists.append(ch) + sentence = ''.join(word_lists).strip() + return sentence, real_word_lists diff --git a/funasr/runtime/python/libtorch/torch_paraformer/utils/utils.py b/funasr/runtime/python/libtorch/torch_paraformer/utils/utils.py new file mode 100644 index 000000000..2f09de8c9 --- /dev/null +++ b/funasr/runtime/python/libtorch/torch_paraformer/utils/utils.py @@ -0,0 +1,165 @@ +# -*- encoding: utf-8 -*- + +import functools +import logging +import pickle +from pathlib import Path +from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union + +import numpy as np +import yaml + +from typeguard import check_argument_types + +import warnings + +root_dir = Path(__file__).resolve().parent + +logger_initialized = {} + + +class TokenIDConverter(): + def __init__(self, token_list: Union[List, str], + ): + check_argument_types() + + # self.token_list = self.load_token(token_path) + self.token_list = token_list + self.unk_symbol = token_list[-1] + + def get_num_vocabulary_size(self) -> int: + return len(self.token_list) + + def ids2tokens(self, + integers: Union[np.ndarray, Iterable[int]]) -> List[str]: + if isinstance(integers, np.ndarray) and integers.ndim != 1: + raise TokenIDConverterError( + f"Must be 1 dim ndarray, but got {integers.ndim}") + return [self.token_list[i] for i in integers] + + def tokens2ids(self, tokens: Iterable[str]) -> List[int]: + token2id = {v: i for i, v in enumerate(self.token_list)} + if self.unk_symbol not in token2id: + raise TokenIDConverterError( + f"Unknown symbol '{self.unk_symbol}' doesn't exist in the token_list" + ) + unk_id = token2id[self.unk_symbol] + return [token2id.get(i, unk_id) for i in tokens] + + +class CharTokenizer(): + def __init__( + self, + symbol_value: Union[Path, str, Iterable[str]] = None, + space_symbol: str = "", + remove_non_linguistic_symbols: bool = False, + ): + check_argument_types() + + self.space_symbol = space_symbol + self.non_linguistic_symbols = self.load_symbols(symbol_value) + self.remove_non_linguistic_symbols = remove_non_linguistic_symbols + + @staticmethod + def load_symbols(value: Union[Path, str, Iterable[str]] = None) -> Set: + if value is None: + return set() + + if isinstance(value, Iterable[str]): + return set(value) + + file_path = Path(value) + if not file_path.exists(): + logging.warning("%s doesn't exist.", file_path) + return set() + + with file_path.open("r", encoding="utf-8") as f: + return set(line.rstrip() for line in f) + + def text2tokens(self, line: Union[str, list]) -> List[str]: + tokens = [] + while len(line) != 0: + for w in self.non_linguistic_symbols: + if line.startswith(w): + if not self.remove_non_linguistic_symbols: + tokens.append(line[: len(w)]) + line = line[len(w):] + break + else: + t = line[0] + if t == " ": + t = "" + tokens.append(t) + line = line[1:] + return tokens + + def tokens2text(self, tokens: Iterable[str]) -> str: + tokens = [t if t != self.space_symbol else " " for t in tokens] + return "".join(tokens) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(" + f'space_symbol="{self.space_symbol}"' + f'non_linguistic_symbols="{self.non_linguistic_symbols}"' + f")" + ) + + + +class Hypothesis(NamedTuple): + """Hypothesis data type.""" + + yseq: np.ndarray + score: Union[float, np.ndarray] = 0 + scores: Dict[str, Union[float, np.ndarray]] = dict() + states: Dict[str, Any] = dict() + + def asdict(self) -> dict: + """Convert data to JSON-friendly dict.""" + return self._replace( + yseq=self.yseq.tolist(), + score=float(self.score), + scores={k: float(v) for k, v in self.scores.items()}, + )._asdict() + + +def read_yaml(yaml_path: Union[str, Path]) -> Dict: + if not Path(yaml_path).exists(): + raise FileExistsError(f'The {yaml_path} does not exist.') + + with open(str(yaml_path), 'rb') as f: + data = yaml.load(f, Loader=yaml.Loader) + return data + + +@functools.lru_cache() +def get_logger(name='torch_paraformer'): + """Initialize and get a logger by name. + If the logger has not been initialized, this method will initialize the + logger by adding one or two handlers, otherwise the initialized logger will + be directly returned. During initialization, a StreamHandler will always be + added. + Args: + name (str): Logger name. + Returns: + logging.Logger: The expected logger. + """ + logger = logging.getLogger(name) + if name in logger_initialized: + return logger + + for logger_name in logger_initialized: + if name.startswith(logger_name): + return logger + + formatter = logging.Formatter( + '[%(asctime)s] %(name)s %(levelname)s: %(message)s', + datefmt="%Y/%m/%d %H:%M:%S") + + sh = logging.StreamHandler() + sh.setFormatter(formatter) + logger.addHandler(sh) + logger_initialized[name] = True + logger.propagate = False + return logger From 548153260b27b28bfdc880472e382e3418a05be3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Thu, 2 Mar 2023 20:20:44 +0800 Subject: [PATCH 2/3] torchscripts --- funasr/export/test_torchscripts.py | 2 +- funasr/runtime/python/libtorch/setup.py | 4 ++-- .../libtorch/torch_paraformer/paraformer_bin.py | 11 ++++++----- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/funasr/export/test_torchscripts.py b/funasr/export/test_torchscripts.py index 11be76325..9afec745d 100644 --- a/funasr/export/test_torchscripts.py +++ b/funasr/export/test_torchscripts.py @@ -2,7 +2,7 @@ import torch import numpy as np if __name__ == '__main__': - onnx_path = "/mnt/workspace/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/model.torchscripts" + onnx_path = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/model.torchscripts" loaded = torch.jit.load(onnx_path) x = torch.rand([2, 21, 560]) diff --git a/funasr/runtime/python/libtorch/setup.py b/funasr/runtime/python/libtorch/setup.py index 99d8b5201..0f9e40d24 100644 --- a/funasr/runtime/python/libtorch/setup.py +++ b/funasr/runtime/python/libtorch/setup.py @@ -1,7 +1,7 @@ # -*- encoding: utf-8 -*- from pathlib import Path import setuptools - +from setuptools import find_packages def get_readme(): root_dir = Path(__file__).resolve().parent @@ -29,7 +29,7 @@ setuptools.setup( "scipy", "numpy>=1.19.3", "typeguard", "kaldi-native-fbank", "PyYAML>=5.1.2"], - packages=['torch_paraformer'], + packages=find_packages(include=["torch_paraformer*"]), keywords=[ 'funasr,paraformer' ], diff --git a/funasr/runtime/python/libtorch/torch_paraformer/paraformer_bin.py b/funasr/runtime/python/libtorch/torch_paraformer/paraformer_bin.py index ca9055887..159e3944a 100644 --- a/funasr/runtime/python/libtorch/torch_paraformer/paraformer_bin.py +++ b/funasr/runtime/python/libtorch/torch_paraformer/paraformer_bin.py @@ -27,7 +27,7 @@ class Paraformer(): if not Path(model_dir).exists(): raise FileNotFoundError(f'{model_dir} does not exist.') - model_file = os.path.join(model_dir, 'model.onnx') + model_file = os.path.join(model_dir, 'model.torchscripts') config_file = os.path.join(model_dir, 'config.yaml') cmvn_file = os.path.join(model_dir, 'am.mvn') config = read_yaml(config_file) @@ -52,9 +52,8 @@ class Paraformer(): feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx]) try: - outputs = self.infer(feats, feats_len) - outs = outputs[0], outputs[1] - am_scores, valid_token_lens = outs[0], outs[1] + outputs = self.ort_infer(feats, feats_len) + am_scores, valid_token_lens = outputs[0], outputs[1] if len(outputs) == 4: # for BiCifParaformer Inference us_alphas, us_cif_peak = outputs[2], outputs[3] @@ -65,7 +64,7 @@ class Paraformer(): logging.warning("input wav is silence or noise") preds = [''] else: - am_scores, valid_token_lens = am_scores.cpu().numpy(), valid_token_lens.cpu().numpy() + am_scores, valid_token_lens = am_scores.detach().cpu().numpy(), valid_token_lens.detach().cpu().numpy() preds, raw_token = self.decode(am_scores, valid_token_lens)[0] res['preds'] = preds if us_cif_peak is not None: @@ -105,6 +104,8 @@ class Paraformer(): feats = self.pad_feats(feats, np.max(feats_len)) feats_len = np.array(feats_len).astype(np.int32) + feats = torch.from_numpy(feats).type(torch.float32) + feats_len = torch.from_numpy(feats_len).type(torch.int32) return feats, feats_len @staticmethod From ec3ccbea9ff1d869becaa2b13255d0da1e4bf3ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Thu, 2 Mar 2023 20:23:39 +0800 Subject: [PATCH 3/3] torchscripts --- funasr/runtime/python/libtorch/README.md | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/funasr/runtime/python/libtorch/README.md b/funasr/runtime/python/libtorch/README.md index 33a56afca..b3d31110b 100644 --- a/funasr/runtime/python/libtorch/README.md +++ b/funasr/runtime/python/libtorch/README.md @@ -19,25 +19,21 @@ - `e.g.`, Export model from modelscope ```shell - python -m funasr.export.export_model 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' "./export" true + python -m funasr.export.export_model 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' "./export" false ``` - `e.g.`, Export model from local path, the model'name must be `model.pb`. ```shell - python -m funasr.export.export_model '/mnt/workspace/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' "./export" true + python -m funasr.export.export_model '/mnt/workspace/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch' "./export" false ``` 2. Install the `torch_paraformer`. - - Build the torch_paraformer `whl` - ```shell - git clone https://github.com/alibaba/FunASR.git && cd FunASR - cd funasr/runtime/python/libtorch - python setup.py bdist_wheel - ``` - - Install the build `whl` - ```bash - pip install dist/torch_paraformer-0.0.1-py3-none-any.whl - ``` + ```shell + git clone https://github.com/alibaba/FunASR.git && cd FunASR + cd funasr/runtime/python/libtorch + python setup.py install + ``` + 3. Run the demo. - Model_dir: the model path, which contains `model.torchscripts`, `config.yaml`, `am.mvn`.