mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
* support clas torchscripts * fix CompileHotwordEmbedding * add batch for tensor_hw_emb * fix func of TimestampOnnx * fix func of TimestampOnnx * fix func of TimestampOnnx * fix paraformer-torch fwd * fix paraformer-torch fwd * fix paraformer-torch fwd * fix ~paraformer-torch * update funasr-onnx-offline-rtf * update funasr-onnx-offline-rtf * update funasr-onnx-offline-rtf * change tos model names * fix results of ParaformerTorch::Forward * fix results of ParaformerTorch::Forward * add FusionStrategy for torch * fix paraformer torch * sync to main (#1826) * resume from step * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * train_loss_avg train_acc_avg * train_loss_avg train_acc_avg * train_loss_avg train_acc_avg * log step * wav is not exist * wav is not exist * decoding * decoding * decoding * wechat * decoding key * decoding key * decoding key * decoding key * decoding key * decoding key * dynamic batch * start_data_split_i=0 * total_time/accum_grad * total_time/accum_grad * total_time/accum_grad * update avg slice * update avg slice * sensevoice sanm * sensevoice sanm * add * add * add * add * deepspeed * update with main (#1731) * c++ runtime adapt to 1.0 (#1724) * adapt vad runtime to 1.0 * add json * change yml name * add func LoadVocabFromJson * add token file for InitAsr * add token path for OfflineStream * add funcOpenYaml * add token file for InitPunc * add token file for stream * update punc-model * update funasr-wss-server * update runtime_sdk_download_tool.py * update docker list * Delete docs/images/wechat.png * Add files via upload * Emo2Vec限定选择的情感类别 (#1730) * 限定选择的情感类别 * 使用none来禁用情感标签输出 * 修改输出接口 * 使用unuse来禁用token --------- Co-authored-by: 常材 <gaochangfeng.gcf@alibaba-inc.com> * bugfix * v1.0.27 * update docs * hf hub * Fix incorrect assignment of 'end' attribute to 'start' in sentences list comprehension (#1680) --------- Co-authored-by: Yabin Li <wucong.lyb@alibaba-inc.com> Co-authored-by: gaochangfeng <54253717+gaochangfeng@users.noreply.github.com> Co-authored-by: 常材 <gaochangfeng.gcf@alibaba-inc.com> Co-authored-by: nsdou <168500039+nsdou@users.noreply.github.com> * docs * docs * deepspeed * deepspeed * deepspeed * deepspeed * update * ds * ds * ds * ds * ds * ds * ds * add * add * bugfix * add * wenetspeech * wenetspeech * wenetspeech * wenetspeech * wenetspeech * wenetspeech * update export * update export * update export name * update * docs * update wechat QRcode * Add python funasr api support for websocket srv (#1777) * add python funasr_api supoort * change little to README.md * add core tools stream * modified a little * fix bug for timeout * support for buffer decode * add ffmpeg decode for buffer * libtorch demo * update libtorch infer * update utils * update demo * update demo * update libtorch inference * update model class * update seaco paraformer * bug fix * bug fix * auto frontend * auto frontend * update with main (#1783) * add cmakelist * add paraformer-torch * add debug for funasr-onnx-offline * fix redefinition of jieba StdExtension.hpp * add loading torch models * update funasr-onnx-offline * add SwitchArg for wss-server * add SwitchArg for funasr-onnx-offline * update cmakelist * update funasr-onnx-offline-rtf * add define condition * add gpu define for offlne-stream * update com define * update offline-stream * update cmakelist * update func CompileHotwordEmbedding * add timestamp for paraformer-torch * add C10_USE_GLOG for paraformer-torch * update paraformer-torch * fix func FunASRWfstDecoderInit * update model.h * fix func FunASRWfstDecoderInit * fix tpass_stream * update paraformer-torch * add bladedisc for funasr-onnx-offline * update comdefine * update funasr-wss-server * add log for torch * fix GetValue BLADEDISC * fix log * update cmakelist * update warmup to 10 * update funasrruntime * add batch_size for wss-server * add batch for bins * add batch for offline-stream * add batch for paraformer * add batch for offline-stream * fix func SetBatchSize * add SetBatchSize for model * add SetBatchSize for model * fix func Forward * fix padding * update funasrruntime * add dec reset for batch * set batch default value * add argv for CutSplit * sort frame_queue * sorted msgs * fix FunOfflineInfer * add dynamic batch for fetch * fix FetchDynamic * update run_server.sh * update run_server.sh * cpp http post server support (#1739) * add cpp http server * add some comment * remove some comments * del debug infos * restore run_server.sh * adapt to new model struct * 修复了onnxruntime在macos下编译失败的错误 (#1748) * Add files via upload 增加macos的编译支持 * Add files via upload 增加macos支持 * Add files via upload target_link_directories(funasr PUBLIC ${ONNXRUNTIME_DIR}/lib) target_link_directories(funasr PUBLIC ${FFMPEG_DIR}/lib) 添加 if(APPLE) 限制 --------- Co-authored-by: Yabin Li <wucong.lyb@alibaba-inc.com> * Delete docs/images/wechat.png * Add files via upload * fixed the issues about seaco-onnx timestamp * fix bug (#1764) 当语音识别结果包含 `http` 时,标点符号预测会把它会被当成 url * fix empty asr result (#1765) 解码结果为空的语音片段,text 用空字符串 * docs * docs * docs * docs * docs * keep empty speech result (#1772) * docs * docs * update wechat QRcode * Add python funasr api support for websocket srv (#1777) * add python funasr_api supoort * change little to README.md * add core tools stream * modified a little * fix bug for timeout * support for buffer decode * add ffmpeg decode for buffer * auto frontend * auto frontend --------- Co-authored-by: 雾聪 <wucong.lyb@alibaba-inc.com> Co-authored-by: zhaomingwork <61895407+zhaomingwork@users.noreply.github.com> Co-authored-by: szsteven008 <97944818+szsteven008@users.noreply.github.com> Co-authored-by: Ephemeroptera <605686962@qq.com> Co-authored-by: 彭震东 <zhendong.peng@qq.com> Co-authored-by: Shi Xian <40013335+R1ckShi@users.noreply.github.com> Co-authored-by: 维石 <shixian.shi@alibaba-inc.com> * auto frontend * auto frontend * auto frontend * auto frontend * auto frontend * auto frontend * Dev gzf exp (#1785) * resume from step * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * train_loss_avg train_acc_avg * train_loss_avg train_acc_avg * train_loss_avg train_acc_avg * log step * wav is not exist * wav is not exist * decoding * decoding * decoding * wechat * decoding key * decoding key * decoding key * decoding key * decoding key * decoding key * dynamic batch * start_data_split_i=0 * total_time/accum_grad * total_time/accum_grad * total_time/accum_grad * update avg slice * update avg slice * sensevoice sanm * sensevoice sanm * sensevoice sanm --------- Co-authored-by: 北念 <lzr265946@alibaba-inc.com> * auto frontend * update with main (#1786) * add cmakelist * add paraformer-torch * add debug for funasr-onnx-offline * fix redefinition of jieba StdExtension.hpp * add loading torch models * update funasr-onnx-offline * add SwitchArg for wss-server * add SwitchArg for funasr-onnx-offline * update cmakelist * update funasr-onnx-offline-rtf * add define condition * add gpu define for offlne-stream * update com define * update offline-stream * update cmakelist * update func CompileHotwordEmbedding * add timestamp for paraformer-torch * add C10_USE_GLOG for paraformer-torch * update paraformer-torch * fix func FunASRWfstDecoderInit * update model.h * fix func FunASRWfstDecoderInit * fix tpass_stream * update paraformer-torch * add bladedisc for funasr-onnx-offline * update comdefine * update funasr-wss-server * add log for torch * fix GetValue BLADEDISC * fix log * update cmakelist * update warmup to 10 * update funasrruntime * add batch_size for wss-server * add batch for bins * add batch for offline-stream * add batch for paraformer * add batch for offline-stream * fix func SetBatchSize * add SetBatchSize for model * add SetBatchSize for model * fix func Forward * fix padding * update funasrruntime * add dec reset for batch * set batch default value * add argv for CutSplit * sort frame_queue * sorted msgs * fix FunOfflineInfer * add dynamic batch for fetch * fix FetchDynamic * update run_server.sh * update run_server.sh * cpp http post server support (#1739) * add cpp http server * add some comment * remove some comments * del debug infos * restore run_server.sh * adapt to new model struct * 修复了onnxruntime在macos下编译失败的错误 (#1748) * Add files via upload 增加macos的编译支持 * Add files via upload 增加macos支持 * Add files via upload target_link_directories(funasr PUBLIC ${ONNXRUNTIME_DIR}/lib) target_link_directories(funasr PUBLIC ${FFMPEG_DIR}/lib) 添加 if(APPLE) 限制 --------- Co-authored-by: Yabin Li <wucong.lyb@alibaba-inc.com> * Delete docs/images/wechat.png * Add files via upload * fixed the issues about seaco-onnx timestamp * fix bug (#1764) 当语音识别结果包含 `http` 时,标点符号预测会把它会被当成 url * fix empty asr result (#1765) 解码结果为空的语音片段,text 用空字符串 * docs * docs * docs * docs * docs * keep empty speech result (#1772) * docs * docs * update wechat QRcode * Add python funasr api support for websocket srv (#1777) * add python funasr_api supoort * change little to README.md * add core tools stream * modified a little * fix bug for timeout * support for buffer decode * add ffmpeg decode for buffer * auto frontend * auto frontend * auto frontend * auto frontend * auto frontend * auto frontend * auto frontend * auto frontend * Dev gzf exp (#1785) * resume from step * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * batch * train_loss_avg train_acc_avg * train_loss_avg train_acc_avg * train_loss_avg train_acc_avg * log step * wav is not exist * wav is not exist * decoding * decoding * decoding * wechat * decoding key * decoding key * decoding key * decoding key * decoding key * decoding key * dynamic batch * start_data_split_i=0 * total_time/accum_grad * total_time/accum_grad * total_time/accum_grad * update avg slice * update avg slice * sensevoice sanm * sensevoice sanm * sensevoice sanm --------- Co-authored-by: 北念 <lzr265946@alibaba-inc.com> * auto frontend --------- Co-authored-by: 雾聪 <wucong.lyb@alibaba-inc.com> Co-authored-by: zhaomingwork <61895407+zhaomingwork@users.noreply.github.com> Co-authored-by: szsteven008 <97944818+szsteven008@users.noreply.github.com> Co-authored-by: Ephemeroptera <605686962@qq.com> Co-authored-by: 彭震东 <zhendong.peng@qq.com> Co-authored-by: Shi Xian <40013335+R1ckShi@users.noreply.github.com> Co-authored-by: 维石 <shixian.shi@alibaba-inc.com> Co-authored-by: 北念 <lzr265946@alibaba-inc.com> * update paraformer timestamp * auto frontend * auto frontend * [Optimization] support bladedisc fp16 optimization (#1790) * auto frontend * auto frontend * auto frontend * auto frontend * auto frontend * auto frontend * auto frontend * auto frontend * auto frontend * auto frontend * auto frontend * auto frontend * auto frontend * auto frontend * auto frontend * add cif_v1 and cif_export * auto frontend * Update SDK_advanced_guide_offline_zh.md * add cif_wo_hidden_v1 * auto frontend * auto frontend * auto frontend * fix bug * [fix] fix empty asr result (#1794) * fix bug * fix bug * fix bug * fix bug * fix bug * fix bug * fix bug * fix bug * fix bug * fix bug * fix bug * fix bug * fix bug * fix bug * fix bug * fix bug * fix bug * fix bug * fix bug * fix bug * fix bug * fix bug * fix bug * fix bug * fix bug * fix bug * fix bug * fix bug * fp16 * english timestamp for valilla paraformer * fp16 * wechat * fixbug * [fix] better solution for handling empty result (#1796) * update scripts * modify the qformer adaptor (#1804) Co-authored-by: nichongjia-2007 <nichongjia@gmail.com> * add ctc inference code (#1806) Co-authored-by: haoneng.lhn <haoneng.lhn@alibaba-inc.com> * Update auto_model.py 修复空字串进入speaker model时报raw_text变量不存在的bug * Update auto_model.py 修复识别出空串后spk_model内变量未定义问题 * update model name * fix paramter 'quantize' unused issue (#1813) Co-authored-by: ZihanLiao <liaozihan1@xdf.cn> * wechat * Update cif_predictor.py (#1811) * Update cif_predictor.py * modify cif_v1_export under extreme cases, max_label_len calculated by batch_len misaligns with token_num * Update cif_predictor.py torch.cumsum precision degradation, using float64 instead * update code --------- Co-authored-by: 游雁 <zhifu.gzf@alibaba-inc.com> Co-authored-by: gaochangfeng <54253717+gaochangfeng@users.noreply.github.com> Co-authored-by: 常材 <gaochangfeng.gcf@alibaba-inc.com> Co-authored-by: nsdou <168500039+nsdou@users.noreply.github.com> Co-authored-by: 维石 <shixian.shi@alibaba-inc.com> Co-authored-by: zhaomingwork <61895407+zhaomingwork@users.noreply.github.com> Co-authored-by: szsteven008 <97944818+szsteven008@users.noreply.github.com> Co-authored-by: Ephemeroptera <605686962@qq.com> Co-authored-by: 彭震东 <zhendong.peng@qq.com> Co-authored-by: Shi Xian <40013335+R1ckShi@users.noreply.github.com> Co-authored-by: 北念 <lzr265946@alibaba-inc.com> Co-authored-by: xiaowan0322 <wanchen.swc@alibaba-inc.com> Co-authored-by: zhuangzhong <zhuangzhong@corp.netease.com> Co-authored-by: Xingchen Song(宋星辰) <xingchensong1996@163.com> Co-authored-by: nichongjia-2007 <nichongjia@gmail.com> Co-authored-by: haoneng.lhn <haoneng.lhn@alibaba-inc.com> Co-authored-by: liugz18 <57401541+liugz18@users.noreply.github.com> Co-authored-by: Marlowe <54339989+ZihanLiao@users.noreply.github.com> Co-authored-by: ZihanLiao <liaozihan1@xdf.cn> Co-authored-by: zhong zhuang <zhuangz@lamda.nju.edu.cn> * update runtime_sdk_download_tool * update funasr-wss-server * update vad_revision * update funasr-wss-server * update funasr-wss-server * update punc quant * rename torchscript * Delete examples/industrial_data_pretraining/ctc/infer_from_local.py * resolve conflicts --------- Co-authored-by: 游雁 <zhifu.gzf@alibaba-inc.com> Co-authored-by: gaochangfeng <54253717+gaochangfeng@users.noreply.github.com> Co-authored-by: 常材 <gaochangfeng.gcf@alibaba-inc.com> Co-authored-by: nsdou <168500039+nsdou@users.noreply.github.com> Co-authored-by: 维石 <shixian.shi@alibaba-inc.com> Co-authored-by: zhaomingwork <61895407+zhaomingwork@users.noreply.github.com> Co-authored-by: szsteven008 <97944818+szsteven008@users.noreply.github.com> Co-authored-by: Ephemeroptera <605686962@qq.com> Co-authored-by: 彭震东 <zhendong.peng@qq.com> Co-authored-by: Shi Xian <40013335+R1ckShi@users.noreply.github.com> Co-authored-by: 北念 <lzr265946@alibaba-inc.com> Co-authored-by: xiaowan0322 <wanchen.swc@alibaba-inc.com> Co-authored-by: zhuangzhong <zhuangzhong@corp.netease.com> Co-authored-by: Xingchen Song(宋星辰) <xingchensong1996@163.com> Co-authored-by: nichongjia-2007 <nichongjia@gmail.com> Co-authored-by: haoneng.lhn <haoneng.lhn@alibaba-inc.com> Co-authored-by: liugz18 <57401541+liugz18@users.noreply.github.com> Co-authored-by: Marlowe <54339989+ZihanLiao@users.noreply.github.com> Co-authored-by: ZihanLiao <liaozihan1@xdf.cn> Co-authored-by: zhong zhuang <zhuangz@lamda.nju.edu.cn>
197 lines
6.2 KiB
Python
197 lines
6.2 KiB
Python
import os
|
|
import torch
|
|
import functools
|
|
|
|
|
|
def export(
|
|
model, data_in=None, quantize: bool = False, opset_version: int = 14, type="onnx", **kwargs
|
|
):
|
|
model_scripts = model.export(**kwargs)
|
|
export_dir = kwargs.get("output_dir", os.path.dirname(kwargs.get("init_param")))
|
|
os.makedirs(export_dir, exist_ok=True)
|
|
|
|
if not isinstance(model_scripts, (list, tuple)):
|
|
model_scripts = (model_scripts,)
|
|
for m in model_scripts:
|
|
m.eval()
|
|
if type == "onnx":
|
|
_onnx(
|
|
m,
|
|
data_in=data_in,
|
|
quantize=quantize,
|
|
opset_version=opset_version,
|
|
export_dir=export_dir,
|
|
**kwargs,
|
|
)
|
|
elif type == "torchscript":
|
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
print("Exporting torchscripts on device {}".format(device))
|
|
_torchscripts(m, path=export_dir, device=device)
|
|
elif type == "bladedisc":
|
|
assert (
|
|
torch.cuda.is_available()
|
|
), "Currently bladedisc optimization for FunASR only supports GPU"
|
|
# bladedisc only optimizes encoder/decoder modules
|
|
if hasattr(m, "encoder") and hasattr(m, "decoder"):
|
|
_bladedisc_opt_for_encdec(m, path=export_dir, enable_fp16=True)
|
|
else:
|
|
_torchscripts(m, path=export_dir, device="cuda")
|
|
print("output dir: {}".format(export_dir))
|
|
|
|
return export_dir
|
|
|
|
|
|
def _onnx(
|
|
model,
|
|
data_in=None,
|
|
quantize: bool = False,
|
|
opset_version: int = 14,
|
|
export_dir: str = None,
|
|
**kwargs,
|
|
):
|
|
|
|
dummy_input = model.export_dummy_inputs()
|
|
|
|
verbose = kwargs.get("verbose", False)
|
|
|
|
export_name = model.export_name + ".onnx"
|
|
model_path = os.path.join(export_dir, export_name)
|
|
torch.onnx.export(
|
|
model,
|
|
dummy_input,
|
|
model_path,
|
|
verbose=verbose,
|
|
opset_version=opset_version,
|
|
input_names=model.export_input_names(),
|
|
output_names=model.export_output_names(),
|
|
dynamic_axes=model.export_dynamic_axes(),
|
|
)
|
|
|
|
if quantize:
|
|
from onnxruntime.quantization import QuantType, quantize_dynamic
|
|
import onnx
|
|
|
|
quant_model_path = model_path.replace(".onnx", "_quant.onnx")
|
|
if not os.path.exists(quant_model_path):
|
|
onnx_model = onnx.load(model_path)
|
|
nodes = [n.name for n in onnx_model.graph.node]
|
|
nodes_to_exclude = [
|
|
m for m in nodes if "output" in m or "bias_encoder" in m or "bias_decoder" in m
|
|
]
|
|
quantize_dynamic(
|
|
model_input=model_path,
|
|
model_output=quant_model_path,
|
|
op_types_to_quantize=["MatMul"],
|
|
per_channel=True,
|
|
reduce_range=False,
|
|
weight_type=QuantType.QUInt8,
|
|
nodes_to_exclude=nodes_to_exclude,
|
|
)
|
|
|
|
|
|
def _torchscripts(model, path, device="cuda"):
|
|
dummy_input = model.export_dummy_inputs()
|
|
|
|
if device == "cuda":
|
|
model = model.cuda()
|
|
if isinstance(dummy_input, torch.Tensor):
|
|
dummy_input = dummy_input.cuda()
|
|
else:
|
|
dummy_input = tuple([i.cuda() for i in dummy_input])
|
|
|
|
model_script = torch.jit.trace(model, dummy_input)
|
|
model_script.save(os.path.join(path, f"{model.export_name}.torchscript"))
|
|
|
|
|
|
def _bladedisc_opt(model, model_inputs, enable_fp16=True):
|
|
model = model.eval()
|
|
try:
|
|
import torch_blade
|
|
except Exception as e:
|
|
print(
|
|
f"Warning, if you are exporting bladedisc, please install it and try it again: pip install -U torch_blade\n"
|
|
)
|
|
torch_config = torch_blade.config.Config()
|
|
torch_config.enable_fp16 = enable_fp16
|
|
with torch.no_grad(), torch_config:
|
|
opt_model = torch_blade.optimize(
|
|
model,
|
|
allow_tracing=True,
|
|
model_inputs=model_inputs,
|
|
)
|
|
return opt_model
|
|
|
|
|
|
def _rescale_input_hook(m, x, scale):
|
|
if len(x) > 1:
|
|
return (x[0] / scale, *x[1:])
|
|
else:
|
|
return (x[0] / scale,)
|
|
|
|
|
|
def _rescale_output_hook(m, x, y, scale):
|
|
if isinstance(y, tuple):
|
|
return (y[0] / scale, *y[1:])
|
|
else:
|
|
return y / scale
|
|
|
|
|
|
def _rescale_encoder_model(model, input_data):
|
|
# Calculate absmax
|
|
absmax = torch.tensor(0).cuda()
|
|
|
|
def stat_input_hook(m, x, y):
|
|
val = x[0] if isinstance(x, tuple) else x
|
|
absmax.copy_(torch.max(absmax, val.detach().abs().max()))
|
|
|
|
encoders = model.encoder.model.encoders
|
|
hooks = [m.register_forward_hook(stat_input_hook) for m in encoders]
|
|
model = model.cuda()
|
|
model(*input_data)
|
|
for h in hooks:
|
|
h.remove()
|
|
|
|
# Rescale encoder modules
|
|
fp16_scale = int(2 * absmax // 65536)
|
|
print(f"rescale encoder modules with factor={fp16_scale}")
|
|
model.encoder.model.encoders0.register_forward_pre_hook(
|
|
functools.partial(_rescale_input_hook, scale=fp16_scale),
|
|
)
|
|
for name, m in model.encoder.model.named_modules():
|
|
if name.endswith("self_attn"):
|
|
m.register_forward_hook(functools.partial(_rescale_output_hook, scale=fp16_scale))
|
|
if name.endswith("feed_forward.w_2"):
|
|
state_dict = {k: v / fp16_scale for k, v in m.state_dict().items()}
|
|
m.load_state_dict(state_dict)
|
|
|
|
|
|
def _bladedisc_opt_for_encdec(model, path, enable_fp16):
|
|
# Get input data
|
|
# TODO: better to use real data
|
|
input_data = model.export_dummy_inputs()
|
|
if isinstance(input_data, torch.Tensor):
|
|
input_data = input_data.cuda()
|
|
else:
|
|
input_data = tuple([i.cuda() for i in input_data])
|
|
|
|
# Get input data for decoder module
|
|
decoder_inputs = list()
|
|
|
|
def get_input_hook(m, x):
|
|
decoder_inputs.extend(list(x))
|
|
|
|
hook = model.decoder.register_forward_pre_hook(get_input_hook)
|
|
model = model.cuda()
|
|
model(*input_data)
|
|
hook.remove()
|
|
|
|
# Prevent FP16 overflow
|
|
if enable_fp16:
|
|
_rescale_encoder_model(model, input_data)
|
|
|
|
# Export and optimize encoder/decoder modules
|
|
model.encoder = _bladedisc_opt(model.encoder, input_data[:2])
|
|
model.decoder = _bladedisc_opt(model.decoder, tuple(decoder_inputs))
|
|
model_script = torch.jit.trace(model, input_data)
|
|
model_script.save(os.path.join(path, f"{model.export_name}_blade.torchscript"))
|