python runtime

This commit is contained in:
游雁 2024-07-22 19:02:15 +08:00
parent 03911c82ec
commit aa3fe1a353
7 changed files with 253 additions and 70 deletions

View File

@ -0,0 +1,19 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/FunAudioLLM/SenseVoice). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
from pathlib import Path
from funasr_onnx import SenseVoiceSmall
from funasr_onnx.utils.postprocess_utils import rich_transcription_postprocess
model_dir = "iic/SenseVoiceSmall"
model = SenseVoiceSmall(model_dir, batch_size=10, quantize=False)
# inference
wav_or_scp = ["{}/.cache/modelscope/hub/{}/example/en.mp3".format(Path.home(), model_dir)]
res = model(wav_or_scp, language="auto", use_itn=True)
print([rich_transcription_postprocess(i) for i in res])

View File

@ -1,38 +0,0 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/FunAudioLLM/SenseVoice). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
import os
import torch
from pathlib import Path
from funasr import AutoModel
from funasr_onnx import SenseVoiceSmallONNX as SenseVoiceSmall
from funasr.utils.postprocess_utils import rich_transcription_postprocess
model_dir = "iic/SenseVoiceSmall"
model = AutoModel(
model=model_dir,
device="cuda:0",
)
res = model.export(type="onnx", quantize=False)
# export model init
model_path = "{}/.cache/modelscope/hub/{}".format(Path.home(), model_dir)
model_bin = SenseVoiceSmall(model_path)
# build tokenizer
try:
from funasr.tokenizer.sentencepiece_tokenizer import SentencepiecesTokenizer
tokenizer = SentencepiecesTokenizer(bpemodel=os.path.join(model_path, "chn_jpn_yue_eng_ko_spectok.bpe.model"))
except:
tokenizer = None
# inference
wav_or_scp = "/Users/shixian/Downloads/asr_example_hotword.wav"
language_list = [0]
textnorm_list = [15]
res = model_bin(wav_or_scp, language_list, textnorm_list, tokenizer=tokenizer)
print([rich_transcription_postprocess(i) for i in res])

View File

@ -4,4 +4,4 @@ from .vad_bin import Fsmn_vad
from .vad_bin import Fsmn_vad_online
from .punc_bin import CT_Transformer
from .punc_bin import CT_Transformer_VadRealtime
from .sensevoice_bin import SenseVoiceSmallONNX
from .sensevoice_bin import SenseVoiceSmall

View File

