Dev gzf llm (#1506)

* update

* update

* update

* update onnx

* update with main (#1492)

* contextual&seaco ONNX export (#1481)

* contextual&seaco ONNX export

* update ContextualEmbedderExport2

* update ContextualEmbedderExport2

* update code

* onnx (#1482)

* qwenaudio qwenaudiochat

* qwenaudio qwenaudiochat

* whisper

* whisper

* llm

* llm

* llm

* llm

* llm

* llm

* llm

* llm

* export onnx

* export onnx

* export onnx

* dingding

* dingding

* llm

* doc

* onnx

* onnx

* onnx

* onnx

* onnx

* onnx

* v1.0.15

* qwenaudio

* qwenaudio

* issue doc

* update

* update

* bugfix

* onnx

* update export calling

* update codes

* remove useless code

* update code

---------

Co-authored-by: zhifu gao <zhifu.gzf@alibaba-inc.com>

* acknowledge

---------

Co-authored-by: Shi Xian <40013335+R1ckShi@users.noreply.github.com>

* update onnx

* update onnx

* train update

* train update

* train update

* train update

* punc update

---------

Co-authored-by: Shi Xian <40013335+R1ckShi@users.noreply.github.com>
This commit is contained in:
zhifu gao 2024-03-15 21:14:08 +08:00 committed by GitHub
parent 5023dd0422
commit 675b4605e8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
31 changed files with 214 additions and 174 deletions

View File

@ -112,7 +112,7 @@ Notes: Support recognition of single audio file, as well as file list in Kaldi-s
from funasr import AutoModel
# paraformer-zh is a multi-functional asr model
# use vad, punc, spk or not as you need
model = AutoModel(model="paraformer-zh", vad_model="fsmn-vad", punc_model="ct-punc-c",
model = AutoModel(model="paraformer-zh", vad_model="fsmn-vad", punc_model="ct-punc",
# spk_model="cam++",
)
res = model.generate(input=f"{model.model_path}/example/asr_example.wav",

View File

@ -106,7 +106,7 @@ funasr ++model=paraformer-zh ++vad_model="fsmn-vad" ++punc_model="ct-punc" ++inp
from funasr import AutoModel
# paraformer-zh is a multi-functional asr model
# use vad, punc, spk or not as you need
model = AutoModel(model="paraformer-zh", vad_model="fsmn-vad", punc_model="ct-punc-c",
model = AutoModel(model="paraformer-zh", vad_model="fsmn-vad", punc_model="ct-punc",
# spk_model="cam++"
)
res = model.generate(input=f"{model.model_path}/example/asr_example.wav",

View File

@ -5,13 +5,13 @@
from funasr import AutoModel
model = AutoModel(model="damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
model = AutoModel(model="iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
model_revision="v2.0.4",
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
vad_model_revision="v2.0.4",
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
punc_model="iic/punc_ct-transformer_cn-en-common-vocab471067-large",
punc_model_revision="v2.0.4",
# spk_model="damo/speech_campplus_sv_zh-cn_16k-common",
# spk_model="iic/speech_campplus_sv_zh-cn_16k-common",
# spk_model_revision="v2.0.2",
)

View File

@ -1,11 +1,12 @@
model="damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
model="iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
model_revision="v2.0.4"
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch"
vad_model_revision="v2.0.4"
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
punc_model_revision="v2.0.3"
spk_model="damo/speech_campplus_sv_zh-cn_16k-common"
#punc_model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
punc_model="iic/punc_ct-transformer_cn-en-common-vocab471067-large"
punc_model_revision="v2.0.4"
spk_model="iic/speech_campplus_sv_zh-cn_16k-common"
spk_model_revision="v2.0.2"
python funasr/bin/inference.py \

View File

@ -7,7 +7,7 @@
from funasr import AutoModel
model = AutoModel(model="damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
model = AutoModel(model="iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
model_revision="v2.0.4", device="cpu")
res = model.export(type="onnx", quantize=False)

View File

@ -5,7 +5,7 @@
from funasr import AutoModel
model = AutoModel(model="damo/speech_campplus_sv_zh-cn_16k-common",
model = AutoModel(model="iic/speech_campplus_sv_zh-cn_16k-common",
model_revision="v2.0.2",
)

View File

@ -5,7 +5,7 @@
from funasr import AutoModel
model = AutoModel(model="damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404", model_revision="v2.0.4")
model = AutoModel(model="iic/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404", model_revision="v2.0.4")
res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
hotword='达摩院 魔搭')

View File

@ -1,5 +1,5 @@
model="damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404"
model="iic/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404"
model_revision="v2.0.4"
python ../../../funasr/bin/inference.py \

View File

@ -5,7 +5,7 @@
from funasr import AutoModel
model = AutoModel(model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", model_revision="v2.0.4")
model = AutoModel(model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", model_revision="v2.0.4")
res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt")
print(res)
@ -13,7 +13,7 @@ print(res)
from funasr import AutoModel
model = AutoModel(model="damo/punc_ct-transformer_cn-en-common-vocab471067-large", model_revision="v2.0.4")
model = AutoModel(model="iic/punc_ct-transformer_cn-en-common-vocab471067-large", model_revision="v2.0.4")
res = model.generate(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_text/punc_example.txt")
print(res)

View File

@ -1,8 +1,8 @@
#model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
#model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
#model_revision="v2.0.4"
model="damo/punc_ct-transformer_cn-en-common-vocab471067-large"
model="iic/punc_ct-transformer_cn-en-common-vocab471067-large"
model_revision="v2.0.4"
python funasr/bin/inference.py \

View File

@ -1,5 +1,5 @@
model="damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727"
model="iic/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727"
model_revision="v2.0.4"
python funasr/bin/inference.py \

View File

@ -5,7 +5,7 @@
from funasr import AutoModel
# model="damo/emotion2vec_base"
# model="iic/emotion2vec_base"
model = AutoModel(model="iic/emotion2vec_base_finetuned", model_revision="v2.0.4")
wav_file = f"{model.model_path}/example/test.wav"

View File

@ -6,7 +6,7 @@
from funasr import AutoModel
wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav"
model = AutoModel(model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", model_revision="v2.0.4")
model = AutoModel(model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch", model_revision="v2.0.4")
res = model.generate(input=wav_file)
print(res)

View File

@ -1,6 +1,6 @@
model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch"
model_revision="v2.0.4"
python funasr/bin/inference.py \

View File

@ -5,7 +5,7 @@
from funasr import AutoModel
model = AutoModel(model="damo/speech_timestamp_prediction-v1-16k-offline", model_revision="v2.0.4")
model = AutoModel(model="iic/speech_timestamp_prediction-v1-16k-offline", model_revision="v2.0.4")
res = model.generate(input=("https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
"欢迎大家来到魔搭社区进行体验"),

View File

@ -1,5 +1,5 @@
model="damo/speech_timestamp_prediction-v1-16k-offline"
model="iic/speech_timestamp_prediction-v1-16k-offline"
model_revision="v2.0.4"
python funasr/bin/inference.py \

View File

@ -5,13 +5,13 @@
from funasr import AutoModel
model = AutoModel(model="damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
model = AutoModel(model="iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
model_revision="v2.0.4",
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
vad_model_revision="v2.0.4",
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
punc_model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
punc_model_revision="v2.0.4",
spk_model="damo/speech_campplus_sv_zh-cn_16k-common",
spk_model="iic/speech_campplus_sv_zh-cn_16k-common",
spk_model_revision="v2.0.2"
)

View File

@ -1,11 +1,11 @@
model="damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
model="iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
model_revision="v2.0.4"
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch"
vad_model_revision="v2.0.4"
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
punc_model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
punc_model_revision="v2.0.4"
spk_model="damo/speech_campplus_sv_zh-cn_16k-common"
spk_model="iic/speech_campplus_sv_zh-cn_16k-common"
spk_model_revision="v2.0.2"
python funasr/bin/inference.py \

View File

@ -22,7 +22,7 @@ print(res)
''' can not use currently
from funasr import AutoFrontend
frontend = AutoFrontend(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", model_revision="v2.0.4")
frontend = AutoFrontend(model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", model_revision="v2.0.4")
fbanks = frontend(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", batch_size=2)

View File

@ -32,7 +32,7 @@ torchrun \
--nnodes 1 \
--nproc_per_node ${gpu_num} \
funasr/bin/train.py \
++model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
++model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
++model_revision="v2.0.4" \
++train_data_set_list="${train_data}" \
++valid_data_set_list="${val_data}" \

View File

@ -8,7 +8,7 @@ input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr
output_dir="./outputs/debug"
model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
model_revision="v2.0.4"
device="cuda:0" # "cuda:0" for gpu0, "cuda:1" for gpu1, "cpu"

View File

@ -1,5 +1,5 @@
model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online"
model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online"
model_revision="v2.0.4"
python funasr/bin/inference.py \

View File

@ -7,11 +7,11 @@ from funasr import AutoModel
model = AutoModel(model="iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
model_revision="v2.0.4",
# vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
# vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch",
# vad_model_revision="v2.0.4",
# punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
# punc_model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
# punc_model_revision="v2.0.4",
# spk_model="damo/speech_campplus_sv_zh-cn_16k-common",
# spk_model="iic/speech_campplus_sv_zh-cn_16k-common",
# spk_model_revision="v2.0.2",
)

View File

@ -1,9 +1,9 @@
model="damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
model="iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
model_revision="v2.0.4"
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
vad_model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch"
vad_model_revision="v2.0.4"
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
punc_model="iic/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
punc_model_revision="v2.0.4"
python funasr/bin/inference.py \

View File

@ -16,7 +16,7 @@ print(res)
''' can not use currently
from funasr import AutoFrontend
frontend = AutoFrontend(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", model_revision="v2.0.4")
frontend = AutoFrontend(model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", model_revision="v2.0.4")
fbanks = frontend(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", batch_size=2)

View File

@ -1,5 +1,5 @@
model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
model="iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
model_revision="v2.0.4"
python funasr/bin/inference.py \

View File

@ -0,0 +1,67 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
import types
import torch
from funasr.register import tables
def export_rebuild_model(model, **kwargs):
is_onnx = kwargs.get("type", "onnx") == "onnx"
encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
model.encoder = encoder_class(model.encoder, onnx=is_onnx)
model.forward = types.MethodType(export_forward, model)
model.export_dummy_inputs = types.MethodType(export_dummy_inputs, model)
model.export_input_names = types.MethodType(export_input_names, model)
model.export_output_names = types.MethodType(export_output_names, model)
model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model)
model.export_name = types.MethodType(export_name, model)
return model
def export_forward(self, inputs: torch.Tensor, text_lengths: torch.Tensor):
"""Compute loss value from buffer sequences.
Args:
input (torch.Tensor): Input ids. (batch, len)
hidden (torch.Tensor): Target ids. (batch, len)
"""
x = self.embed(inputs)
h, _ = self.encoder(x, text_lengths)
y = self.decoder(h)
return y
def export_dummy_inputs(self):
length = 120
text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length)).type(torch.int32)
text_lengths = torch.tensor([length-20, length], dtype=torch.int32)
return (text_indexes, text_lengths)
def export_input_names(self):
return ['inputs', 'text_lengths']
def export_output_names(self):
return ['logits']
def export_dynamic_axes(self):
return {
'inputs': {
0: 'batch_size',
1: 'feats_length'
},
'text_lengths': {
0: 'batch_size',
},
'logits': {
0: 'batch_size',
1: 'logits_length'
},
}
def export_name(self):
return "model.onnx"

View File

@ -17,7 +17,10 @@ from funasr.train_utils.device_funcs import force_gatherable
from funasr.utils.load_utils import load_audio_text_image_video
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.models.ct_transformer.utils import split_to_mini_sentence, split_words
try:
import jieba
except:
pass
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
else:
@ -69,6 +72,10 @@ class CTTransformer(torch.nn.Module):
self.sos = sos
self.eos = eos
self.sentence_end_id = sentence_end_id
self.jieba_usr_dict = None
if kwargs.get("jieba_usr_dict", None) is not None:
jieba.load_userdict(kwargs["jieba_usr_dict"])
self.jieba_usr_dict = jieba
@ -237,14 +244,8 @@ class CTTransformer(torch.nn.Module):
# text = data_in[0]
# text_lengths = data_lengths[0] if data_lengths is not None else None
split_size = kwargs.get("split_size", 20)
jieba_usr_dict = kwargs.get("jieba_usr_dict", None)
if jieba_usr_dict and isinstance(jieba_usr_dict, str):
import jieba
jieba.load_userdict(jieba_usr_dict)
jieba_usr_dict = jieba
kwargs["jieba_usr_dict"] = "jieba_usr_dict"
tokens = split_words(text, jieba_usr_dict=jieba_usr_dict)
tokens = split_words(text, jieba_usr_dict=self.jieba_usr_dict)
tokens_int = tokenizer.encode(tokens)
mini_sentences = split_to_mini_sentence(tokens, split_size)
@ -347,7 +348,7 @@ class CTTransformer(torch.nn.Module):
else:
punc_array = torch.cat([punc_array, punctuations], dim=0)
# post processing when using word level punc model
if jieba_usr_dict:
if self.jieba_usr_dict is not None:
len_tokens = len(tokens)
new_punc_array = copy.copy(punc_array).tolist()
# for i, (token, punc_id) in enumerate(zip(tokens[::-1], punc_array.tolist()[::-1])):
@ -364,57 +365,10 @@ class CTTransformer(torch.nn.Module):
results.append(result_i)
return results, meta_data
def export(
self,
**kwargs,
):
def export(self, **kwargs):
from .export_meta import export_rebuild_model
models = export_rebuild_model(model=self, **kwargs)
return models
is_onnx = kwargs.get("type", "onnx") == "onnx"
encoder_class = tables.encoder_classes.get(kwargs["encoder"]+"Export")
self.encoder = encoder_class(self.encoder, onnx=is_onnx)
self.forward = self.export_forward
return self
def export_forward(self, inputs: torch.Tensor, text_lengths: torch.Tensor):
"""Compute loss value from buffer sequences.
Args:
input (torch.Tensor): Input ids. (batch, len)
hidden (torch.Tensor): Target ids. (batch, len)
"""
x = self.embed(inputs)
h, _ = self.encoder(x, text_lengths)
y = self.decoder(h)
return y
def export_dummy_inputs(self):
length = 120
text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length)).type(torch.int32)
text_lengths = torch.tensor([length-20, length], dtype=torch.int32)
return (text_indexes, text_lengths)
def export_input_names(self):
return ['inputs', 'text_lengths']
def export_output_names(self):
return ['logits']
def export_dynamic_axes(self):
return {
'inputs': {
0: 'batch_size',
1: 'feats_length'
},
'text_lengths': {
0: 'batch_size',
},
'logits': {
0: 'batch_size',
1: 'logits_length'
},
}
def export_name(self):
return "model.onnx"

View File

@ -0,0 +1,77 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
import types
import torch
from funasr.register import tables
def export_rebuild_model(model, **kwargs):
is_onnx = kwargs.get("type", "onnx") == "onnx"
encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
model.encoder = encoder_class(model.encoder, onnx=is_onnx)
model.forward = types.MethodType(export_forward, model)
model.export_dummy_inputs = types.MethodType(export_dummy_inputs, model)
model.export_input_names = types.MethodType(export_input_names, model)
model.export_output_names = types.MethodType(export_output_names, model)
model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model)
model.export_name = types.MethodType(export_name, model)
return model
def export_forward(self, inputs: torch.Tensor,
text_lengths: torch.Tensor,
vad_indexes: torch.Tensor,
sub_masks: torch.Tensor,
):
"""Compute loss value from buffer sequences.
Args:
input (torch.Tensor): Input ids. (batch, len)
hidden (torch.Tensor): Target ids. (batch, len)
"""
x = self.embed(inputs)
# mask = self._target_mask(input)
h, _ = self.encoder(x, text_lengths, vad_indexes, sub_masks)
y = self.decoder(h)
return y
def export_dummy_inputs(self):
length = 120
text_indexes = torch.randint(0, self.embed.num_embeddings, (1, length)).type(torch.int32)
text_lengths = torch.tensor([length], dtype=torch.int32)
vad_mask = torch.ones(length, length, dtype=torch.float32)[None, None, :, :]
sub_masks = torch.ones(length, length, dtype=torch.float32)
sub_masks = torch.tril(sub_masks).type(torch.float32)
return (text_indexes, text_lengths, vad_mask, sub_masks[None, None, :, :])
def export_input_names(self):
return ['inputs', 'text_lengths', 'vad_masks', 'sub_masks']
def export_output_names(self):
return ['logits']
def export_dynamic_axes(self):
return {
'inputs': {
1: 'feats_length'
},
'vad_masks': {
2: 'feats_length1',
3: 'feats_length2'
},
'sub_masks': {
2: 'feats_length1',
3: 'feats_length2'
},
'logits': {
1: 'logits_length'
},
}
def export_name(self):
return "model.onnx"

View File

@ -173,68 +173,9 @@ class CTTransformerStreaming(CTTransformer):
return results, meta_data
def export(
self,
**kwargs,
):
def export(self, **kwargs):
is_onnx = kwargs.get("type", "onnx") == "onnx"
encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
self.encoder = encoder_class(self.encoder, onnx=is_onnx)
self.forward = self.export_forward
return self
from .export_meta import export_rebuild_model
models = export_rebuild_model(model=self, **kwargs)
return models
def export_forward(self, inputs: torch.Tensor,
text_lengths: torch.Tensor,
vad_indexes: torch.Tensor,
sub_masks: torch.Tensor,
):
"""Compute loss value from buffer sequences.
Args:
input (torch.Tensor): Input ids. (batch, len)
hidden (torch.Tensor): Target ids. (batch, len)
"""
x = self.embed(inputs)
# mask = self._target_mask(input)
h, _ = self.encoder(x, text_lengths, vad_indexes, sub_masks)
y = self.decoder(h)
return y
def export_dummy_inputs(self):
length = 120
text_indexes = torch.randint(0, self.embed.num_embeddings, (1, length)).type(torch.int32)
text_lengths = torch.tensor([length], dtype=torch.int32)
vad_mask = torch.ones(length, length, dtype=torch.float32)[None, None, :, :]
sub_masks = torch.ones(length, length, dtype=torch.float32)
sub_masks = torch.tril(sub_masks).type(torch.float32)
return (text_indexes, text_lengths, vad_mask, sub_masks[None, None, :, :])
def export_input_names(self):
return ['inputs', 'text_lengths', 'vad_masks', 'sub_masks']
def export_output_names(self):
return ['logits']
def export_dynamic_axes(self):
return {
'inputs': {
1: 'feats_length'
},
'vad_masks': {
2: 'feats_length1',
3: 'feats_length2'
},
'sub_masks': {
2: 'feats_length1',
3: 'feats_length2'
},
'logits': {
1: 'logits_length'
},
}
def export_name(self):
return "model.onnx"

View File

@ -14,16 +14,14 @@ requirements = {
"librosa",
"jamo", # For kss
"PyYAML>=5.1.2",
# "soundfile>=0.12.1",
"soundfile>=0.12.1",
"kaldiio>=2.17.0",
"torch_complex",
# "nltk>=3.4.5",
# ASR
"sentencepiece", # train
"jieba",
# "rotary_embedding_torch",
"rotary_embedding_torch",
# "ffmpeg-python",
# TTS
# "pypinyin>=0.44.0",
# "espnet_tts_frontend",
# ENH
@ -54,6 +52,7 @@ requirements = {
"torch_optimizer",
"fairscale",
"transformers",
"openai-whisper"
],
"setup": [
"numpy",
@ -96,6 +95,7 @@ requirements = {
],
}
requirements["all"].extend(requirements["train"])
requirements["all"].extend(requirements["llm"])
requirements["test"].extend(requirements["train"])
install_requires = requirements["install"]