mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update libtorch infer
This commit is contained in:
parent
509d09f50d
commit
f591f33111
@ -20,10 +20,12 @@ def export(model, data_in=None, quantize: bool = False, opset_version: int = 14,
|
||||
export_dir=export_dir,
|
||||
**kwargs
|
||||
)
|
||||
elif type == 'torchscript':
|
||||
elif type == 'torchscripts':
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
_torchscripts(
|
||||
m,
|
||||
path=export_dir,
|
||||
device=device
|
||||
)
|
||||
print("output dir: {}".format(export_dir))
|
||||
|
||||
@ -88,6 +90,5 @@ def _torchscripts(model, path, device='cuda'):
|
||||
else:
|
||||
dummy_input = tuple([i.cuda() for i in dummy_input])
|
||||
|
||||
# model_script = torch.jit.script(model)
|
||||
model_script = torch.jit.trace(model, dummy_input)
|
||||
model_script.save(os.path.join(path, f'{model.export_name}.torchscripts'))
|
||||
|
||||
@ -1,17 +1,11 @@
|
||||
from funasr_torch import Paraformer
|
||||
from pathlib import Path
|
||||
from funasr_torch import Paraformer
|
||||
|
||||
|
||||
model_dir = (
|
||||
"iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
|
||||
)
|
||||
model_dir = "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
|
||||
|
||||
model = Paraformer(model_dir, batch_size=1) # cpu
|
||||
# model = Paraformer(model_dir, batch_size=1, device_id=0) # gpu
|
||||
|
||||
# when using paraformer-large-vad-punc model, you can set plot_timestamp_to="./xx.png" to get figure of alignment besides timestamps
|
||||
# model = Paraformer(model_dir, batch_size=1, plot_timestamp_to="test.png")
|
||||
|
||||
wav_path = "{}/.cache/modelscope/hub/{}/example/asr_example.wav".format(Path.home(), model_dir)
|
||||
|
||||
result = model(wav_path)
|
||||
|
||||
@ -1,22 +1,21 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
import json
|
||||
import copy
|
||||
import torch
|
||||
import os.path
|
||||
import librosa
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from typing import List, Union, Tuple
|
||||
|
||||
import copy
|
||||
import librosa
|
||||
import numpy as np
|
||||
|
||||
from .utils.utils import CharTokenizer, Hypothesis, TokenIDConverter, get_logger, read_yaml
|
||||
from .utils.postprocess_utils import sentence_postprocess
|
||||
from .utils.utils import pad_list
|
||||
from .utils.frontend import WavFrontend
|
||||
from .utils.timestamp_utils import time_stamp_lfr6_onnx
|
||||
from .utils.postprocess_utils import sentence_postprocess
|
||||
from .utils.utils import CharTokenizer, Hypothesis, TokenIDConverter, get_logger, read_yaml
|
||||
|
||||
logging = get_logger()
|
||||
|
||||
import torch
|
||||
import json
|
||||
|
||||
|
||||
class Paraformer:
|
||||
"""
|
||||
@ -32,7 +31,6 @@ class Paraformer:
|
||||
device_id: Union[str, int] = "-1",
|
||||
plot_timestamp_to: str = "",
|
||||
quantize: bool = False,
|
||||
intra_op_num_threads: int = 4,
|
||||
cache_dir: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
@ -236,4 +234,186 @@ class Paraformer:
|
||||
token = self.converter.ids2tokens(token_int)
|
||||
token = token[: valid_token_num - self.pred_bias]
|
||||
# texts = sentence_postprocess(token)
|
||||
return token
|
||||
return token
|
||||
|
||||
|
||||
class ContextualParaformer(Paraformer):
|
||||
"""
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
|
||||
https://arxiv.org/abs/2206.08317
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_dir: Union[str, Path] = None,
|
||||
batch_size: int = 1,
|
||||
device_id: Union[str, int] = "-1",
|
||||
plot_timestamp_to: str = "",
|
||||
quantize: bool = False,
|
||||
cache_dir: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
if not Path(model_dir).exists():
|
||||
try:
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
except:
|
||||
raise "You are exporting model from modelscope, please install modelscope and try it again. To install modelscope, you could:\n" "\npip3 install -U modelscope\n" "For the users in China, you could install with the command:\n" "\npip3 install -U modelscope -i https://mirror.sjtu.edu.cn/pypi/web/simple"
|
||||
try:
|
||||
model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
|
||||
except:
|
||||
raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(
|
||||
model_dir
|
||||
)
|
||||
|
||||
if quantize:
|
||||
model_bb_file = os.path.join(model_dir, "model_bb_quant.torchscripts")
|
||||
model_eb_file = os.path.join(model_dir, "model_eb_quant.torchscripts")
|
||||
else:
|
||||
model_bb_file = os.path.join(model_dir, "model_bb.torchscripts")
|
||||
model_eb_file = os.path.join(model_dir, "model_eb.torchscripts")
|
||||
|
||||
if not (os.path.exists(model_eb_file) and os.path.exists(model_bb_file)):
|
||||
print(".onnx is not exist, begin to export onnx")
|
||||
try:
|
||||
from funasr import AutoModel
|
||||
except:
|
||||
raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" "\npip3 install -U funasr\n" "For the users in China, you could install with the command:\n" "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
|
||||
|
||||
model = AutoModel(model=model_dir)
|
||||
model_dir = model.export(type="onnx", quantize=quantize, **kwargs)
|
||||
|
||||
config_file = os.path.join(model_dir, "config.yaml")
|
||||
cmvn_file = os.path.join(model_dir, "am.mvn")
|
||||
config = read_yaml(config_file)
|
||||
token_list = os.path.join(model_dir, "tokens.json")
|
||||
with open(token_list, "r", encoding="utf-8") as f:
|
||||
token_list = json.load(f)
|
||||
|
||||
# revert token_list into vocab dict
|
||||
self.vocab = {}
|
||||
for i, token in enumerate(token_list):
|
||||
self.vocab[token] = i
|
||||
|
||||
self.converter = TokenIDConverter(token_list)
|
||||
self.tokenizer = CharTokenizer()
|
||||
self.frontend = WavFrontend(cmvn_file=cmvn_file, **config["frontend_conf"])
|
||||
|
||||
self.ort_infer_bb = torch.jit.load(model_bb_file)
|
||||
self.ort_infer_eb = torch.jit.load(model_eb_file)
|
||||
self.device_id = device_id
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.plot_timestamp_to = plot_timestamp_to
|
||||
if "predictor_bias" in config["model_conf"].keys():
|
||||
self.pred_bias = config["model_conf"]["predictor_bias"]
|
||||
else:
|
||||
self.pred_bias = 0
|
||||
|
||||
def __call__(
|
||||
self, wav_content: Union[str, np.ndarray, List[str]], hotwords: str, **kwargs
|
||||
) -> List:
|
||||
# make hotword list
|
||||
hotwords, hotwords_length = self.proc_hotword(hotwords)
|
||||
# import pdb; pdb.set_trace()
|
||||
[bias_embed] = self.eb_infer(hotwords, hotwords_length)
|
||||
# index from bias_embed
|
||||
bias_embed = bias_embed.transpose(1, 0, 2)
|
||||
_ind = np.arange(0, len(hotwords)).tolist()
|
||||
bias_embed = bias_embed[_ind, hotwords_length.tolist()]
|
||||
waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
|
||||
waveform_nums = len(waveform_list)
|
||||
asr_res = []
|
||||
for beg_idx in range(0, waveform_nums, self.batch_size):
|
||||
end_idx = min(waveform_nums, beg_idx + self.batch_size)
|
||||
feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
|
||||
bias_embed = np.expand_dims(bias_embed, axis=0)
|
||||
bias_embed = np.repeat(bias_embed, feats.shape[0], axis=0)
|
||||
try:
|
||||
with torch.no_grad():
|
||||
if int(self.device_id) == -1:
|
||||
outputs = self.ort_infer(feats, feats_len)
|
||||
am_scores, valid_token_lens = outputs[0], outputs[1]
|
||||
else:
|
||||
outputs = self.ort_infer(feats.cuda(), feats_len.cuda())
|
||||
am_scores, valid_token_lens = outputs[0].cpu(), outputs[1].cpu()
|
||||
except:
|
||||
# logging.warning(traceback.format_exc())
|
||||
logging.warning("input wav is silence or noise")
|
||||
preds = [""]
|
||||
else:
|
||||
preds = self.decode(am_scores, valid_token_lens)
|
||||
for pred in preds:
|
||||
pred = sentence_postprocess(pred)
|
||||
asr_res.append({"preds": pred})
|
||||
return asr_res
|
||||
|
||||
def proc_hotword(self, hotwords):
|
||||
hotwords = hotwords.split(" ")
|
||||
hotwords_length = [len(i) - 1 for i in hotwords]
|
||||
hotwords_length.append(0)
|
||||
hotwords_length = np.array(hotwords_length)
|
||||
|
||||
# hotwords.append('<s>')
|
||||
def word_map(word):
|
||||
hotwords = []
|
||||
for c in word:
|
||||
if c not in self.vocab.keys():
|
||||
hotwords.append(8403)
|
||||
logging.warning(
|
||||
"oov character {} found in hotword {}, replaced by <unk>".format(c, word)
|
||||
)
|
||||
else:
|
||||
hotwords.append(self.vocab[c])
|
||||
return np.array(hotwords)
|
||||
|
||||
hotword_int = [word_map(i) for i in hotwords]
|
||||
hotword_int.append(np.array([1]))
|
||||
hotwords = pad_list(hotword_int, pad_value=0, max_len=10)
|
||||
return hotwords, hotwords_length
|
||||
|
||||
def bb_infer(
|
||||
self, feats: np.ndarray, feats_len: np.ndarray, bias_embed
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
outputs = self.ort_infer_bb([feats, feats_len, bias_embed])
|
||||
return outputs
|
||||
|
||||
def eb_infer(self, hotwords, hotwords_length):
|
||||
outputs = self.ort_infer_eb([hotwords.astype(np.int32), hotwords_length.astype(np.int32)])
|
||||
return outputs
|
||||
|
||||
def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]:
|
||||
return [
|
||||
self.decode_one(am_score, token_num)
|
||||
for am_score, token_num in zip(am_scores, token_nums)
|
||||
]
|
||||
|
||||
def decode_one(self, am_score: np.ndarray, valid_token_num: int) -> List[str]:
|
||||
yseq = am_score.argmax(axis=-1)
|
||||
score = am_score.max(axis=-1)
|
||||
score = np.sum(score, axis=-1)
|
||||
|
||||
# pad with mask tokens to ensure compatibility with sos/eos tokens
|
||||
# asr_model.sos:1 asr_model.eos:2
|
||||
yseq = np.array([1] + yseq.tolist() + [2])
|
||||
hyp = Hypothesis(yseq=yseq, score=score)
|
||||
|
||||
# remove sos/eos and get results
|
||||
last_pos = -1
|
||||
token_int = hyp.yseq[1:last_pos].tolist()
|
||||
|
||||
# remove blank symbol id, which is assumed to be 0
|
||||
token_int = list(filter(lambda x: x not in (0, 2), token_int))
|
||||
|
||||
# Change integer-ids to tokens
|
||||
token = self.converter.ids2tokens(token_int)
|
||||
token = token[: valid_token_num - self.pred_bias]
|
||||
# texts = sentence_postprocess(token)
|
||||
return token
|
||||
|
||||
|
||||
class SeacoParaformer(ContextualParaformer):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# no difference with contextual_paraformer in method of calling onnx models
|
||||
|
||||
Loading…
Reference in New Issue
Block a user