From 2ae59b6ce06305724e2eaf30b9f9e93447a7832e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BB=B4=E7=9F=B3?= Date: Mon, 22 Jul 2024 16:58:27 +0800 Subject: [PATCH] ONNX and torchscript export for sensevoice --- .../sense_voice/export.py | 15 ++ funasr/models/sense_voice/export_meta.py | 58 +++---- funasr/utils/export_utils.py | 44 +++--- .../python/libtorch/demo_sensevoicesmall.py | 38 +++++ .../python/libtorch/funasr_torch/__init__.py | 1 + .../libtorch/funasr_torch/sensevoice_bin.py | 130 ++++++++++++++++ .../onnxruntime/demo_sencevoicesmall.py | 38 +++++ .../onnxruntime/funasr_onnx/__init__.py | 1 + .../onnxruntime/funasr_onnx/sensevoice_bin.py | 145 ++++++++++++++++++ 9 files changed, 412 insertions(+), 58 deletions(-) create mode 100644 examples/industrial_data_pretraining/sense_voice/export.py create mode 100644 runtime/python/libtorch/demo_sensevoicesmall.py create mode 100644 runtime/python/libtorch/funasr_torch/sensevoice_bin.py create mode 100644 runtime/python/onnxruntime/demo_sencevoicesmall.py create mode 100644 runtime/python/onnxruntime/funasr_onnx/sensevoice_bin.py diff --git a/examples/industrial_data_pretraining/sense_voice/export.py b/examples/industrial_data_pretraining/sense_voice/export.py new file mode 100644 index 000000000..7376c8a8c --- /dev/null +++ b/examples/industrial_data_pretraining/sense_voice/export.py @@ -0,0 +1,15 @@ +#!/usr/bin/env python3 +# -*- encoding: utf-8 -*- +# Copyright FunASR (https://github.com/FunAudioLLM/SenseVoice). All Rights Reserved. +# MIT License (https://opensource.org/licenses/MIT) + +from funasr import AutoModel + + +model_dir = "iic/SenseVoiceSmall" +model = AutoModel( + model=model_dir, + device="cuda:0", +) + +res = model.export(type="onnx", quantize=False) \ No newline at end of file diff --git a/funasr/models/sense_voice/export_meta.py b/funasr/models/sense_voice/export_meta.py index fe09ee149..449388ef7 100644 --- a/funasr/models/sense_voice/export_meta.py +++ b/funasr/models/sense_voice/export_meta.py @@ -5,31 +5,20 @@ import types import torch -import torch.nn as nn -from funasr.register import tables +from funasr.utils.torch_function import sequence_mask def export_rebuild_model(model, **kwargs): model.device = kwargs.get("device") - is_onnx = kwargs.get("type", "onnx") == "onnx" - # encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export") - # model.encoder = encoder_class(model.encoder, onnx=is_onnx) - - from funasr.utils.torch_function import sequence_mask - model.make_pad_mask = sequence_mask(kwargs["max_seq_len"], flip=False) - model.forward = types.MethodType(export_forward, model) model.export_dummy_inputs = types.MethodType(export_dummy_inputs, model) model.export_input_names = types.MethodType(export_input_names, model) model.export_output_names = types.MethodType(export_output_names, model) model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model) model.export_name = types.MethodType(export_name, model) - - model.export_name = "model" return model - def export_forward( self, speech: torch.Tensor, @@ -38,33 +27,29 @@ def export_forward( textnorm: torch.Tensor, **kwargs, ): - speech = speech.to(device=kwargs["device"]) - speech_lengths = speech_lengths.to(device=kwargs["device"]) - - language_query = self.embed(language).to(speech.device) - - textnorm_query = self.embed(textnorm).to(speech.device) + # speech = speech.to(device="cuda") + # speech_lengths = speech_lengths.to(device="cuda") + language_query = self.embed(language.to(speech.device)).unsqueeze(1) + textnorm_query = self.embed(textnorm.to(speech.device)).unsqueeze(1) + speech = torch.cat((textnorm_query, speech), dim=1) - speech_lengths += 1 - + event_emo_query = self.embed(torch.LongTensor([[1, 2]]).to(speech.device)).repeat( speech.size(0), 1, 1 ) input_query = torch.cat((language_query, event_emo_query), dim=1) speech = torch.cat((input_query, speech), dim=1) - speech_lengths += 3 - - # Encoder - encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths) + + speech_lengths_new = speech_lengths + 4 + encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths_new) + if isinstance(encoder_out, tuple): encoder_out = encoder_out[0] - # c. Passed the encoder result and the beam search - ctc_logits = self.ctc.log_softmax(encoder_out) - + ctc_logits = self.ctc.ctc_lo(encoder_out) + return ctc_logits, encoder_out_lens - def export_dummy_inputs(self): speech = torch.randn(2, 30, 560) speech_lengths = torch.tensor([6, 30], dtype=torch.int32) @@ -72,26 +57,21 @@ def export_dummy_inputs(self): textnorm = torch.tensor([15, 15], dtype=torch.int32) return (speech, speech_lengths, language, textnorm) - def export_input_names(self): return ["speech", "speech_lengths", "language", "textnorm"] - def export_output_names(self): return ["ctc_logits", "encoder_out_lens"] - def export_dynamic_axes(self): return { "speech": {0: "batch_size", 1: "feats_length"}, - "speech_lengths": { - 0: "batch_size", - }, - "logits": {0: "batch_size", 1: "logits_length"}, + "speech_lengths": {0: "batch_size"}, + "language": {0: "batch_size"}, + "textnorm": {0: "batch_size"}, + "ctc_logits": {0: "batch_size", 1: "logits_length"}, + "encoder_out_lens": {0: "batch_size"}, } - -def export_name( - self, -): +def export_name(self): return "model.onnx" diff --git a/funasr/utils/export_utils.py b/funasr/utils/export_utils.py index a6d079871..af9f37b96 100644 --- a/funasr/utils/export_utils.py +++ b/funasr/utils/export_utils.py @@ -54,7 +54,10 @@ def _onnx( verbose = kwargs.get("verbose", False) - export_name = model.export_name + ".onnx" + if isinstance(model.export_name, str): + export_name = model.export_name + ".onnx" + else: + export_name = model.export_name() model_path = os.path.join(export_dir, export_name) torch.onnx.export( model, @@ -72,35 +75,38 @@ def _onnx( import onnx quant_model_path = model_path.replace(".onnx", "_quant.onnx") - if not os.path.exists(quant_model_path): - onnx_model = onnx.load(model_path) - nodes = [n.name for n in onnx_model.graph.node] - nodes_to_exclude = [ - m for m in nodes if "output" in m or "bias_encoder" in m or "bias_decoder" in m - ] - quantize_dynamic( - model_input=model_path, - model_output=quant_model_path, - op_types_to_quantize=["MatMul"], - per_channel=True, - reduce_range=False, - weight_type=QuantType.QUInt8, - nodes_to_exclude=nodes_to_exclude, - ) + onnx_model = onnx.load(model_path) + nodes = [n.name for n in onnx_model.graph.node] + nodes_to_exclude = [ + m for m in nodes if "output" in m or "bias_encoder" in m or "bias_decoder" in m + ] + print("Quantizing model from {} to {}".format(model_path, quant_model_path)) + quantize_dynamic( + model_input=model_path, + model_output=quant_model_path, + op_types_to_quantize=["MatMul"], + per_channel=True, + reduce_range=False, + weight_type=QuantType.QUInt8, + nodes_to_exclude=nodes_to_exclude, + ) def _torchscripts(model, path, device="cuda"): dummy_input = model.export_dummy_inputs() - + if device == "cuda": model = model.cuda() if isinstance(dummy_input, torch.Tensor): dummy_input = dummy_input.cuda() else: dummy_input = tuple([i.cuda() for i in dummy_input]) - + model_script = torch.jit.trace(model, dummy_input) - model_script.save(os.path.join(path, f"{model.export_name}.torchscript")) + if isinstance(model.export_name, str): + model_script.save(os.path.join(path, f"{model.export_name}".replace("onnx", "torchscript"))) + else: + model_script.save(os.path.join(path, f"{model.export_name()}".replace("onnx", "torchscript"))) def _bladedisc_opt(model, model_inputs, enable_fp16=True): diff --git a/runtime/python/libtorch/demo_sensevoicesmall.py b/runtime/python/libtorch/demo_sensevoicesmall.py new file mode 100644 index 000000000..5c70f346d --- /dev/null +++ b/runtime/python/libtorch/demo_sensevoicesmall.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 +# -*- encoding: utf-8 -*- +# Copyright FunASR (https://github.com/FunAudioLLM/SenseVoice). All Rights Reserved. +# MIT License (https://opensource.org/licenses/MIT) + +import os +import torch +from pathlib import Path +from funasr import AutoModel +from funasr_torch import SenseVoiceSmallTorchScript as SenseVoiceSmall +from funasr.utils.postprocess_utils import rich_transcription_postprocess + + +model_dir = "iic/SenseVoiceSmall" +model = AutoModel( + model=model_dir, + device="cuda:0", +) + +# res = model.export(type="torchscript", quantize=False) + +# export model init +model_path = "{}/.cache/modelscope/hub/{}".format(Path.home(), model_dir) +model_bin = SenseVoiceSmall(model_path) + +# build tokenizer +try: + from funasr.tokenizer.sentencepiece_tokenizer import SentencepiecesTokenizer + tokenizer = SentencepiecesTokenizer(bpemodel=os.path.join(model_path, "chn_jpn_yue_eng_ko_spectok.bpe.model")) +except: + tokenizer = None + +# inference +wav_or_scp = "/Users/shixian/Downloads/asr_example_hotword.wav" +language_list = [0] +textnorm_list = [15] +res = model_bin(wav_or_scp, language_list, textnorm_list, tokenizer=tokenizer) +print([rich_transcription_postprocess(i) for i in res]) diff --git a/runtime/python/libtorch/funasr_torch/__init__.py b/runtime/python/libtorch/funasr_torch/__init__.py index 647f9fadc..4669cedbf 100644 --- a/runtime/python/libtorch/funasr_torch/__init__.py +++ b/runtime/python/libtorch/funasr_torch/__init__.py @@ -1,2 +1,3 @@ # -*- encoding: utf-8 -*- from .paraformer_bin import Paraformer +from .sensevoice_bin import SenseVoiceSmallTorchScript diff --git a/runtime/python/libtorch/funasr_torch/sensevoice_bin.py b/runtime/python/libtorch/funasr_torch/sensevoice_bin.py new file mode 100644 index 000000000..d2e3cde00 --- /dev/null +++ b/runtime/python/libtorch/funasr_torch/sensevoice_bin.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python3 +# -*- encoding: utf-8 -*- +# Copyright FunASR (https://github.com/FunAudioLLM/SenseVoice). All Rights Reserved. +# MIT License (https://opensource.org/licenses/MIT) + + +import torch +import os.path +import librosa +import numpy as np +from pathlib import Path +from typing import List, Union, Tuple + +from .utils.utils import ( + CharTokenizer, + get_logger, + read_yaml, +) +from .utils.frontend import WavFrontend + +logging = get_logger() + + +class SenseVoiceSmallTorchScript: + """ + Author: Speech Lab of DAMO Academy, Alibaba Group + Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition + https://arxiv.org/abs/2206.08317 + """ + + def __init__( + self, + model_dir: Union[str, Path] = None, + batch_size: int = 1, + device_id: Union[str, int] = "-1", + plot_timestamp_to: str = "", + quantize: bool = False, + intra_op_num_threads: int = 4, + cache_dir: str = None, + **kwargs, + ): + if quantize: + model_file = os.path.join(model_dir, "model_quant.torchscript") + else: + model_file = os.path.join(model_dir, "model.torchscript") + + config_file = os.path.join(model_dir, "config.yaml") + cmvn_file = os.path.join(model_dir, "am.mvn") + config = read_yaml(config_file) + # token_list = os.path.join(model_dir, "tokens.json") + # with open(token_list, "r", encoding="utf-8") as f: + # token_list = json.load(f) + + # self.converter = TokenIDConverter(token_list) + self.tokenizer = CharTokenizer() + config["frontend_conf"]['cmvn_file'] = cmvn_file + self.frontend = WavFrontend(**config["frontend_conf"]) + self.ort_infer = torch.jit.load(model_file) + self.batch_size = batch_size + self.blank_id = 0 + + def __call__(self, + wav_content: Union[str, np.ndarray, List[str]], + language: List, + textnorm: List, + tokenizer=None, + **kwargs) -> List: + waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq) + waveform_nums = len(waveform_list) + asr_res = [] + for beg_idx in range(0, waveform_nums, self.batch_size): + end_idx = min(waveform_nums, beg_idx + self.batch_size) + feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx]) + ctc_logits, encoder_out_lens = self.ort_infer(torch.Tensor(feats), + torch.Tensor(feats_len), + torch.tensor(language), + torch.tensor(textnorm) + ) + # support batch_size=1 only currently + x = ctc_logits[0, : encoder_out_lens[0].item(), :] + yseq = x.argmax(dim=-1) + yseq = torch.unique_consecutive(yseq, dim=-1) + + mask = yseq != self.blank_id + token_int = yseq[mask].tolist() + + if tokenizer is not None: + asr_res.append(tokenizer.tokens2text(token_int)) + else: + asr_res.append(token_int) + return asr_res + + def load_data(self, wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List: + def load_wav(path: str) -> np.ndarray: + waveform, _ = librosa.load(path, sr=fs) + return waveform + + if isinstance(wav_content, np.ndarray): + return [wav_content] + + if isinstance(wav_content, str): + return [load_wav(wav_content)] + + if isinstance(wav_content, list): + return [load_wav(path) for path in wav_content] + + raise TypeError(f"The type of {wav_content} is not in [str, np.ndarray, list]") + + def extract_feat(self, waveform_list: List[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]: + feats, feats_len = [], [] + for waveform in waveform_list: + speech, _ = self.frontend.fbank(waveform) + feat, feat_len = self.frontend.lfr_cmvn(speech) + feats.append(feat) + feats_len.append(feat_len) + + feats = self.pad_feats(feats, np.max(feats_len)) + feats_len = np.array(feats_len).astype(np.int32) + return feats, feats_len + + @staticmethod + def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray: + def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray: + pad_width = ((0, max_feat_len - cur_len), (0, 0)) + return np.pad(feat, pad_width, "constant", constant_values=0) + + feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats] + feats = np.array(feat_res).astype(np.float32) + return feats + diff --git a/runtime/python/onnxruntime/demo_sencevoicesmall.py b/runtime/python/onnxruntime/demo_sencevoicesmall.py new file mode 100644 index 000000000..27f01799e --- /dev/null +++ b/runtime/python/onnxruntime/demo_sencevoicesmall.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 +# -*- encoding: utf-8 -*- +# Copyright FunASR (https://github.com/FunAudioLLM/SenseVoice). All Rights Reserved. +# MIT License (https://opensource.org/licenses/MIT) + +import os +import torch +from pathlib import Path +from funasr import AutoModel +from funasr_onnx import SenseVoiceSmallONNX as SenseVoiceSmall +from funasr.utils.postprocess_utils import rich_transcription_postprocess + + +model_dir = "iic/SenseVoiceSmall" +model = AutoModel( + model=model_dir, + device="cuda:0", +) + +res = model.export(type="onnx", quantize=False) + +# export model init +model_path = "{}/.cache/modelscope/hub/{}".format(Path.home(), model_dir) +model_bin = SenseVoiceSmall(model_path) + +# build tokenizer +try: + from funasr.tokenizer.sentencepiece_tokenizer import SentencepiecesTokenizer + tokenizer = SentencepiecesTokenizer(bpemodel=os.path.join(model_path, "chn_jpn_yue_eng_ko_spectok.bpe.model")) +except: + tokenizer = None + +# inference +wav_or_scp = "/Users/shixian/Downloads/asr_example_hotword.wav" +language_list = [0] +textnorm_list = [15] +res = model_bin(wav_or_scp, language_list, textnorm_list, tokenizer=tokenizer) +print([rich_transcription_postprocess(i) for i in res]) diff --git a/runtime/python/onnxruntime/funasr_onnx/__init__.py b/runtime/python/onnxruntime/funasr_onnx/__init__.py index d0d665152..42566299c 100644 --- a/runtime/python/onnxruntime/funasr_onnx/__init__.py +++ b/runtime/python/onnxruntime/funasr_onnx/__init__.py @@ -4,3 +4,4 @@ from .vad_bin import Fsmn_vad from .vad_bin import Fsmn_vad_online from .punc_bin import CT_Transformer from .punc_bin import CT_Transformer_VadRealtime +from .sensevoice_bin import SenseVoiceSmallONNX diff --git a/runtime/python/onnxruntime/funasr_onnx/sensevoice_bin.py b/runtime/python/onnxruntime/funasr_onnx/sensevoice_bin.py new file mode 100644 index 000000000..fcfcede35 --- /dev/null +++ b/runtime/python/onnxruntime/funasr_onnx/sensevoice_bin.py @@ -0,0 +1,145 @@ +#!/usr/bin/env python3 +# -*- encoding: utf-8 -*- +# Copyright FunASR (https://github.com/FunAudioLLM/SenseVoice). All Rights Reserved. +# MIT License (https://opensource.org/licenses/MIT) + + +import torch +import os.path +import librosa +import numpy as np +from pathlib import Path +from typing import List, Union, Tuple + +from .utils.utils import ( + CharTokenizer, + Hypothesis, + ONNXRuntimeError, + OrtInferSession, + TokenIDConverter, + get_logger, + read_yaml, +) +from .utils.frontend import WavFrontend + +logging = get_logger() + + +class SenseVoiceSmallONNX: + """ + Author: Speech Lab of DAMO Academy, Alibaba Group + Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition + https://arxiv.org/abs/2206.08317 + """ + + def __init__( + self, + model_dir: Union[str, Path] = None, + batch_size: int = 1, + device_id: Union[str, int] = "-1", + plot_timestamp_to: str = "", + quantize: bool = False, + intra_op_num_threads: int = 4, + cache_dir: str = None, + **kwargs, + ): + if quantize: + model_file = os.path.join(model_dir, "model_quant.onnx") + else: + model_file = os.path.join(model_dir, "model.onnx") + + config_file = os.path.join(model_dir, "config.yaml") + cmvn_file = os.path.join(model_dir, "am.mvn") + config = read_yaml(config_file) + # token_list = os.path.join(model_dir, "tokens.json") + # with open(token_list, "r", encoding="utf-8") as f: + # token_list = json.load(f) + + # self.converter = TokenIDConverter(token_list) + self.tokenizer = CharTokenizer() + config["frontend_conf"]['cmvn_file'] = cmvn_file + self.frontend = WavFrontend(**config["frontend_conf"]) + self.ort_infer = OrtInferSession( + model_file, device_id, intra_op_num_threads=intra_op_num_threads + ) + self.batch_size = batch_size + self.blank_id = 0 + + def __call__(self, + wav_content: Union[str, np.ndarray, List[str]], + language: List, + textnorm: List, + tokenizer=None, + **kwargs) -> List: + waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq) + waveform_nums = len(waveform_list) + asr_res = [] + for beg_idx in range(0, waveform_nums, self.batch_size): + end_idx = min(waveform_nums, beg_idx + self.batch_size) + feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx]) + ctc_logits, encoder_out_lens = self.infer(feats, + feats_len, + np.array(language, dtype=np.int32), + np.array(textnorm, dtype=np.int32) + ) + # back to torch.Tensor + ctc_logits = torch.from_numpy(ctc_logits).float() + # support batch_size=1 only currently + x = ctc_logits[0, : encoder_out_lens[0].item(), :] + yseq = x.argmax(dim=-1) + yseq = torch.unique_consecutive(yseq, dim=-1) + + mask = yseq != self.blank_id + token_int = yseq[mask].tolist() + + if tokenizer is not None: + asr_res.append(tokenizer.tokens2text(token_int)) + else: + asr_res.append(token_int) + return asr_res + + def load_data(self, wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List: + def load_wav(path: str) -> np.ndarray: + waveform, _ = librosa.load(path, sr=fs) + return waveform + + if isinstance(wav_content, np.ndarray): + return [wav_content] + + if isinstance(wav_content, str): + return [load_wav(wav_content)] + + if isinstance(wav_content, list): + return [load_wav(path) for path in wav_content] + + raise TypeError(f"The type of {wav_content} is not in [str, np.ndarray, list]") + + def extract_feat(self, waveform_list: List[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]: + feats, feats_len = [], [] + for waveform in waveform_list: + speech, _ = self.frontend.fbank(waveform) + feat, feat_len = self.frontend.lfr_cmvn(speech) + feats.append(feat) + feats_len.append(feat_len) + + feats = self.pad_feats(feats, np.max(feats_len)) + feats_len = np.array(feats_len).astype(np.int32) + return feats, feats_len + + @staticmethod + def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray: + def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray: + pad_width = ((0, max_feat_len - cur_len), (0, 0)) + return np.pad(feat, pad_width, "constant", constant_values=0) + + feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats] + feats = np.array(feat_res).astype(np.float32) + return feats + + def infer(self, + feats: np.ndarray, + feats_len: np.ndarray, + language: np.ndarray, + textnorm: np.ndarray,) -> Tuple[np.ndarray, np.ndarray]: + outputs = self.ort_infer([feats, feats_len, language, textnorm]) + return outputs