mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Dev gzf llm (#1506)
* update * update * update * update onnx * update with main (#1492) * contextual&seaco ONNX export (#1481) * contextual&seaco ONNX export * update ContextualEmbedderExport2 * update ContextualEmbedderExport2 * update code * onnx (#1482) * qwenaudio qwenaudiochat * qwenaudio qwenaudiochat * whisper * whisper * llm * llm * llm * llm * llm * llm * llm * llm * export onnx * export onnx * export onnx * dingding * dingding * llm * doc * onnx * onnx * onnx * onnx * onnx * onnx * v1.0.15 * qwenaudio * qwenaudio * issue doc * update * update * bugfix * onnx * update export calling * update codes * remove useless code * update code --------- Co-authored-by: zhifu gao <zhifu.gzf@alibaba-inc.com> * acknowledge --------- Co-authored-by: Shi Xian <40013335+R1ckShi@users.noreply.github.com> * update onnx * update onnx * train update * train update * train update * train update * punc update --------- Co-authored-by: Shi Xian <40013335+R1ckShi@users.noreply.github.com>
This commit is contained in:
parent
5023dd0422
commit
675b4605e8
@ -112,7 +112,7 @@ Notes: Support recognition of single audio file, as well as file list in Kaldi-s
|
||||
from funasr import AutoModel
|
||||
# paraformer-zh is a multi-functional asr model
|
||||
# use vad, punc, spk or not as you need
|
||||
model = AutoModel(model="paraformer-zh", vad_model="fsmn-vad", punc_model="ct-punc-c",
|
||||
model = AutoModel(model="paraformer-zh", vad_model="fsmn-vad", punc_model="ct-punc",
|
||||
# spk_model="cam++",
|
||||
)
|
||||
res = model.generate(input=f"{model.model_path}/example/asr_example.wav",
|
||||
|
||||
@ -106,7 +106,7 @@ funasr ++model=paraformer-zh ++vad_model="fsmn-vad" ++punc_model="ct-punc" ++inp
|
||||
from funasr import AutoModel
|
||||
# paraformer-zh is a multi-functional asr model
|
||||
# use vad, punc, spk or not as you need
|
||||
model = AutoModel(model="paraformer-zh", vad_model="fsmn-vad", punc_model="ct-punc-c",
|
||||
model = AutoModel(model="paraformer-zh", vad_model="fsmn-vad", punc_model="ct-punc",
|
||||
# spk_model="cam++"
|
||||
)
|
||||
res = model.generate(input=f"{model.model_path}/example/asr_example.wav",
|
||||
|
||||
@ -5,13 +5,13 @@
|
||||
|
||||
from funasr import AutoModel
|
||||
|
||||
model = AutoModel(model="damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
||||
model = AutoModel(model="iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
||||
model_revision="v2.0.4",
|
||||
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
||||
vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
||||
vad_model_revision="v2.0.4",
|
||||
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
|
||||
punc_model="iic/punc_ct-transformer_cn-en-common-vocab471067-large",
|
||||
punc_model_revision="v2.0.4",
|
||||
# spk_model="damo/speech_campplus_sv_zh-cn_16k-common",
|
||||
# spk_model="iic/speech_campplus_sv_zh-cn_16k-common",
|
||||
# spk_model_revision="v2.0.2",
|
||||
)
|
||||
|
||||
|
||||
@ -1,11 +1,12 @@
|
||||
|
||||
model="damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
|
||||
model="iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
|
||||
model_revision="v2.0.4"
|
||||
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
|
||||
vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch"
|
||||
vad_model_revision="v2.0.4"
|
||||
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
|
||||
punc_model_revision="v2.0.3"
|
||||
spk_model="damo/speech_campplus_sv_zh-cn_16k-common"
|
||||
#punc_model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
|
||||
punc_model="iic/punc_ct-transformer_cn-en-common-vocab471067-large"
|
||||
punc_model_revision="v2.0.4"
|
||||
spk_model="iic/speech_campplus_sv_zh-cn_16k-common"
|
||||
spk_model_revision="v2.0.2"
|
||||
|
||||
python funasr/bin/inference.py \
|
||||
|
||||
@ -7,7 +7,7 @@
|
||||
|
||||
from funasr import AutoModel
|
||||
|
||||
model = AutoModel(model="damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
||||
model = AutoModel(model="iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
||||
model_revision="v2.0.4", device="cpu")
|
||||
|
||||
res = model.export(type="onnx", quantize=False)
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
|
||||
from funasr import AutoModel
|
||||
|
||||
model = AutoModel(model="damo/speech_campplus_sv_zh-cn_16k-common",
|
||||
model = AutoModel(model="iic/speech_campplus_sv_zh-cn_16k-common",
|
||||
model_revision="v2.0.2",
|
||||
)
|
||||
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
|
||||
from funasr import AutoModel
|
||||
|
||||
model = AutoModel(model="damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404", model_revision="v2.0.4")
|
||||
model = AutoModel(model="iic/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404", model_revision="v2.0.4")
|
||||
|
||||
res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
|
||||
hotword='达摩院 魔搭')
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
|
||||
model="damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404"
|
||||
model="iic/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404"
|
||||
model_revision="v2.0.4"
|
||||
|
||||
python ../../../funasr/bin/inference.py \
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
|
||||
from funasr import AutoModel
|
||||
|
||||
model = AutoModel(model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", model_revision="v2.0.4")
|
||||
model = AutoModel(model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", model_revision="v2.0.4")
|
||||
|
||||
res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt")
|
||||
print(res)
|
||||
@ -13,7 +13,7 @@ print(res)
|
||||
|
||||
from funasr import AutoModel
|
||||
|
||||
model = AutoModel(model="damo/punc_ct-transformer_cn-en-common-vocab471067-large", model_revision="v2.0.4")
|
||||
model = AutoModel(model="iic/punc_ct-transformer_cn-en-common-vocab471067-large", model_revision="v2.0.4")
|
||||
|
||||
res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt")
|
||||
print(res)
|
||||
@ -1,8 +1,8 @@
|
||||
|
||||
#model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
|
||||
#model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
|
||||
#model_revision="v2.0.4"
|
||||
|
||||
model="damo/punc_ct-transformer_cn-en-common-vocab471067-large"
|
||||
model="iic/punc_ct-transformer_cn-en-common-vocab471067-large"
|
||||
model_revision="v2.0.4"
|
||||
|
||||
python funasr/bin/inference.py \
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
|
||||
model="damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727"
|
||||
model="iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727"
|
||||
model_revision="v2.0.4"
|
||||
|
||||
python funasr/bin/inference.py \
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
|
||||
from funasr import AutoModel
|
||||
|
||||
# model="damo/emotion2vec_base"
|
||||
# model="iic/emotion2vec_base"
|
||||
model = AutoModel(model="iic/emotion2vec_base_finetuned", model_revision="v2.0.4")
|
||||
|
||||
wav_file = f"{model.model_path}/example/test.wav"
|
||||
|
||||
@ -6,7 +6,7 @@
|
||||
from funasr import AutoModel
|
||||
wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav"
|
||||
|
||||
model = AutoModel(model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", model_revision="v2.0.4")
|
||||
model = AutoModel(model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch", model_revision="v2.0.4")
|
||||
|
||||
res = model.generate(input=wav_file)
|
||||
print(res)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
|
||||
|
||||
model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
|
||||
model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch"
|
||||
model_revision="v2.0.4"
|
||||
|
||||
python funasr/bin/inference.py \
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
|
||||
from funasr import AutoModel
|
||||
|
||||
model = AutoModel(model="damo/speech_timestamp_prediction-v1-16k-offline", model_revision="v2.0.4")
|
||||
model = AutoModel(model="iic/speech_timestamp_prediction-v1-16k-offline", model_revision="v2.0.4")
|
||||
|
||||
res = model.generate(input=("https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
|
||||
"欢迎大家来到魔搭社区进行体验"),
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
|
||||
model="damo/speech_timestamp_prediction-v1-16k-offline"
|
||||
model="iic/speech_timestamp_prediction-v1-16k-offline"
|
||||
model_revision="v2.0.4"
|
||||
|
||||
python funasr/bin/inference.py \
|
||||
|
||||
@ -5,13 +5,13 @@
|
||||
|
||||
from funasr import AutoModel
|
||||
|
||||
model = AutoModel(model="damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
||||
model = AutoModel(model="iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
||||
model_revision="v2.0.4",
|
||||
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
||||
vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
||||
vad_model_revision="v2.0.4",
|
||||
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
|
||||
punc_model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
|
||||
punc_model_revision="v2.0.4",
|
||||
spk_model="damo/speech_campplus_sv_zh-cn_16k-common",
|
||||
spk_model="iic/speech_campplus_sv_zh-cn_16k-common",
|
||||
spk_model_revision="v2.0.2"
|
||||
)
|
||||
|
||||
|
||||
@ -1,11 +1,11 @@
|
||||
|
||||
model="damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
|
||||
model="iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
|
||||
model_revision="v2.0.4"
|
||||
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
|
||||
vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch"
|
||||
vad_model_revision="v2.0.4"
|
||||
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
|
||||
punc_model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
|
||||
punc_model_revision="v2.0.4"
|
||||
spk_model="damo/speech_campplus_sv_zh-cn_16k-common"
|
||||
spk_model="iic/speech_campplus_sv_zh-cn_16k-common"
|
||||
spk_model_revision="v2.0.2"
|
||||
|
||||
python funasr/bin/inference.py \
|
||||
|
||||
@ -22,7 +22,7 @@ print(res)
|
||||
''' can not use currently
|
||||
from funasr import AutoFrontend
|
||||
|
||||
frontend = AutoFrontend(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", model_revision="v2.0.4")
|
||||
frontend = AutoFrontend(model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", model_revision="v2.0.4")
|
||||
|
||||
fbanks = frontend(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", batch_size=2)
|
||||
|
||||
|
||||
@ -32,7 +32,7 @@ torchrun \
|
||||
--nnodes 1 \
|
||||
--nproc_per_node ${gpu_num} \
|
||||
funasr/bin/train.py \
|
||||
++model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
|
||||
++model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
|
||||
++model_revision="v2.0.4" \
|
||||
++train_data_set_list="${train_data}" \
|
||||
++valid_data_set_list="${val_data}" \
|
||||
|
||||
@ -8,7 +8,7 @@ input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr
|
||||
|
||||
output_dir="./outputs/debug"
|
||||
|
||||
model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
|
||||
model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
|
||||
model_revision="v2.0.4"
|
||||
|
||||
device="cuda:0" # "cuda:0" for gpu0, "cuda:1" for gpu1, "cpu"
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
|
||||
model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online"
|
||||
model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online"
|
||||
model_revision="v2.0.4"
|
||||
|
||||
python funasr/bin/inference.py \
|
||||
|
||||
@ -7,11 +7,11 @@ from funasr import AutoModel
|
||||
|
||||
model = AutoModel(model="iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
|
||||
model_revision="v2.0.4",
|
||||
# vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
||||
# vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
|
||||
# vad_model_revision="v2.0.4",
|
||||
# punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
|
||||
# punc_model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
|
||||
# punc_model_revision="v2.0.4",
|
||||
# spk_model="damo/speech_campplus_sv_zh-cn_16k-common",
|
||||
# spk_model="iic/speech_campplus_sv_zh-cn_16k-common",
|
||||
# spk_model_revision="v2.0.2",
|
||||
)
|
||||
|
||||
|
||||
@ -1,9 +1,9 @@
|
||||
|
||||
model="damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
|
||||
model="iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
|
||||
model_revision="v2.0.4"
|
||||
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
|
||||
vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch"
|
||||
vad_model_revision="v2.0.4"
|
||||
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
|
||||
punc_model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
|
||||
punc_model_revision="v2.0.4"
|
||||
|
||||
python funasr/bin/inference.py \
|
||||
|
||||
@ -16,7 +16,7 @@ print(res)
|
||||
''' can not use currently
|
||||
from funasr import AutoFrontend
|
||||
|
||||
frontend = AutoFrontend(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", model_revision="v2.0.4")
|
||||
frontend = AutoFrontend(model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", model_revision="v2.0.4")
|
||||
|
||||
fbanks = frontend(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", batch_size=2)
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
|
||||
model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
|
||||
model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
|
||||
model_revision="v2.0.4"
|
||||
|
||||
python funasr/bin/inference.py \
|
||||
|
||||
67
funasr/models/ct_transformer/export_meta.py
Normal file
67
funasr/models/ct_transformer/export_meta.py
Normal file
@ -0,0 +1,67 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- encoding: utf-8 -*-
|
||||
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||
# MIT License (https://opensource.org/licenses/MIT)
|
||||
|
||||
import types
|
||||
import torch
|
||||
from funasr.register import tables
|
||||
|
||||
|
||||
def export_rebuild_model(model, **kwargs):
|
||||
|
||||
is_onnx = kwargs.get("type", "onnx") == "onnx"
|
||||
encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
|
||||
model.encoder = encoder_class(model.encoder, onnx=is_onnx)
|
||||
|
||||
model.forward = types.MethodType(export_forward, model)
|
||||
model.export_dummy_inputs = types.MethodType(export_dummy_inputs, model)
|
||||
model.export_input_names = types.MethodType(export_input_names, model)
|
||||
model.export_output_names = types.MethodType(export_output_names, model)
|
||||
model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model)
|
||||
model.export_name = types.MethodType(export_name, model)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def export_forward(self, inputs: torch.Tensor, text_lengths: torch.Tensor):
|
||||
"""Compute loss value from buffer sequences.
|
||||
|
||||
Args:
|
||||
input (torch.Tensor): Input ids. (batch, len)
|
||||
hidden (torch.Tensor): Target ids. (batch, len)
|
||||
|
||||
"""
|
||||
x = self.embed(inputs)
|
||||
h, _ = self.encoder(x, text_lengths)
|
||||
y = self.decoder(h)
|
||||
return y
|
||||
|
||||
def export_dummy_inputs(self):
|
||||
length = 120
|
||||
text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length)).type(torch.int32)
|
||||
text_lengths = torch.tensor([length-20, length], dtype=torch.int32)
|
||||
return (text_indexes, text_lengths)
|
||||
|
||||
def export_input_names(self):
|
||||
return ['inputs', 'text_lengths']
|
||||
|
||||
def export_output_names(self):
|
||||
return ['logits']
|
||||
|
||||
def export_dynamic_axes(self):
|
||||
return {
|
||||
'inputs': {
|
||||
0: 'batch_size',
|
||||
1: 'feats_length'
|
||||
},
|
||||
'text_lengths': {
|
||||
0: 'batch_size',
|
||||
},
|
||||
'logits': {
|
||||
0: 'batch_size',
|
||||
1: 'logits_length'
|
||||
},
|
||||
}
|
||||
def export_name(self):
|
||||
return "model.onnx"
|
||||
@ -17,7 +17,10 @@ from funasr.train_utils.device_funcs import force_gatherable
|
||||
from funasr.utils.load_utils import load_audio_text_image_video
|
||||
from funasr.models.transformer.utils.nets_utils import make_pad_mask
|
||||
from funasr.models.ct_transformer.utils import split_to_mini_sentence, split_words
|
||||
|
||||
try:
|
||||
import jieba
|
||||
except:
|
||||
pass
|
||||
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
|
||||
from torch.cuda.amp import autocast
|
||||
else:
|
||||
@ -69,6 +72,10 @@ class CTTransformer(torch.nn.Module):
|
||||
self.sos = sos
|
||||
self.eos = eos
|
||||
self.sentence_end_id = sentence_end_id
|
||||
self.jieba_usr_dict = None
|
||||
if kwargs.get("jieba_usr_dict", None) is not None:
|
||||
jieba.load_userdict(kwargs["jieba_usr_dict"])
|
||||
self.jieba_usr_dict = jieba
|
||||
|
||||
|
||||
|
||||
@ -237,14 +244,8 @@ class CTTransformer(torch.nn.Module):
|
||||
# text = data_in[0]
|
||||
# text_lengths = data_lengths[0] if data_lengths is not None else None
|
||||
split_size = kwargs.get("split_size", 20)
|
||||
|
||||
jieba_usr_dict = kwargs.get("jieba_usr_dict", None)
|
||||
if jieba_usr_dict and isinstance(jieba_usr_dict, str):
|
||||
import jieba
|
||||
jieba.load_userdict(jieba_usr_dict)
|
||||
jieba_usr_dict = jieba
|
||||
kwargs["jieba_usr_dict"] = "jieba_usr_dict"
|
||||
tokens = split_words(text, jieba_usr_dict=jieba_usr_dict)
|
||||
|
||||
tokens = split_words(text, jieba_usr_dict=self.jieba_usr_dict)
|
||||
tokens_int = tokenizer.encode(tokens)
|
||||
|
||||
mini_sentences = split_to_mini_sentence(tokens, split_size)
|
||||
@ -347,7 +348,7 @@ class CTTransformer(torch.nn.Module):
|
||||
else:
|
||||
punc_array = torch.cat([punc_array, punctuations], dim=0)
|
||||
# post processing when using word level punc model
|
||||
if jieba_usr_dict:
|
||||
if self.jieba_usr_dict is not None:
|
||||
len_tokens = len(tokens)
|
||||
new_punc_array = copy.copy(punc_array).tolist()
|
||||
# for i, (token, punc_id) in enumerate(zip(tokens[::-1], punc_array.tolist()[::-1])):
|
||||
@ -364,57 +365,10 @@ class CTTransformer(torch.nn.Module):
|
||||
results.append(result_i)
|
||||
return results, meta_data
|
||||
|
||||
def export(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
def export(self, **kwargs):
|
||||
|
||||
from .export_meta import export_rebuild_model
|
||||
models = export_rebuild_model(model=self, **kwargs)
|
||||
return models
|
||||
|
||||
is_onnx = kwargs.get("type", "onnx") == "onnx"
|
||||
encoder_class = tables.encoder_classes.get(kwargs["encoder"]+"Export")
|
||||
self.encoder = encoder_class(self.encoder, onnx=is_onnx)
|
||||
|
||||
self.forward = self.export_forward
|
||||
|
||||
return self
|
||||
|
||||
def export_forward(self, inputs: torch.Tensor, text_lengths: torch.Tensor):
|
||||
"""Compute loss value from buffer sequences.
|
||||
|
||||
Args:
|
||||
input (torch.Tensor): Input ids. (batch, len)
|
||||
hidden (torch.Tensor): Target ids. (batch, len)
|
||||
|
||||
"""
|
||||
x = self.embed(inputs)
|
||||
h, _ = self.encoder(x, text_lengths)
|
||||
y = self.decoder(h)
|
||||
return y
|
||||
|
||||
def export_dummy_inputs(self):
|
||||
length = 120
|
||||
text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length)).type(torch.int32)
|
||||
text_lengths = torch.tensor([length-20, length], dtype=torch.int32)
|
||||
return (text_indexes, text_lengths)
|
||||
|
||||
def export_input_names(self):
|
||||
return ['inputs', 'text_lengths']
|
||||
|
||||
def export_output_names(self):
|
||||
return ['logits']
|
||||
|
||||
def export_dynamic_axes(self):
|
||||
return {
|
||||
'inputs': {
|
||||
0: 'batch_size',
|
||||
1: 'feats_length'
|
||||
},
|
||||
'text_lengths': {
|
||||
0: 'batch_size',
|
||||
},
|
||||
'logits': {
|
||||
0: 'batch_size',
|
||||
1: 'logits_length'
|
||||
},
|
||||
}
|
||||
def export_name(self):
|
||||
return "model.onnx"
|
||||
77
funasr/models/ct_transformer_streaming/export_meta.py
Normal file
77
funasr/models/ct_transformer_streaming/export_meta.py
Normal file
@ -0,0 +1,77 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- encoding: utf-8 -*-
|
||||
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||
# MIT License (https://opensource.org/licenses/MIT)
|
||||
|
||||
import types
|
||||
import torch
|
||||
from funasr.register import tables
|
||||
|
||||
|
||||
def export_rebuild_model(model, **kwargs):
|
||||
|
||||
is_onnx = kwargs.get("type", "onnx") == "onnx"
|
||||
encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
|
||||
model.encoder = encoder_class(model.encoder, onnx=is_onnx)
|
||||
|
||||
model.forward = types.MethodType(export_forward, model)
|
||||
model.export_dummy_inputs = types.MethodType(export_dummy_inputs, model)
|
||||
model.export_input_names = types.MethodType(export_input_names, model)
|
||||
model.export_output_names = types.MethodType(export_output_names, model)
|
||||
model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model)
|
||||
model.export_name = types.MethodType(export_name, model)
|
||||
|
||||
return model
|
||||
|
||||
def export_forward(self, inputs: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
vad_indexes: torch.Tensor,
|
||||
sub_masks: torch.Tensor,
|
||||
):
|
||||
"""Compute loss value from buffer sequences.
|
||||
|
||||
Args:
|
||||
input (torch.Tensor): Input ids. (batch, len)
|
||||
hidden (torch.Tensor): Target ids. (batch, len)
|
||||
|
||||
"""
|
||||
x = self.embed(inputs)
|
||||
# mask = self._target_mask(input)
|
||||
h, _ = self.encoder(x, text_lengths, vad_indexes, sub_masks)
|
||||
y = self.decoder(h)
|
||||
return y
|
||||
|
||||
def export_dummy_inputs(self):
|
||||
length = 120
|
||||
text_indexes = torch.randint(0, self.embed.num_embeddings, (1, length)).type(torch.int32)
|
||||
text_lengths = torch.tensor([length], dtype=torch.int32)
|
||||
vad_mask = torch.ones(length, length, dtype=torch.float32)[None, None, :, :]
|
||||
sub_masks = torch.ones(length, length, dtype=torch.float32)
|
||||
sub_masks = torch.tril(sub_masks).type(torch.float32)
|
||||
return (text_indexes, text_lengths, vad_mask, sub_masks[None, None, :, :])
|
||||
|
||||
def export_input_names(self):
|
||||
return ['inputs', 'text_lengths', 'vad_masks', 'sub_masks']
|
||||
|
||||
def export_output_names(self):
|
||||
return ['logits']
|
||||
|
||||
def export_dynamic_axes(self):
|
||||
return {
|
||||
'inputs': {
|
||||
1: 'feats_length'
|
||||
},
|
||||
'vad_masks': {
|
||||
2: 'feats_length1',
|
||||
3: 'feats_length2'
|
||||
},
|
||||
'sub_masks': {
|
||||
2: 'feats_length1',
|
||||
3: 'feats_length2'
|
||||
},
|
||||
'logits': {
|
||||
1: 'logits_length'
|
||||
},
|
||||
}
|
||||
def export_name(self):
|
||||
return "model.onnx"
|
||||
@ -173,68 +173,9 @@ class CTTransformerStreaming(CTTransformer):
|
||||
|
||||
return results, meta_data
|
||||
|
||||
def export(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
def export(self, **kwargs):
|
||||
|
||||
is_onnx = kwargs.get("type", "onnx") == "onnx"
|
||||
encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
|
||||
self.encoder = encoder_class(self.encoder, onnx=is_onnx)
|
||||
|
||||
self.forward = self.export_forward
|
||||
|
||||
return self
|
||||
from .export_meta import export_rebuild_model
|
||||
models = export_rebuild_model(model=self, **kwargs)
|
||||
return models
|
||||
|
||||
def export_forward(self, inputs: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
vad_indexes: torch.Tensor,
|
||||
sub_masks: torch.Tensor,
|
||||
):
|
||||
"""Compute loss value from buffer sequences.
|
||||
|
||||
Args:
|
||||
input (torch.Tensor): Input ids. (batch, len)
|
||||
hidden (torch.Tensor): Target ids. (batch, len)
|
||||
|
||||
"""
|
||||
x = self.embed(inputs)
|
||||
# mask = self._target_mask(input)
|
||||
h, _ = self.encoder(x, text_lengths, vad_indexes, sub_masks)
|
||||
y = self.decoder(h)
|
||||
return y
|
||||
|
||||
def export_dummy_inputs(self):
|
||||
length = 120
|
||||
text_indexes = torch.randint(0, self.embed.num_embeddings, (1, length)).type(torch.int32)
|
||||
text_lengths = torch.tensor([length], dtype=torch.int32)
|
||||
vad_mask = torch.ones(length, length, dtype=torch.float32)[None, None, :, :]
|
||||
sub_masks = torch.ones(length, length, dtype=torch.float32)
|
||||
sub_masks = torch.tril(sub_masks).type(torch.float32)
|
||||
return (text_indexes, text_lengths, vad_mask, sub_masks[None, None, :, :])
|
||||
|
||||
def export_input_names(self):
|
||||
return ['inputs', 'text_lengths', 'vad_masks', 'sub_masks']
|
||||
|
||||
def export_output_names(self):
|
||||
return ['logits']
|
||||
|
||||
def export_dynamic_axes(self):
|
||||
return {
|
||||
'inputs': {
|
||||
1: 'feats_length'
|
||||
},
|
||||
'vad_masks': {
|
||||
2: 'feats_length1',
|
||||
3: 'feats_length2'
|
||||
},
|
||||
'sub_masks': {
|
||||
2: 'feats_length1',
|
||||
3: 'feats_length2'
|
||||
},
|
||||
'logits': {
|
||||
1: 'logits_length'
|
||||
},
|
||||
}
|
||||
def export_name(self):
|
||||
return "model.onnx"
|
||||
|
||||
8
setup.py
8
setup.py
@ -14,16 +14,14 @@ requirements = {
|
||||
"librosa",
|
||||
"jamo", # For kss
|
||||
"PyYAML>=5.1.2",
|
||||
# "soundfile>=0.12.1",
|
||||
"soundfile>=0.12.1",
|
||||
"kaldiio>=2.17.0",
|
||||
"torch_complex",
|
||||
# "nltk>=3.4.5",
|
||||
# ASR
|
||||
"sentencepiece", # train
|
||||
"jieba",
|
||||
# "rotary_embedding_torch",
|
||||
"rotary_embedding_torch",
|
||||
# "ffmpeg-python",
|
||||
# TTS
|
||||
# "pypinyin>=0.44.0",
|
||||
# "espnet_tts_frontend",
|
||||
# ENH
|
||||
@ -54,6 +52,7 @@ requirements = {
|
||||
"torch_optimizer",
|
||||
"fairscale",
|
||||
"transformers",
|
||||
"openai-whisper"
|
||||
],
|
||||
"setup": [
|
||||
"numpy",
|
||||
@ -96,6 +95,7 @@ requirements = {
|
||||
],
|
||||
}
|
||||
requirements["all"].extend(requirements["train"])
|
||||
requirements["all"].extend(requirements["llm"])
|
||||
requirements["test"].extend(requirements["train"])
|
||||
|
||||
install_requires = requirements["install"]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user