FunASR/funasr/export/export_model.py
2023-09-12 19:40:02 +08:00

297 lines
11 KiB
Python

import os
import torch
import random
import logging
import numpy as np
from pathlib import Path
from typing import Union, Dict, List
from funasr.export.models import get_model
from funasr.utils.types import str2bool, str2triple_str
# torch_version = float(".".join(torch.__version__.split(".")[:2]))
# assert torch_version > 1.9
class ModelExport:
def __init__(
self,
cache_dir: Union[Path, str] = None,
onnx: bool = True,
device: str = "cpu",
quant: bool = True,
fallback_num: int = 0,
audio_in: str = None,
calib_num: int = 200,
model_revision: str = None,
):
self.set_all_random_seed(0)
self.cache_dir = cache_dir
self.export_config = dict(
feats_dim=560,
onnx=False,
)
self.onnx = onnx
self.device = device
self.quant = quant
self.fallback_num = fallback_num
self.frontend = None
self.audio_in = audio_in
self.calib_num = calib_num
self.model_revision = model_revision
def _export(
self,
model,
tag_name: str = None,
verbose: bool = False,
):
export_dir = self.cache_dir
os.makedirs(export_dir, exist_ok=True)
# export encoder1
self.export_config["model_name"] = "model"
model = get_model(
model,
self.export_config,
)
if isinstance(model, List):
for m in model:
m.eval()
if self.onnx:
self._export_onnx(m, verbose, export_dir)
else:
self._export_torchscripts(m, verbose, export_dir)
print("output dir: {}".format(export_dir))
else:
model.eval()
# self._export_onnx(model, verbose, export_dir)
if self.onnx:
self._export_onnx(model, verbose, export_dir)
else:
self._export_torchscripts(model, verbose, export_dir)
print("output dir: {}".format(export_dir))
def _torch_quantize(self, model):
def _run_calibration_data(m):
# using dummy inputs for a example
if self.audio_in is not None:
feats, feats_len = self.load_feats(self.audio_in)
for i, (feat, len) in enumerate(zip(feats, feats_len)):
with torch.no_grad():
m(feat, len)
else:
dummy_input = model.get_dummy_inputs()
m(*dummy_input)
from torch_quant.module import ModuleFilter
from torch_quant.quantizer import Backend, Quantizer
from funasr.export.models.modules.decoder_layer import DecoderLayerSANM
from funasr.export.models.modules.encoder_layer import EncoderLayerSANM
module_filter = ModuleFilter(include_classes=[EncoderLayerSANM, DecoderLayerSANM])
module_filter.exclude_op_types = [torch.nn.Conv1d]
quantizer = Quantizer(
module_filter=module_filter,
backend=Backend.FBGEMM,
)
model.eval()
calib_model = quantizer.calib(model)
_run_calibration_data(calib_model)
if self.fallback_num > 0:
# perform automatic mixed precision quantization
amp_model = quantizer.amp(model)
_run_calibration_data(amp_model)
quantizer.fallback(amp_model, num=self.fallback_num)
print('Fallback layers:')
print('\n'.join(quantizer.module_filter.exclude_names))
quant_model = quantizer.quantize(model)
return quant_model
def _export_torchscripts(self, model, verbose, path, enc_size=None):
if enc_size:
dummy_input = model.get_dummy_inputs(enc_size)
else:
dummy_input = model.get_dummy_inputs()
if self.device == 'cuda':
model = model.cuda()
dummy_input = tuple([i.cuda() for i in dummy_input])
# model_script = torch.jit.script(model)
model_script = torch.jit.trace(model, dummy_input)
model_script.save(os.path.join(path, f'{model.model_name}.torchscripts'))
if self.quant:
quant_model = self._torch_quantize(model)
model_script = torch.jit.trace(quant_model, dummy_input)
model_script.save(os.path.join(path, f'{model.model_name}_quant.torchscripts'))
def set_all_random_seed(self, seed: int):
random.seed(seed)
np.random.seed(seed)
torch.random.manual_seed(seed)
def parse_audio_in(self, audio_in):
wav_list, name_list = [], []
if audio_in.endswith(".scp"):
f = open(audio_in, 'r')
lines = f.readlines()[:self.calib_num]
for line in lines:
name, path = line.strip().split()
name_list.append(name)
wav_list.append(path)
else:
wav_list = [audio_in,]
name_list = ["test",]
return wav_list, name_list
def load_feats(self, audio_in: str = None):
import torchaudio
wav_list, name_list = self.parse_audio_in(audio_in)
feats = []
feats_len = []
for line in wav_list:
path = line.strip()
waveform, sampling_rate = torchaudio.load(path)
if sampling_rate != self.frontend.fs:
waveform = torchaudio.transforms.Resample(orig_freq=sampling_rate,
new_freq=self.frontend.fs)(waveform)
fbank, fbank_len = self.frontend(waveform, [waveform.size(1)])
feats.append(fbank)
feats_len.append(fbank_len)
return feats, feats_len
def export(self,
tag_name: str = 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch',
mode: str = None,
):
model_dir = tag_name
if model_dir.startswith('damo'):
from modelscope.hub.snapshot_download import snapshot_download
model_dir = snapshot_download(model_dir, cache_dir=self.cache_dir, revision=self.model_revision)
self.cache_dir = model_dir
if mode is None:
import json
json_file = os.path.join(model_dir, 'configuration.json')
with open(json_file, 'r') as f:
config_data = json.load(f)
if config_data['task'] == "punctuation":
mode = config_data['model']['punc_model_config']['mode']
else:
mode = config_data['model']['model_config']['mode']
if mode.startswith('paraformer'):
from funasr.tasks.asr import ASRTaskParaformer as ASRTask
config = os.path.join(model_dir, 'config.yaml')
model_file = os.path.join(model_dir, 'model.pb')
cmvn_file = os.path.join(model_dir, 'am.mvn')
model, asr_train_args = ASRTask.build_model_from_file(
config, model_file, cmvn_file, 'cpu'
)
self.frontend = model.frontend
self.export_config["feats_dim"] = 560
elif mode.startswith('offline'):
from funasr.tasks.vad import VADTask
config = os.path.join(model_dir, 'vad.yaml')
model_file = os.path.join(model_dir, 'vad.pb')
cmvn_file = os.path.join(model_dir, 'vad.mvn')
model, vad_infer_args = VADTask.build_model_from_file(
config, model_file, cmvn_file=cmvn_file, device='cpu'
)
self.export_config["feats_dim"] = 400
self.frontend = model.frontend
elif mode.startswith('punc'):
from funasr.tasks.punctuation import PunctuationTask as PUNCTask
punc_train_config = os.path.join(model_dir, 'config.yaml')
punc_model_file = os.path.join(model_dir, 'punc.pb')
model, punc_train_args = PUNCTask.build_model_from_file(
punc_train_config, punc_model_file, 'cpu'
)
elif mode.startswith('punc_VadRealtime'):
from funasr.tasks.punctuation import PunctuationTask as PUNCTask
punc_train_config = os.path.join(model_dir, 'config.yaml')
punc_model_file = os.path.join(model_dir, 'punc.pb')
model, punc_train_args = PUNCTask.build_model_from_file(
punc_train_config, punc_model_file, 'cpu'
)
self._export(model, tag_name)
def _export_onnx(self, model, verbose, path, enc_size=None):
if enc_size:
dummy_input = model.get_dummy_inputs(enc_size)
else:
dummy_input = model.get_dummy_inputs()
# model_script = torch.jit.script(model)
model_script = model #torch.jit.trace(model)
model_path = os.path.join(path, f'{model.model_name}.onnx')
# if not os.path.exists(model_path):
torch.onnx.export(
model_script,
dummy_input,
model_path,
verbose=verbose,
opset_version=14,
input_names=model.get_input_names(),
output_names=model.get_output_names(),
dynamic_axes=model.get_dynamic_axes()
)
if self.quant:
from onnxruntime.quantization import QuantType, quantize_dynamic
import onnx
quant_model_path = os.path.join(path, f'{model.model_name}_quant.onnx')
if not os.path.exists(quant_model_path):
onnx_model = onnx.load(model_path)
nodes = [n.name for n in onnx_model.graph.node]
nodes_to_exclude = [m for m in nodes if 'output' in m or 'bias_encoder' in m or 'bias_decoder' in m]
quantize_dynamic(
model_input=model_path,
model_output=quant_model_path,
op_types_to_quantize=['MatMul'],
per_channel=True,
reduce_range=False,
weight_type=QuantType.QUInt8,
nodes_to_exclude=nodes_to_exclude,
)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
# parser.add_argument('--model-name', type=str, required=True)
parser.add_argument('--model-name', type=str, action="append", required=True, default=[])
parser.add_argument('--export-dir', type=str, required=True)
parser.add_argument('--type', type=str, default='onnx', help='["onnx", "torch"]')
parser.add_argument('--device', type=str, default='cpu', help='["cpu", "cuda"]')
parser.add_argument('--quantize', type=str2bool, default=False, help='export quantized model')
parser.add_argument('--fallback-num', type=int, default=0, help='amp fallback number')
parser.add_argument('--audio_in', type=str, default=None, help='["wav", "wav.scp"]')
parser.add_argument('--calib_num', type=int, default=200, help='calib max num')
parser.add_argument('--model_revision', type=str, default=None, help='model_revision')
args = parser.parse_args()
export_model = ModelExport(
cache_dir=args.export_dir,
onnx=args.type == 'onnx',
device=args.device,
quant=args.quantize,
fallback_num=args.fallback_num,
audio_in=args.audio_in,
calib_num=args.calib_num,
model_revision=args.model_revision,
)
for model_name in args.model_name:
print("export model: {}".format(model_name))
export_model.export(model_name)