From 0c3c9be2c4c1c4e4da4628c3987708c9a0763391 Mon Sep 17 00:00:00 2001 From: will_wang <53147925+willnufe@users.noreply.github.com> Date: Wed, 4 Dec 2024 17:47:31 +0800 Subject: [PATCH] =?UTF-8?q?paraformer=20onnx=20fp16=E5=AF=BC=E5=87=BA?= =?UTF-8?q?=E6=96=B9=E6=A1=88=20(#2264)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * onnx fp16模型 * paraformer-offline [fp32 fp16 onnx-gpu] * paraformer-offline [fp32 fp16 onnx-gpu] * Update export.py --------- Co-authored-by: zhifu gao --- export.py | 9 +++ funasr/models/paraformer/cif_predictor.py | 12 ++-- funasr/models/paraformer/export_meta.py | 1 + funasr/utils/export_utils.py | 79 ++++++++++++++++++++++- 4 files changed, 93 insertions(+), 8 deletions(-) create mode 100644 export.py diff --git a/export.py b/export.py new file mode 100644 index 000000000..3307e96e1 --- /dev/null +++ b/export.py @@ -0,0 +1,9 @@ +# method2, inference from local path +from funasr import AutoModel + +model = AutoModel( + model="/raid/t3cv/wangch/WORK_SAPCE/ASR/models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" +) + +res = model.export(type="onnx", quantize=False, opset_version=13, device='cuda') # fp32 onnx-gpu +# res = model.export(type="onnx_fp16", quantize=False, opset_version=13, device='cuda') # fp16 onnx-gpu diff --git a/funasr/models/paraformer/cif_predictor.py b/funasr/models/paraformer/cif_predictor.py index 24145cd33..d59705032 100644 --- a/funasr/models/paraformer/cif_predictor.py +++ b/funasr/models/paraformer/cif_predictor.py @@ -245,7 +245,7 @@ class CifPredictorV2(torch.nn.Module): hidden, alphas, token_num, mask=None ) - acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold) + acoustic_embeds, cif_peak = cif_v1(hidden, alphas, self.threshold) if target_length is None and self.tail_threshold > 0.0: token_num_int = torch.max(token_num).type(torch.int32).item() acoustic_embeds = acoustic_embeds[:, :token_num_int, :] @@ -449,7 +449,7 @@ class CifPredictorV2Export(torch.nn.Module): mask = mask.transpose(-1, -2).float() mask = mask.squeeze(-1) hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, mask=mask) - acoustic_embeds, cif_peak = cif_export(hidden, alphas, self.threshold) + acoustic_embeds, cif_peak = cif_v1_export(hidden, alphas, self.threshold) return acoustic_embeds, token_num, alphas, cif_peak @@ -522,7 +522,7 @@ def cif_v1_export(hidden, alphas, threshold: float): fires = fires + prefix_sum - prefix_sum_floor # prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1) - prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1) + prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).repeat((1, 1, hidden_size)) * hidden, dim=1) frames = prefix_sum_hidden[fire_idxs] shift_frames = torch.roll(frames, 1, dims=0) @@ -534,7 +534,7 @@ def cif_v1_export(hidden, alphas, threshold: float): remains = fires - torch.floor(fires) # remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs] - remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs] + remain_frames = remains[fire_idxs].unsqueeze(-1).repeat((1, hidden_size)) * hidden[fire_idxs] shift_remain_frames = torch.roll(remain_frames, 1, dims=0) shift_remain_frames[shift_batch_idxs] = 0 @@ -702,7 +702,7 @@ def cif_v1(hidden, alphas, threshold): # frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device) # prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1) frames = torch.zeros(batch_size, len_time, hidden_size, dtype=dtype, device=device) - prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).tile((1, 1, hidden_size)) * hidden, dim=1) + prefix_sum_hidden = torch.cumsum(alphas.unsqueeze(-1).repeat((1, 1, hidden_size)) * hidden, dim=1) frames = prefix_sum_hidden[fire_idxs] shift_frames = torch.roll(frames, 1, dims=0) @@ -715,7 +715,7 @@ def cif_v1(hidden, alphas, threshold): remains = fires - torch.floor(fires) # remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs] - remain_frames = remains[fire_idxs].unsqueeze(-1).tile((1, hidden_size)) * hidden[fire_idxs] + remain_frames = remains[fire_idxs].unsqueeze(-1).repeat((1, hidden_size)) * hidden[fire_idxs] shift_remain_frames = torch.roll(remain_frames, 1, dims=0) shift_remain_frames[shift_batch_idxs] = 0 diff --git a/funasr/models/paraformer/export_meta.py b/funasr/models/paraformer/export_meta.py index d3ace8fd6..f6a83c3b5 100644 --- a/funasr/models/paraformer/export_meta.py +++ b/funasr/models/paraformer/export_meta.py @@ -77,6 +77,7 @@ def export_dynamic_axes(self): 0: "batch_size", }, "logits": {0: "batch_size", 1: "logits_length"}, + "token_num": {0: "batch_size"} } diff --git a/funasr/utils/export_utils.py b/funasr/utils/export_utils.py index af9f37b96..667418ccb 100644 --- a/funasr/utils/export_utils.py +++ b/funasr/utils/export_utils.py @@ -1,6 +1,12 @@ import os import torch import functools +import onnx +from onnxconverter_common import float16 + +import warnings +warnings.filterwarnings("ignore") + def export( @@ -35,8 +41,17 @@ def export( if hasattr(m, "encoder") and hasattr(m, "decoder"): _bladedisc_opt_for_encdec(m, path=export_dir, enable_fp16=True) else: + print(f"export_dir: {export_dir}") _torchscripts(m, path=export_dir, device="cuda") - print("output dir: {}".format(export_dir)) + + + elif type=='onnx_fp16': + assert ( + torch.cuda.is_available() + ), "Currently onnx_fp16 optimization for FunASR only supports GPU" + + if hasattr(m, "encoder") and hasattr(m, "decoder"): + _onnx_opt_for_encdec(m, path=export_dir, enable_fp16=True) return export_dir @@ -51,6 +66,8 @@ def _onnx( ): dummy_input = model.export_dummy_inputs() + dummy_input = (dummy_input[0].to("cuda"), dummy_input[1].to("cuda")) + verbose = kwargs.get("verbose", False) @@ -64,6 +81,7 @@ def _onnx( dummy_input, model_path, verbose=verbose, + do_constant_folding=True, opset_version=opset_version, input_names=model.export_input_names(), output_names=model.export_output_names(), @@ -159,7 +177,7 @@ def _rescale_encoder_model(model, input_data): # Rescale encoder modules fp16_scale = int(2 * absmax // 65536) - print(f"rescale encoder modules with factor={fp16_scale}") + print(f"rescale encoder modules with factor={fp16_scale}\n\n") model.encoder.model.encoders0.register_forward_pre_hook( functools.partial(_rescale_input_hook, scale=fp16_scale), ) @@ -200,3 +218,60 @@ def _bladedisc_opt_for_encdec(model, path, enable_fp16): model.decoder = _bladedisc_opt(model.decoder, tuple(decoder_inputs)) model_script = torch.jit.trace(model, input_data) model_script.save(os.path.join(path, f"{model.export_name}_blade.torchscript")) + + + +def _onnx_opt_for_encdec(model, path, enable_fp16): + + # Get input data + # TODO: better to use real data + input_data = model.export_dummy_inputs() + + if isinstance(input_data, torch.Tensor): + input_data = input_data.cuda() + else: + input_data = tuple([i.cuda() for i in input_data]) + + # Get input data for decoder module + decoder_inputs = list() + + def get_input_hook(m, x): + decoder_inputs.extend(list(x)) + + hook = model.decoder.register_forward_pre_hook(get_input_hook) + model = model.cuda() + model(*input_data) + hook.remove() + + # Prevent FP16 overflow + if enable_fp16: + _rescale_encoder_model(model, input_data) + + fp32_model_path = f"{path}/{model.export_name}_hook.onnx" + print("*" * 50) + print(f"[_onnx_opt_for_encdec(fp32)]: {fp32_model_path}\n\n") + if not os.path.exists(fp32_model_path): + + torch.onnx.export( + model, + input_data, + fp32_model_path, + verbose=False, + do_constant_folding=True, + opset_version=13, + input_names=model.export_input_names(), + output_names=model.export_output_names(), + dynamic_axes=model.export_dynamic_axes(), + ) + + + # fp32 to fp16 + fp16_model_path = f"{path}/{model.export_name}_hook_fp16.onnx" + print("*" * 50) + print(f"[_onnx_opt_for_encdec(fp16)]: {fp16_model_path}\n\n") + if os.path.exists(fp32_model_path) and not os.path.exists(fp16_model_path): + fp32_onnx_model = onnx.load(fp32_model_path) + fp16_onnx_model = float16.convert_float_to_float16(fp32_onnx_model, keep_io_types=True) + onnx.save( + fp16_onnx_model, fp16_model_path + )