From ee9569ceef0c9707c8877d6b65733621dfbd3aeb Mon Sep 17 00:00:00 2001 From: "shixian.shi" Date: Tue, 15 Aug 2023 17:31:27 +0800 Subject: [PATCH 1/4] Contextual Paraformer onnx export --- funasr/export/export_model.py | 52 ++--- funasr/export/models/__init__.py | 10 +- .../models/decoder/contextual_decoder.py | 191 ++++++++++++++++++ .../models/e2e_asr_contextual_paraformer.py | 174 ++++++++++++++++ .../onnxruntime/demo_contextual_paraformer.py | 11 + .../onnxruntime/funasr_onnx/__init__.py | 2 +- .../onnxruntime/funasr_onnx/paraformer_bin.py | 144 +++++++++++++ .../onnxruntime/funasr_onnx/utils/utils.py | 47 +++++ 8 files changed, 604 insertions(+), 27 deletions(-) create mode 100644 funasr/export/models/decoder/contextual_decoder.py create mode 100644 funasr/export/models/e2e_asr_contextual_paraformer.py create mode 100644 funasr/runtime/python/onnxruntime/demo_contextual_paraformer.py diff --git a/funasr/export/export_model.py b/funasr/export/export_model.py index 8c3108bf7..e0a9313ed 100644 --- a/funasr/export/export_model.py +++ b/funasr/export/export_model.py @@ -1,14 +1,11 @@ -import json -from typing import Union, Dict -from pathlib import Path - import os -import logging import torch - -from funasr.export.models import get_model -import numpy as np import random +import logging +import numpy as np +from pathlib import Path +from typing import Union, Dict, List +from funasr.export.models import get_model from funasr.utils.types import str2bool, str2triple_str # torch_version = float(".".join(torch.__version__.split(".")[:2])) # assert torch_version > 1.9 @@ -55,20 +52,25 @@ class ModelExport: # export encoder1 self.export_config["model_name"] = "model" - models = get_model( + model = get_model( model, self.export_config, ) - if not isinstance(models, tuple): - models = (models,) - - for i, model in enumerate(models): + if isinstance(model, List): + for m in model: + m.eval() + if self.onnx: + self._export_onnx(m, verbose, export_dir) + else: + self._export_torchscripts(m, verbose, export_dir) + print("output dir: {}".format(export_dir)) + else: model.eval() + # self._export_onnx(model, verbose, export_dir) if self.onnx: self._export_onnx(model, verbose, export_dir) else: self._export_torchscripts(model, verbose, export_dir) - print("output dir: {}".format(export_dir)) @@ -233,17 +235,17 @@ class ModelExport: # model_script = torch.jit.script(model) model_script = model #torch.jit.trace(model) model_path = os.path.join(path, f'{model.model_name}.onnx') - if not os.path.exists(model_path): - torch.onnx.export( - model_script, - dummy_input, - model_path, - verbose=verbose, - opset_version=14, - input_names=model.get_input_names(), - output_names=model.get_output_names(), - dynamic_axes=model.get_dynamic_axes() - ) + # if not os.path.exists(model_path): + torch.onnx.export( + model_script, + dummy_input, + model_path, + verbose=verbose, + opset_version=14, + input_names=model.get_input_names(), + output_names=model.get_output_names(), + dynamic_axes=model.get_dynamic_axes() + ) if self.quant: from onnxruntime.quantization import QuantType, quantize_dynamic diff --git a/funasr/export/models/__init__.py b/funasr/export/models/__init__.py index fd0a15c9c..cba92a865 100644 --- a/funasr/export/models/__init__.py +++ b/funasr/export/models/__init__.py @@ -12,9 +12,17 @@ from funasr.models.vad_realtime_transformer import VadRealtimeTransformer from funasr.export.models.CT_Transformer import CT_Transformer_VadRealtime as CT_Transformer_VadRealtime_export from funasr.export.models.e2e_asr_paraformer import ParaformerOnline_encoder_predictor as ParaformerOnline_encoder_predictor_export from funasr.export.models.e2e_asr_paraformer import ParaformerOnline_decoder as ParaformerOnline_decoder_export +from funasr.export.models.e2e_asr_contextual_paraformer import ContextualParaformer_backbone as ContextualParaformer_backbone_export +from funasr.export.models.e2e_asr_contextual_paraformer import ContextualParaformer_embedder as ContextualParaformer_embedder_export +from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer + def get_model(model, export_config=None): - if isinstance(model, BiCifParaformer): + if isinstance(model, NeatContextualParaformer): + backbone = ContextualParaformer_backbone_export(model, **export_config) + embedder = ContextualParaformer_embedder_export(model, **export_config) + return [embedder, backbone] + elif isinstance(model, BiCifParaformer): return BiCifParaformer_export(model, **export_config) elif isinstance(model, ParaformerOnline): return (ParaformerOnline_encoder_predictor_export(model, model_name="model"), diff --git a/funasr/export/models/decoder/contextual_decoder.py b/funasr/export/models/decoder/contextual_decoder.py new file mode 100644 index 000000000..4e11b5d6b --- /dev/null +++ b/funasr/export/models/decoder/contextual_decoder.py @@ -0,0 +1,191 @@ +import os +import torch +import torch.nn as nn + +from funasr.export.utils.torch_function import MakePadMask +from funasr.export.utils.torch_function import sequence_mask +from funasr.modules.attention import MultiHeadedAttentionSANMDecoder +from funasr.export.models.modules.multihead_att import MultiHeadedAttentionSANMDecoder as MultiHeadedAttentionSANMDecoder_export +from funasr.modules.attention import MultiHeadedAttentionCrossAtt +from funasr.export.models.modules.multihead_att import MultiHeadedAttentionCrossAtt as MultiHeadedAttentionCrossAtt_export +from funasr.modules.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM +from funasr.export.models.modules.feedforward import PositionwiseFeedForwardDecoderSANM as PositionwiseFeedForwardDecoderSANM_export +from funasr.export.models.modules.decoder_layer import DecoderLayerSANM as DecoderLayerSANM_export + + +class ContextualSANMDecoder(nn.Module): + def __init__(self, model, + max_seq_len=512, + model_name='decoder', + onnx: bool = True,): + super().__init__() + # self.embed = model.embed #Embedding(model.embed, max_seq_len) + self.model = model + if onnx: + self.make_pad_mask = MakePadMask(max_seq_len, flip=False) + else: + self.make_pad_mask = sequence_mask(max_seq_len, flip=False) + + for i, d in enumerate(self.model.decoders): + if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM): + d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward) + if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder): + d.self_attn = MultiHeadedAttentionSANMDecoder_export(d.self_attn) + if isinstance(d.src_attn, MultiHeadedAttentionCrossAtt): + d.src_attn = MultiHeadedAttentionCrossAtt_export(d.src_attn) + self.model.decoders[i] = DecoderLayerSANM_export(d) + + if self.model.decoders2 is not None: + for i, d in enumerate(self.model.decoders2): + if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM): + d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward) + if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder): + d.self_attn = MultiHeadedAttentionSANMDecoder_export(d.self_attn) + self.model.decoders2[i] = DecoderLayerSANM_export(d) + + for i, d in enumerate(self.model.decoders3): + if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM): + d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward) + self.model.decoders3[i] = DecoderLayerSANM_export(d) + + self.output_layer = model.output_layer + self.after_norm = model.after_norm + self.model_name = model_name + + # bias decoder + if isinstance(self.model.bias_decoder.src_attn, MultiHeadedAttentionCrossAtt): + self.model.bias_decoder.src_attn = MultiHeadedAttentionCrossAtt_export(self.model.bias_decoder.src_attn) + self.bias_decoder = self.model.bias_decoder + # last decoder + if isinstance(self.model.last_decoder.src_attn, MultiHeadedAttentionCrossAtt): + self.model.last_decoder.src_attn = MultiHeadedAttentionCrossAtt_export(self.model.last_decoder.src_attn) + if isinstance(self.model.last_decoder.self_attn, MultiHeadedAttentionSANMDecoder): + self.model.last_decoder.self_attn = MultiHeadedAttentionSANMDecoder_export(self.model.last_decoder.self_attn) + if isinstance(self.model.last_decoder.feed_forward, PositionwiseFeedForwardDecoderSANM): + self.model.last_decoder.feed_forward = PositionwiseFeedForwardDecoderSANM_export(self.model.last_decoder.feed_forward) + self.last_decoder = self.model.last_decoder + self.bias_output = self.model.bias_output + self.dropout = self.model.dropout + + + def prepare_mask(self, mask): + mask_3d_btd = mask[:, :, None] + if len(mask.shape) == 2: + mask_4d_bhlt = 1 - mask[:, None, None, :] + elif len(mask.shape) == 3: + mask_4d_bhlt = 1 - mask[:, None, :] + mask_4d_bhlt = mask_4d_bhlt * -10000.0 + + return mask_3d_btd, mask_4d_bhlt + + def forward( + self, + hs_pad: torch.Tensor, + hlens: torch.Tensor, + ys_in_pad: torch.Tensor, + ys_in_lens: torch.Tensor, + bias_embed: torch.Tensor, + ): + + tgt = ys_in_pad + tgt_mask = self.make_pad_mask(ys_in_lens) + tgt_mask, _ = self.prepare_mask(tgt_mask) + # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None] + + memory = hs_pad + memory_mask = self.make_pad_mask(hlens) + _, memory_mask = self.prepare_mask(memory_mask) + # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :] + + x = tgt + x, tgt_mask, memory, memory_mask, _ = self.model.decoders( + x, tgt_mask, memory, memory_mask + ) + + _, _, x_self_attn, x_src_attn = self.last_decoder( + x, tgt_mask, memory, memory_mask + ) + + # contextual paraformer related + contextual_length = torch.Tensor([bias_embed.shape[1]]).int().repeat(hs_pad.shape[0]) + # contextual_mask = myutils.sequence_mask(contextual_length, device=memory.device)[:, None, :] + contextual_mask = self.make_pad_mask(contextual_length) + contextual_mask, _ = self.prepare_mask(contextual_mask) + # import pdb; pdb.set_trace() + contextual_mask = contextual_mask.transpose(2, 1).unsqueeze(1) + cx, tgt_mask, _, _, _ = self.bias_decoder(x_self_attn, tgt_mask, bias_embed, memory_mask=contextual_mask) + + if self.bias_output is not None: + x = torch.cat([x_src_attn, cx], dim=2) + x = self.bias_output(x.transpose(1, 2)).transpose(1, 2) # 2D -> D + x = x_self_attn + self.dropout(x) + + if self.model.decoders2 is not None: + x, tgt_mask, memory, memory_mask, _ = self.model.decoders2( + x, tgt_mask, memory, memory_mask + ) + x, tgt_mask, memory, memory_mask, _ = self.model.decoders3( + x, tgt_mask, memory, memory_mask + ) + x = self.after_norm(x) + x = self.output_layer(x) + + return x, ys_in_lens + + + def get_dummy_inputs(self, enc_size): + tgt = torch.LongTensor([0]).unsqueeze(0) + memory = torch.randn(1, 100, enc_size) + pre_acoustic_embeds = torch.randn(1, 1, enc_size) + cache_num = len(self.model.decoders) + len(self.model.decoders2) + cache = [ + torch.zeros((1, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size)) + for _ in range(cache_num) + ] + return (tgt, memory, pre_acoustic_embeds, cache) + + def is_optimizable(self): + return True + + def get_input_names(self): + cache_num = len(self.model.decoders) + len(self.model.decoders2) + return ['tgt', 'memory', 'pre_acoustic_embeds'] \ + + ['cache_%d' % i for i in range(cache_num)] + + def get_output_names(self): + cache_num = len(self.model.decoders) + len(self.model.decoders2) + return ['y'] \ + + ['out_cache_%d' % i for i in range(cache_num)] + + def get_dynamic_axes(self): + ret = { + 'tgt': { + 0: 'tgt_batch', + 1: 'tgt_length' + }, + 'memory': { + 0: 'memory_batch', + 1: 'memory_length' + }, + 'pre_acoustic_embeds': { + 0: 'acoustic_embeds_batch', + 1: 'acoustic_embeds_length', + } + } + cache_num = len(self.model.decoders) + len(self.model.decoders2) + ret.update({ + 'cache_%d' % d: { + 0: 'cache_%d_batch' % d, + 2: 'cache_%d_length' % d + } + for d in range(cache_num) + }) + return ret + + def get_model_config(self, path): + return { + "dec_type": "XformerDecoder", + "model_path": os.path.join(path, f'{self.model_name}.onnx'), + "n_layers": len(self.model.decoders) + len(self.model.decoders2), + "odim": self.model.decoders[0].size + } diff --git a/funasr/export/models/e2e_asr_contextual_paraformer.py b/funasr/export/models/e2e_asr_contextual_paraformer.py new file mode 100644 index 000000000..61806c919 --- /dev/null +++ b/funasr/export/models/e2e_asr_contextual_paraformer.py @@ -0,0 +1,174 @@ +from audioop import bias +import logging +import torch +import torch.nn as nn +import numpy as np + +from funasr.export.utils.torch_function import MakePadMask +from funasr.export.utils.torch_function import sequence_mask +from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt +from funasr.models.encoder.conformer_encoder import ConformerEncoder +from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export +from funasr.export.models.encoder.conformer_encoder import ConformerEncoder as ConformerEncoder_export +from funasr.models.predictor.cif import CifPredictorV2 +from funasr.export.models.predictor.cif import CifPredictorV2 as CifPredictorV2_export +from funasr.models.decoder.sanm_decoder import ParaformerSANMDecoder +from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN +from funasr.export.models.decoder.sanm_decoder import ParaformerSANMDecoder as ParaformerSANMDecoder_export +from funasr.export.models.decoder.transformer_decoder import ParaformerDecoderSAN as ParaformerDecoderSAN_export +from funasr.export.models.decoder.contextual_decoder import ContextualSANMDecoder as ContextualSANMDecoder_export +from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder + + +class ContextualParaformer_backbone(nn.Module): + """ + Author: Speech Lab of DAMO Academy, Alibaba Group + Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition + https://arxiv.org/abs/2206.08317 + """ + + def __init__( + self, + model, + max_seq_len=512, + feats_dim=560, + model_name='model', + **kwargs, + ): + super().__init__() + onnx = False + if "onnx" in kwargs: + onnx = kwargs["onnx"] + if isinstance(model.encoder, SANMEncoder): + self.encoder = SANMEncoder_export(model.encoder, onnx=onnx) + elif isinstance(model.encoder, ConformerEncoder): + self.encoder = ConformerEncoder_export(model.encoder, onnx=onnx) + if isinstance(model.predictor, CifPredictorV2): + self.predictor = CifPredictorV2_export(model.predictor) + + # decoder + if isinstance(model.decoder, ContextualParaformerDecoder): + self.decoder = ContextualSANMDecoder_export(model.decoder, onnx=onnx) + elif isinstance(model.decoder, ParaformerSANMDecoder): + self.decoder = ParaformerSANMDecoder_export(model.decoder, onnx=onnx) + elif isinstance(model.decoder, ParaformerDecoderSAN): + self.decoder = ParaformerDecoderSAN_export(model.decoder, onnx=onnx) + + self.feats_dim = feats_dim + self.model_name = model_name + '_bb' + + if onnx: + self.make_pad_mask = MakePadMask(max_seq_len, flip=False) + else: + self.make_pad_mask = sequence_mask(max_seq_len, flip=False) + + def forward( + self, + speech: torch.Tensor, + speech_lengths: torch.Tensor, + bias_embed: torch.Tensor, + ): + # a. To device + batch = {"speech": speech, "speech_lengths": speech_lengths} + # batch = to_device(batch, device=self.device) + + enc, enc_len = self.encoder(**batch) + mask = self.make_pad_mask(enc_len)[:, None, :] + pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(enc, mask) + pre_token_length = pre_token_length.floor().type(torch.int32) + + # bias_embed = bias_embed. squeeze(0).repeat([enc.shape[0], 1, 1]) + + decoder_out, _ = self.decoder(enc, enc_len, pre_acoustic_embeds, pre_token_length, bias_embed) + decoder_out = torch.log_softmax(decoder_out, dim=-1) + # sample_ids = decoder_out.argmax(dim=-1) + return decoder_out, pre_token_length + + def get_dummy_inputs(self): + speech = torch.randn(2, 30, self.feats_dim) + speech_lengths = torch.tensor([6, 30], dtype=torch.int32) + bias_embed = torch.randn(2, 1, 512) + return (speech, speech_lengths, bias_embed) + + def get_dummy_inputs_txt(self, txt_file: str = "/mnt/workspace/data_fbank/0207/12345.wav.fea.txt"): + import numpy as np + fbank = np.loadtxt(txt_file) + fbank_lengths = np.array([fbank.shape[0], ], dtype=np.int32) + speech = torch.from_numpy(fbank[None, :, :].astype(np.float32)) + speech_lengths = torch.from_numpy(fbank_lengths.astype(np.int32)) + return (speech, speech_lengths) + + def get_input_names(self): + return ['speech', 'speech_lengths', 'bias_embed'] + + def get_output_names(self): + return ['logits', 'token_num'] + + def get_dynamic_axes(self): + return { + 'speech': { + 0: 'batch_size', + 1: 'feats_length' + }, + 'speech_lengths': { + 0: 'batch_size', + }, + 'bias_embed': { + 0: 'batch_size', + 1: 'num_hotwords' + }, + 'logits': { + 0: 'batch_size', + 1: 'logits_length' + }, + } + + +class ContextualParaformer_embedder(nn.Module): + def __init__(self, + model, + max_seq_len=512, + feats_dim=560, + model_name='model', + **kwargs,): + super().__init__() + self.embedding = model.bias_embed + model.bias_encoder.batch_first = False + self.bias_encoder = model.bias_encoder + # self.bias_encoder.batch_first = False + self.feats_dim = feats_dim + self.model_name = "{}_eb".format(model_name) + + def forward(self, hotword): + hotword = self.embedding(hotword).transpose(0, 1) # batch second + hw_embed, (_, _) = self.bias_encoder(hotword) + return hw_embed + + def get_dummy_inputs(self): + hotword = torch.tensor([ + [10, 11, 12, 13, 14, 10, 11, 12, 13, 14], + [100, 101, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [10, 11, 12, 13, 14, 10, 11, 12, 13, 14], + [100, 101, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + dtype=torch.int32) + # hotword_length = torch.tensor([10, 2, 1], dtype=torch.int32) + return (hotword) + + def get_input_names(self): + return ['hotword'] + + def get_output_names(self): + return ['hw_embed'] + + def get_dynamic_axes(self): + return { + 'hotword': { + 0: 'num_hotwords', + }, + 'hw_embed': { + 0: 'num_hotwords', + }, + } \ No newline at end of file diff --git a/funasr/runtime/python/onnxruntime/demo_contextual_paraformer.py b/funasr/runtime/python/onnxruntime/demo_contextual_paraformer.py new file mode 100644 index 000000000..984c0d6bc --- /dev/null +++ b/funasr/runtime/python/onnxruntime/demo_contextual_paraformer.py @@ -0,0 +1,11 @@ +from funasr_onnx import ContextualParaformer +from pathlib import Path + +model_dir = "./export/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404" +model = ContextualParaformer(model_dir, batch_size=1) + +wav_path = ['{}/.cache/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/example/asr_example.wav'.format(Path.home())] +hotwords = '随机热词 各种热词 魔搭 阿里巴巴' + +result = model(wav_path, hotwords) +print(result) diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/__init__.py b/funasr/runtime/python/onnxruntime/funasr_onnx/__init__.py index 7d8d6620f..c03d7e52a 100644 --- a/funasr/runtime/python/onnxruntime/funasr_onnx/__init__.py +++ b/funasr/runtime/python/onnxruntime/funasr_onnx/__init__.py @@ -1,5 +1,5 @@ # -*- encoding: utf-8 -*- -from .paraformer_bin import Paraformer +from .paraformer_bin import Paraformer, ContextualParaformer from .vad_bin import Fsmn_vad from .vad_bin import Fsmn_vad_online from .punc_bin import CT_Transformer diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py b/funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py index f3e0f3d2b..5f866b808 100644 --- a/funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py +++ b/funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py @@ -7,6 +7,7 @@ from pathlib import Path from typing import List, Union, Tuple import copy +import torch import librosa import numpy as np @@ -16,6 +17,7 @@ from .utils.utils import (CharTokenizer, Hypothesis, ONNXRuntimeError, from .utils.postprocess_utils import sentence_postprocess from .utils.frontend import WavFrontend from .utils.timestamp_utils import time_stamp_lfr6_onnx +from .utils.utils import pad_list, make_pad_mask logging = get_logger() @@ -210,3 +212,145 @@ class Paraformer(): # texts = sentence_postprocess(token) return token + +class ContextualParaformer(Paraformer): + """ + Author: Speech Lab of DAMO Academy, Alibaba Group + Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition + https://arxiv.org/abs/2206.08317 + """ + def __init__(self, model_dir: Union[str, Path] = None, + batch_size: int = 1, + device_id: Union[str, int] = "-1", + plot_timestamp_to: str = "", + quantize: bool = False, + intra_op_num_threads: int = 4, + cache_dir: str = None + ): + + if not Path(model_dir).exists(): + from modelscope.hub.snapshot_download import snapshot_download + try: + model_dir = snapshot_download(model_dir, cache_dir=cache_dir) + except: + raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(model_dir) + + model_bb_file = os.path.join(model_dir, 'model_bb.onnx') + model_eb_file = os.path.join(model_dir, 'model_eb.onnx') + + token_list_file = os.path.join(model_dir, 'tokens.txt') + self.vocab = {} + with open(Path(token_list_file), 'r') as fin: + for i, line in enumerate(fin.readlines()): + self.vocab[line.strip()] = i + + #if quantize: + # model_file = os.path.join(model_dir, 'model_quant.onnx') + #if not os.path.exists(model_file): + # logging.error(".onnx model not exist, please export first.") + + config_file = os.path.join(model_dir, 'config.yaml') + cmvn_file = os.path.join(model_dir, 'am.mvn') + config = read_yaml(config_file) + + self.converter = TokenIDConverter(config['token_list']) + self.tokenizer = CharTokenizer() + self.frontend = WavFrontend( + cmvn_file=cmvn_file, + **config['frontend_conf'] + ) + self.ort_infer_bb = OrtInferSession(model_bb_file, device_id, intra_op_num_threads=intra_op_num_threads) + self.ort_infer_eb = OrtInferSession(model_eb_file, device_id, intra_op_num_threads=intra_op_num_threads) + + self.batch_size = batch_size + self.plot_timestamp_to = plot_timestamp_to + if "predictor_bias" in config['model_conf'].keys(): + self.pred_bias = config['model_conf']['predictor_bias'] + else: + self.pred_bias = 0 + + def __call__(self, + wav_content: Union[str, np.ndarray, List[str]], + hotwords: str, + **kwargs) -> List: + # make hotword list + hotwords, hotwords_length = self.proc_hotword(hotwords) + # import pdb; pdb.set_trace() + [bias_embed] = self.eb_infer(hotwords, hotwords_length) + # index from bias_embed + bias_embed = bias_embed.transpose(1, 0, 2) + _ind = np.arange(0, len(hotwords)).tolist() + bias_embed = bias_embed[_ind, hotwords_length.cpu().numpy().tolist()] + waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq) + waveform_nums = len(waveform_list) + asr_res = [] + for beg_idx in range(0, waveform_nums, self.batch_size): + end_idx = min(waveform_nums, beg_idx + self.batch_size) + feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx]) + bias_embed = np.expand_dims(bias_embed, axis=0) + bias_embed = np.repeat(bias_embed, feats.shape[0], axis=0) + try: + outputs = self.bb_infer(feats, feats_len, bias_embed) + am_scores, valid_token_lens = outputs[0], outputs[1] + except ONNXRuntimeError: + #logging.warning(traceback.format_exc()) + logging.warning("input wav is silence or noise") + preds = [''] + else: + preds = self.decode(am_scores, valid_token_lens) + for pred in preds: + pred = sentence_postprocess(pred) + asr_res.append({'preds': pred}) + return asr_res + + def proc_hotword(self, hotwords): + hotwords = hotwords.split(" ") + hotwords_length = [len(i) - 1 for i in hotwords] + hotwords_length.append(0) + hotwords_length = torch.Tensor(hotwords_length).to(torch.int32) + # hotwords.append('') + def word_map(word): + return torch.tensor([self.vocab[i] for i in word]) + hotword_int = [word_map(i) for i in hotwords] + # import pdb; pdb.set_trace() + hotword_int.append(torch.tensor([1])) + hotwords = pad_list(hotword_int, pad_value=0, max_len=10) + return hotwords, hotwords_length + + def bb_infer(self, feats: np.ndarray, + feats_len: np.ndarray, bias_embed) -> Tuple[np.ndarray, np.ndarray]: + outputs = self.ort_infer_bb([feats, feats_len, bias_embed]) + return outputs + + def eb_infer(self, hotwords, hotwords_length): + outputs = self.ort_infer_eb([hotwords.to(torch.int32).numpy(), hotwords_length.to(torch.int32).numpy()]) + return outputs + + def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]: + return [self.decode_one(am_score, token_num) + for am_score, token_num in zip(am_scores, token_nums)] + + def decode_one(self, + am_score: np.ndarray, + valid_token_num: int) -> List[str]: + yseq = am_score.argmax(axis=-1) + score = am_score.max(axis=-1) + score = np.sum(score, axis=-1) + + # pad with mask tokens to ensure compatibility with sos/eos tokens + # asr_model.sos:1 asr_model.eos:2 + yseq = np.array([1] + yseq.tolist() + [2]) + hyp = Hypothesis(yseq=yseq, score=score) + + # remove sos/eos and get results + last_pos = -1 + token_int = hyp.yseq[1:last_pos].tolist() + + # remove blank symbol id, which is assumed to be 0 + token_int = list(filter(lambda x: x not in (0, 2), token_int)) + + # Change integer-ids to tokens + token = self.converter.ids2tokens(token_int) + token = token[:valid_token_num-self.pred_bias] + # texts = sentence_postprocess(token) + return token \ No newline at end of file diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py b/funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py index f1fc9a08e..cf742007b 100644 --- a/funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py +++ b/funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py @@ -7,6 +7,7 @@ from pathlib import Path from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union import re +import torch import numpy as np import yaml try: @@ -22,6 +23,52 @@ root_dir = Path(__file__).resolve().parent logger_initialized = {} +def pad_list(xs, pad_value, max_len=None): + n_batch = len(xs) + if max_len is None: + max_len = max(x.size(0) for x in xs) + pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value) + + for i in range(n_batch): + pad[i, : xs[i].size(0)] = xs[i] + + return pad + + +def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None): + if length_dim == 0: + raise ValueError("length_dim cannot be 0: {}".format(length_dim)) + + if not isinstance(lengths, list): + lengths = lengths.tolist() + bs = int(len(lengths)) + if maxlen is None: + if xs is None: + maxlen = int(max(lengths)) + else: + maxlen = xs.size(length_dim) + else: + assert xs is None + assert maxlen >= int(max(lengths)) + + seq_range = torch.arange(0, maxlen, dtype=torch.int64) + seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen) + seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1) + mask = seq_range_expand >= seq_length_expand + + if xs is not None: + assert xs.size(0) == bs, (xs.size(0), bs) + + if length_dim < 0: + length_dim = xs.dim() + length_dim + # ind = (:, None, ..., None, :, , None, ..., None) + ind = tuple( + slice(None) if i in (0, length_dim) else None for i in range(xs.dim()) + ) + mask = mask[ind].expand_as(xs).to(xs.device) + return mask + + class TokenIDConverter(): def __init__(self, token_list: Union[List, str], ): From 7175dee6e19a4cbfc67ba83b6aefb93624eaffa7 Mon Sep 17 00:00:00 2001 From: "shixian.shi" Date: Wed, 16 Aug 2023 14:42:44 +0800 Subject: [PATCH 2/4] update --- funasr/export/models/e2e_asr_contextual_paraformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/funasr/export/models/e2e_asr_contextual_paraformer.py b/funasr/export/models/e2e_asr_contextual_paraformer.py index 61806c919..0a3eba648 100644 --- a/funasr/export/models/e2e_asr_contextual_paraformer.py +++ b/funasr/export/models/e2e_asr_contextual_paraformer.py @@ -55,7 +55,7 @@ class ContextualParaformer_backbone(nn.Module): self.decoder = ParaformerDecoderSAN_export(model.decoder, onnx=onnx) self.feats_dim = feats_dim - self.model_name = model_name + '_bb' + self.model_name = model_name if onnx: self.make_pad_mask = MakePadMask(max_seq_len, flip=False) From 6f753a3899891f129a0d8f7ea11646f3ea6dfe31 Mon Sep 17 00:00:00 2001 From: "shixian.shi" Date: Wed, 16 Aug 2023 14:46:04 +0800 Subject: [PATCH 3/4] update --- funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py b/funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py index 5f866b808..d596d29ca 100644 --- a/funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py +++ b/funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py @@ -235,7 +235,7 @@ class ContextualParaformer(Paraformer): except: raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(model_dir) - model_bb_file = os.path.join(model_dir, 'model_bb.onnx') + model_bb_file = os.path.join(model_dir, 'model.onnx') model_eb_file = os.path.join(model_dir, 'model_eb.onnx') token_list_file = os.path.join(model_dir, 'tokens.txt') From 57875a51d9e33754149f54e14304ee4fb27e4519 Mon Sep 17 00:00:00 2001 From: "shixian.shi" Date: Wed, 16 Aug 2023 14:53:07 +0800 Subject: [PATCH 4/4] quant inference --- .../python/onnxruntime/funasr_onnx/paraformer_bin.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py b/funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py index d596d29ca..c9940363e 100644 --- a/funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py +++ b/funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py @@ -235,8 +235,12 @@ class ContextualParaformer(Paraformer): except: raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(model_dir) - model_bb_file = os.path.join(model_dir, 'model.onnx') - model_eb_file = os.path.join(model_dir, 'model_eb.onnx') + if quantize: + model_bb_file = os.path.join(model_dir, 'model_quant.onnx') + model_eb_file = os.path.join(model_dir, 'model_eb_quant.onnx') + else: + model_bb_file = os.path.join(model_dir, 'model.onnx') + model_eb_file = os.path.join(model_dir, 'model_eb.onnx') token_list_file = os.path.join(model_dir, 'tokens.txt') self.vocab = {}