mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
vad_realtime onnx runnable
This commit is contained in:
parent
9f6445d39b
commit
79007d36f1
15
funasr/runtime/python/onnxruntime/demo_punc_online.py
Normal file
15
funasr/runtime/python/onnxruntime/demo_punc_online.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
from funasr_onnx import CT_Transformer_VadRealtime
|
||||||
|
|
||||||
|
model_dir = "/disk1/mengzhe.cmz/workspace/FunASR/funasr/export/damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727"
|
||||||
|
model = CT_Transformer_VadRealtime(model_dir)
|
||||||
|
|
||||||
|
text_in = "跨境河流是养育沿岸|人民的生命之源长期以来为帮助下游地区防灾减灾中方技术人员|在上游地区极为恶劣的自然条件下克服巨大困难甚至冒着生命危险|向印方提供汛期水文资料处理紧急事件中方重视印方在跨境河流>问题上的关切|愿意进一步完善双方联合工作机制|凡是|中方能做的我们|都会去做而且会做得更好我请印度朋友们放心中国在上游的|任何开发利用都会经过科学|规划和论证兼顾上下游的利益"
|
||||||
|
|
||||||
|
vads = text_in.split("|")
|
||||||
|
rec_result_all="outputs:"
|
||||||
|
param_dict = {"cache": []}
|
||||||
|
for vad in vads:
|
||||||
|
result = model(vad, param_dict=param_dict)
|
||||||
|
rec_result_all += result[0]
|
||||||
|
|
||||||
|
print(rec_result_all)
|
||||||
@ -2,4 +2,4 @@
|
|||||||
from .paraformer_bin import Paraformer
|
from .paraformer_bin import Paraformer
|
||||||
from .vad_bin import Fsmn_vad
|
from .vad_bin import Fsmn_vad
|
||||||
from .punc_bin import CT_Transformer
|
from .punc_bin import CT_Transformer
|
||||||
#from .punc_bin import VadRealtimeCT_Transformer
|
from .punc_bin import CT_Transformer_VadRealtime
|
||||||
|
|||||||
@ -117,3 +117,133 @@ class CT_Transformer():
|
|||||||
outputs = self.ort_infer([feats, feats_len])
|
outputs = self.ort_infer([feats, feats_len])
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
class CT_Transformer_VadRealtime(CT_Transformer):
|
||||||
|
def __init__(self, model_dir: Union[str, Path] = None,
|
||||||
|
batch_size: int = 1,
|
||||||
|
device_id: Union[str, int] = "-1",
|
||||||
|
quantize: bool = False,
|
||||||
|
intra_op_num_threads: int = 4
|
||||||
|
):
|
||||||
|
super(CT_Transformer_VadRealtime, self).__init__(model_dir, batch_size, device_id, quantize, intra_op_num_threads)
|
||||||
|
|
||||||
|
def __call__(self, text: str, param_dict: map, split_size=20):
|
||||||
|
cache_key = "cache"
|
||||||
|
assert cache_key in param_dict
|
||||||
|
cache = param_dict[cache_key]
|
||||||
|
if cache is not None and len(cache) > 0:
|
||||||
|
precache = "".join(cache)
|
||||||
|
else:
|
||||||
|
precache = ""
|
||||||
|
cache = []
|
||||||
|
full_text = precache + text
|
||||||
|
split_text = code_mix_split_words(full_text)
|
||||||
|
split_text_id = self.converter.tokens2ids(split_text)
|
||||||
|
mini_sentences = split_to_mini_sentence(split_text, split_size)
|
||||||
|
mini_sentences_id = split_to_mini_sentence(split_text_id, split_size)
|
||||||
|
new_mini_sentence_punc = []
|
||||||
|
assert len(mini_sentences) == len(mini_sentences_id)
|
||||||
|
|
||||||
|
cache_sent = []
|
||||||
|
cache_sent_id = np.array([], dtype='int32')
|
||||||
|
sentence_punc_list = []
|
||||||
|
sentence_words_list = []
|
||||||
|
cache_pop_trigger_limit = 200
|
||||||
|
skip_num = 0
|
||||||
|
for mini_sentence_i in range(len(mini_sentences)):
|
||||||
|
mini_sentence = mini_sentences[mini_sentence_i]
|
||||||
|
mini_sentence_id = mini_sentences_id[mini_sentence_i]
|
||||||
|
mini_sentence = cache_sent + mini_sentence
|
||||||
|
mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0)
|
||||||
|
text_length = len(mini_sentence_id)
|
||||||
|
data = {
|
||||||
|
"input": mini_sentence_id[None,:],
|
||||||
|
"text_lengths": np.array([text_length], dtype='int32'),
|
||||||
|
"vad_mask": self.vad_mask(text_length, len(cache) - 1)[None, None, :, :].astype(np.float32),
|
||||||
|
"sub_masks": np.tril(np.ones((text_length, text_length), dtype=np.float32))[None, None, :, :].astype(np.float32)
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
outputs = self.infer(data['input'], data['text_lengths'], data['vad_mask'], data["sub_masks"])
|
||||||
|
y = outputs[0]
|
||||||
|
punctuations = np.argmax(y,axis=-1)[0]
|
||||||
|
assert punctuations.size == len(mini_sentence)
|
||||||
|
except ONNXRuntimeError:
|
||||||
|
logging.warning("error")
|
||||||
|
|
||||||
|
# Search for the last Period/QuestionMark as cache
|
||||||
|
if mini_sentence_i < len(mini_sentences) - 1:
|
||||||
|
sentenceEnd = -1
|
||||||
|
last_comma_index = -1
|
||||||
|
for i in range(len(punctuations) - 2, 1, -1):
|
||||||
|
if self.punc_list[punctuations[i]] == "。" or self.punc_list[punctuations[i]] == "?":
|
||||||
|
sentenceEnd = i
|
||||||
|
break
|
||||||
|
if last_comma_index < 0 and self.punc_list[punctuations[i]] == ",":
|
||||||
|
last_comma_index = i
|
||||||
|
|
||||||
|
if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
|
||||||
|
# The sentence it too long, cut off at a comma.
|
||||||
|
sentenceEnd = last_comma_index
|
||||||
|
punctuations[sentenceEnd] = self.period
|
||||||
|
cache_sent = mini_sentence[sentenceEnd + 1:]
|
||||||
|
cache_sent_id = mini_sentence_id[sentenceEnd + 1:]
|
||||||
|
mini_sentence = mini_sentence[0:sentenceEnd + 1]
|
||||||
|
punctuations = punctuations[0:sentenceEnd + 1]
|
||||||
|
|
||||||
|
punctuations_np = [int(x) for x in punctuations]
|
||||||
|
new_mini_sentence_punc += punctuations_np
|
||||||
|
sentence_punc_list += [self.punc_list[int(x)] for x in punctuations_np]
|
||||||
|
sentence_words_list += mini_sentence
|
||||||
|
|
||||||
|
assert len(sentence_punc_list) == len(sentence_words_list)
|
||||||
|
words_with_punc = []
|
||||||
|
sentence_punc_list_out = []
|
||||||
|
for i in range(0, len(sentence_words_list)):
|
||||||
|
if i > 0:
|
||||||
|
if len(sentence_words_list[i][0].encode()) == 1 and len(sentence_words_list[i - 1][-1].encode()) == 1:
|
||||||
|
sentence_words_list[i] = " " + sentence_words_list[i]
|
||||||
|
if skip_num < len(cache):
|
||||||
|
skip_num += 1
|
||||||
|
else:
|
||||||
|
words_with_punc.append(sentence_words_list[i])
|
||||||
|
if skip_num >= len(cache):
|
||||||
|
sentence_punc_list_out.append(sentence_punc_list[i])
|
||||||
|
if sentence_punc_list[i] != "_":
|
||||||
|
words_with_punc.append(sentence_punc_list[i])
|
||||||
|
sentence_out = "".join(words_with_punc)
|
||||||
|
|
||||||
|
sentenceEnd = -1
|
||||||
|
for i in range(len(sentence_punc_list) - 2, 1, -1):
|
||||||
|
if sentence_punc_list[i] == "。" or sentence_punc_list[i] == "?":
|
||||||
|
sentenceEnd = i
|
||||||
|
break
|
||||||
|
cache_out = sentence_words_list[sentenceEnd + 1:]
|
||||||
|
if sentence_out[-1] in self.punc_list:
|
||||||
|
sentence_out = sentence_out[:-1]
|
||||||
|
sentence_punc_list_out[-1] = "_"
|
||||||
|
param_dict[cache_key] = cache_out
|
||||||
|
return sentence_out, sentence_punc_list_out, cache_out
|
||||||
|
|
||||||
|
def vad_mask(self, size, vad_pos, dtype=np.bool):
|
||||||
|
"""Create mask for decoder self-attention.
|
||||||
|
|
||||||
|
:param int size: size of mask
|
||||||
|
:param int vad_pos: index of vad index
|
||||||
|
:param torch.dtype dtype: result dtype
|
||||||
|
:rtype: torch.Tensor (B, Lmax, Lmax)
|
||||||
|
"""
|
||||||
|
ret = np.ones((size, size), dtype=dtype)
|
||||||
|
if vad_pos <= 0 or vad_pos >= size:
|
||||||
|
return ret
|
||||||
|
sub_corner = np.zeros(
|
||||||
|
(vad_pos - 1, size - vad_pos), dtype=dtype)
|
||||||
|
ret[0:vad_pos - 1, vad_pos:] = sub_corner
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def infer(self, feats: np.ndarray,
|
||||||
|
feats_len: np.ndarray,
|
||||||
|
vad_mask: np.ndarray,
|
||||||
|
sub_masks: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
||||||
|
outputs = self.ort_infer([feats, feats_len, vad_mask, sub_masks])
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user