FunASR/funasr/utils/export_utils.py
zhifu gao 35b1c051f6
Dev gzf llm (#1493)
* update

* update

* update

* update onnx

* update with main (#1492)

* contextual&seaco ONNX export (#1481)

* contextual&seaco ONNX export

* update ContextualEmbedderExport2

* update ContextualEmbedderExport2

* update code

* onnx (#1482)

* qwenaudio qwenaudiochat

* qwenaudio qwenaudiochat

* whisper

* whisper

* llm

* llm

* llm

* llm

* llm

* llm

* llm

* llm

* export onnx

* export onnx

* export onnx

* dingding

* dingding

* llm

* doc

* onnx

* onnx

* onnx

* onnx

* onnx

* onnx

* v1.0.15

* qwenaudio

* qwenaudio

* issue doc

* update

* update

* bugfix

* onnx

* update export calling

* update codes

* remove useless code

* update code

---------

Co-authored-by: zhifu gao <zhifu.gzf@alibaba-inc.com>

* acknowledge

---------

Co-authored-by: Shi Xian <40013335+R1ckShi@users.noreply.github.com>

* update onnx

* update onnx

---------

Co-authored-by: Shi Xian <40013335+R1ckShi@users.noreply.github.com>
2024-03-14 09:33:30 +08:00

68 lines
1.9 KiB
Python

import os
import torch
def export_onnx(model,
data_in=None,
quantize: bool = False,
opset_version: int = 14,
**kwargs):
model_scripts = model.export(**kwargs)
export_dir = kwargs.get("output_dir", os.path.dirname(kwargs.get("init_param")))
os.makedirs(export_dir, exist_ok=True)
if not isinstance(model_scripts, (list, tuple)):
model_scripts = (model_scripts,)
for m in model_scripts:
m.eval()
_onnx(m,
data_in=data_in,
quantize=quantize,
opset_version=opset_version,
export_dir=export_dir,
**kwargs
)
print("output dir: {}".format(export_dir))
return export_dir
def _onnx(model,
data_in=None,
quantize: bool = False,
opset_version: int = 14,
export_dir:str = None,
**kwargs):
dummy_input = model.export_dummy_inputs()
verbose = kwargs.get("verbose", False)
export_name = model.export_name() if hasattr(model, "export_name") else "model.onnx"
model_path = os.path.join(export_dir, export_name)
torch.onnx.export(
model,
dummy_input,
model_path,
verbose=verbose,
opset_version=opset_version,
input_names=model.export_input_names(),
output_names=model.export_output_names(),
dynamic_axes=model.export_dynamic_axes()
)
if quantize:
from onnxruntime.quantization import QuantType, quantize_dynamic
import onnx
quant_model_path = model_path.replace(".onnx", "_quant.onnx")
if not os.path.exists(quant_model_path):
onnx_model = onnx.load(model_path)
nodes = [n.name for n in onnx_model.graph.node]
nodes_to_exclude = [m for m in nodes if 'output' in m or 'bias_encoder' in m or 'bias_decoder' in m]
quantize_dynamic(
model_input=model_path,
model_output=quant_model_path,
op_types_to_quantize=['MatMul'],
per_channel=True,
reduce_range=False,
weight_type=QuantType.QUInt8,
nodes_to_exclude=nodes_to_exclude,
)