@ -20,12 +20,13 @@ from .utils.utils import (
get_logger,
read_yaml,
)
from .utils.sentencepiece_tokenizer import SentencepiecesTokenizer
from .utils.frontend import WavFrontend
logging = get_logger()
class SenseVoiceSmallONNX:
class SenseVoiceSmall:
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
@ -43,45 +44,72 @@ class SenseVoiceSmallONNX:
cache_dir: str = None,
**kwargs,
):
if not Path(model_dir).exists():
try:
from modelscope.hub.snapshot_download import snapshot_download
except:
raise "You are exporting model from modelscope, please install modelscope and try it again. To install modelscope, you could:\n" "\npip3 install -U modelscope\n" "For the users in China, you could install with the command:\n" "\npip3 install -U modelscope -i https://mirror.sjtu.edu.cn/pypi/web/simple"
try:
model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
except:
raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(
model_dir
)
model_file = os.path.join(model_dir, "model.onnx")
if quantize:
model_file = os.path.join(model_dir, "model_quant.onnx")
else:
model_file = os.path.join(model_dir, "model.onnx")
if not os.path.exists(model_file):
print(".onnx does not exist, begin to export onnx")
try:
from funasr import AutoModel
except:
raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" "\npip3 install -U funasr\n" "For the users in China, you could install with the command:\n" "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
model = AutoModel(model=model_dir)
model_dir = model.export(type="onnx", quantize=quantize, **kwargs)
config_file = os.path.join(model_dir, "config.yaml")
cmvn_file = os.path.join(model_dir, "am.mvn")
config = read_yaml(config_file)
# token_list = os.path.join(model_dir, "tokens.json")
# with open(token_list, "r", encoding="utf-8") as f:
# token_list = json.load(f)
# self.converter = TokenIDConverter(token_list)
self.tokenizer = CharTokenizer()
config["frontend_conf"]['cmvn_file'] = cmvn_file
self.tokenizer = SentencepiecesTokenizer(
bpemodel=os.path.join(model_dir, "chn_jpn_yue_eng_ko_spectok.bpe.model")
)
config["frontend_conf"]["cmvn_file"] = cmvn_file
self.frontend = WavFrontend(**config["frontend_conf"])
self.ort_infer = OrtInferSession(
model_file, device_id, intra_op_num_threads=intra_op_num_threads
)
self.batch_size = batch_size
self.blank_id = 0
self.lid_dict = {"auto": 0, "zh": 3, "en": 4, "yue": 7, "ja": 11, "ko": 12, "nospeech": 13}
self.lid_int_dict = {24884: 3, 24885: 4, 24888: 7, 24892: 11, 24896: 12, 24992: 13}
self.textnorm_dict = {"withitn": 14, "woitn": 15}
self.textnorm_int_dict = {25016: 14, 25017: 15}
def __call__(self, wav_content: Union[str, np.ndarray, List[str]], **kwargs):
language = self.lid_dict[kwargs.get("language", "auto")]
use_itn = kwargs.get("use_itn", False)
textnorm = kwargs.get("text_norm", None)
if textnorm is None:
textnorm = "withitn" if use_itn else "woitn"
textnorm = self.textnorm_dict[textnorm]
def __call__(self,
wav_content: Union[str, np.ndarray, List[str]],
language: List,
textnorm: List,
tokenizer=None,
**kwargs) -> List:
waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
waveform_nums = len(waveform_list)
asr_res = []
for beg_idx in range(0, waveform_nums, self.batch_size):
end_idx = min(waveform_nums, beg_idx + self.batch_size)
feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
ctc_logits, encoder_out_lens = self.infer(feats,
feats_len,
np.array(language, dtype=np.int32),
np.array(textnorm, dtype=np.int32)
)
ctc_logits, encoder_out_lens = self.infer(
feats,
feats_len,
np.array(language, dtype=np.int32),
np.array(textnorm, dtype=np.int32),
)
# back to torch.Tensor
ctc_logits = torch.from_numpy(ctc_logits).float()
# support batch_size=1 only currently
@ -91,11 +119,9 @@ class SenseVoiceSmallONNX:
mask = yseq != self.blank_id
token_int = yseq[mask].tolist()
if tokenizer is not None:
asr_res.append(tokenizer.tokens2text(token_int))
else:
asr_res.append(token_int)
asr_res.append(self.tokenizer.encode(token_int))
return asr_res
def load_data(self, wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
@ -136,10 +162,12 @@ class SenseVoiceSmallONNX:
feats = np.array(feat_res).astype(np.float32)
return feats
def infer(self,
feats: np.ndarray,
feats_len: np.ndarray,
language: np.ndarray,
textnorm: np.ndarray,) -> Tuple[np.ndarray, np.ndarray]:
def infer(
self,
feats: np.ndarray,
feats_len: np.ndarray,
language: np.ndarray,
textnorm: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray]:
outputs = self.ort_infer([feats, feats_len, language, textnorm])
return outputs

View File

@ -296,3 +296,123 @@ def sentence_postprocess_sentencepiece(words):
real_word_lists.append(ch)
sentence = "".join(word_lists)
return sentence, real_word_lists
emo_dict = {
"<|HAPPY|>": "😊",
"<|SAD|>": "😔",
"<|ANGRY|>": "😡",
"<|NEUTRAL|>": "",
"<|FEARFUL|>": "😰",
"<|DISGUSTED|>": "🤢",
"<|SURPRISED|>": "😮",
}
event_dict = {
"<|BGM|>": "🎼",
"<|Speech|>": "",
"<|Applause|>": "👏",
"<|Laughter|>": "😀",
"<|Cry|>": "😭",
"<|Sneeze|>": "🤧",
"<|Breath|>": "",
"<|Cough|>": "🤧",
}
lang_dict = {
"<|zh|>": "<|lang|>",
"<|en|>": "<|lang|>",
"<|yue|>": "<|lang|>",
"<|ja|>": "<|lang|>",
"<|ko|>": "<|lang|>",
"<|nospeech|>": "<|lang|>",
}
emoji_dict = {
"<|nospeech|><|Event_UNK|>": "",
"<|zh|>": "",
"<|en|>": "",
"<|yue|>": "",
"<|ja|>": "",
"<|ko|>": "",
"<|nospeech|>": "",
"<|HAPPY|>": "😊",
"<|SAD|>": "😔",
"<|ANGRY|>": "😡",
"<|NEUTRAL|>": "",
"<|BGM|>": "🎼",
"<|Speech|>": "",
"<|Applause|>": "👏",
"<|Laughter|>": "😀",
"<|FEARFUL|>": "😰",
"<|DISGUSTED|>": "🤢",
"<|SURPRISED|>": "😮",
"<|Cry|>": "😭",
"<|EMO_UNKNOWN|>": "",
"<|Sneeze|>": "🤧",
"<|Breath|>": "",
"<|Cough|>": "😷",
"<|Sing|>": "",
"<|Speech_Noise|>": "",
"<|withitn|>": "",
"<|woitn|>": "",
"<|GBG|>": "",
"<|Event_UNK|>": "",
}
emo_set = {"😊", "😔", "😡", "😰", "🤢", "😮"}
event_set = {
"🎼",
"👏",
"😀",
"😭",
"🤧",
"😷",
}
def format_str_v2(s):
sptk_dict = {}
for sptk in emoji_dict:
sptk_dict[sptk] = s.count(sptk)
s = s.replace(sptk, "")
emo = "<|NEUTRAL|>"
for e in emo_dict:
if sptk_dict[e] > sptk_dict[emo]:
emo = e
for e in event_dict:
if sptk_dict[e] > 0:
s = event_dict[e] + s
s = s + emo_dict[emo]
for emoji in emo_set.union(event_set):
s = s.replace(" " + emoji, emoji)
s = s.replace(emoji + " ", emoji)
return s.strip()
def rich_transcription_postprocess(s):
def get_emo(s):
return s[-1] if s[-1] in emo_set else None
def get_event(s):
return s[0] if s[0] in event_set else None
s = s.replace("<|nospeech|><|Event_UNK|>", "")
for lang in lang_dict:
s = s.replace(lang, "<|lang|>")
s_list = [format_str_v2(s_i).strip(" ") for s_i in s.split("<|lang|>")]
new_s = " " + s_list[0]
cur_ent_event = get_event(new_s)
for i in range(1, len(s_list)):
if len(s_list[i]) == 0:
continue
if get_event(s_list[i]) == cur_ent_event and get_event(s_list[i]) != None:
s_list[i] = s_list[i][1:]
# else:
cur_ent_event = get_event(s_list[i])
if get_emo(s_list[i]) != None and get_emo(s_list[i]) == get_emo(new_s):
new_s = new_s[:-1]
new_s += s_list[i].strip().lstrip()
new_s = new_s.replace("The.", " ")
return new_s.strip()

View File

@ -0,0 +1,53 @@
from pathlib import Path
from typing import Iterable
from typing import List
from typing import Union
import sentencepiece as spm
class SentencepiecesTokenizer:
def __init__(self, bpemodel: Union[Path, str], **kwargs):
super().__init__(**kwargs)
self.bpemodel = str(bpemodel)
# NOTE(kamo):
# Don't build SentencePieceProcessor in __init__()
# because it's not picklable and it may cause following error,
# "TypeError: can't pickle SwigPyObject objects",
# when giving it as argument of "multiprocessing.Process()".
self.sp = None
self._build_sentence_piece_processor()
def __repr__(self):
return f'{self.__class__.__name__}(model="{self.bpemodel}")'
def _build_sentence_piece_processor(self):
# Build SentencePieceProcessor lazily.
if self.sp is None:
self.sp = spm.SentencePieceProcessor()
self.sp.load(self.bpemodel)
def text2tokens(self, line: str) -> List[str]:
self._build_sentence_piece_processor()
return self.sp.EncodeAsPieces(line)
def tokens2text(self, tokens: Iterable[str]) -> str:
self._build_sentence_piece_processor()
return self.sp.DecodePieces(list(tokens))
def encode(self, line: str, **kwargs) -> List[int]:
self._build_sentence_piece_processor()
return self.sp.EncodeAsIds(line)
def decode(self, line: List[int], **kwargs):
self._build_sentence_piece_processor()
return self.sp.DecodeIds(line)
def get_vocab_size(self):
return self.sp.GetPieceSize()
def ids2tokens(self, *args, **kwargs):
return self.decode(*args, **kwargs)
def tokens2ids(self, *args, **kwargs):
return self.encode(*args, **kwargs)

View File

@ -31,10 +31,11 @@ setuptools.setup(
"librosa",
"onnxruntime>=1.7.0",
"scipy",
"numpy>=1.19.3",
"numpy<=1.26.4",
"kaldi-native-fbank",
"PyYAML>=5.1.2",
"onnx",
"sentencepiece",
],
packages=[MODULE_NAME, f"{MODULE_NAME}.utils"],
keywords=["funasr,asr"],