From fd0992af3d1a2d2d098b1fab24f67f8c4cece39d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=BB=B4=E7=9F=B3?= Date: Mon, 3 Jun 2024 15:32:34 +0800 Subject: [PATCH] update libtorch inference --- .../python/libtorch/funasr_torch/paraformer_bin.py | 13 ++++++------- .../onnxruntime/funasr_onnx/paraformer_bin.py | 2 +- .../funasr_onnx/paraformer_online_bin.py | 2 +- runtime/python/onnxruntime/funasr_onnx/punc_bin.py | 2 +- runtime/python/onnxruntime/funasr_onnx/vad_bin.py | 4 ++-- 5 files changed, 11 insertions(+), 12 deletions(-) diff --git a/runtime/python/libtorch/funasr_torch/paraformer_bin.py b/runtime/python/libtorch/funasr_torch/paraformer_bin.py index b0cd8714d..ca96b47b3 100644 --- a/runtime/python/libtorch/funasr_torch/paraformer_bin.py +++ b/runtime/python/libtorch/funasr_torch/paraformer_bin.py @@ -275,7 +275,7 @@ class ContextualParaformer(Paraformer): model_eb_file = os.path.join(model_dir, "model_eb.torchscripts") if not (os.path.exists(model_eb_file) and os.path.exists(model_bb_file)): - print(".onnx is not exist, begin to export onnx") + print(".onnx does not exist, begin to export onnx") try: from funasr import AutoModel except: @@ -316,8 +316,7 @@ class ContextualParaformer(Paraformer): ) -> List: # make hotword list hotwords, hotwords_length = self.proc_hotword(hotwords) - # import pdb; pdb.set_trace() - [bias_embed] = self.eb_infer(hotwords, hotwords_length) + [bias_embed] = self.eb_infer(torch.Tensor(hotwords), torch.Tensor(hotwords_length)) # index from bias_embed bias_embed = bias_embed.transpose(1, 0, 2) _ind = np.arange(0, len(hotwords)).tolist() @@ -333,10 +332,10 @@ class ContextualParaformer(Paraformer): try: with torch.no_grad(): if int(self.device_id) == -1: - outputs = self.ort_infer(feats, feats_len) + outputs = self.bb_infer(feats, feats_len) am_scores, valid_token_lens = outputs[0], outputs[1] else: - outputs = self.ort_infer(feats.cuda(), feats_len.cuda()) + outputs = self.bb_infer_infer(feats.cuda(), feats_len.cuda()) am_scores, valid_token_lens = outputs[0].cpu(), outputs[1].cpu() except: # logging.warning(traceback.format_exc()) @@ -374,13 +373,13 @@ class ContextualParaformer(Paraformer): return hotwords, hotwords_length def bb_infer( - self, feats: np.ndarray, feats_len: np.ndarray, bias_embed + self, feats, feats_len, bias_embed ) -> Tuple[np.ndarray, np.ndarray]: outputs = self.ort_infer_bb([feats, feats_len, bias_embed]) return outputs def eb_infer(self, hotwords, hotwords_length): - outputs = self.ort_infer_eb([hotwords.astype(np.int32), hotwords_length.astype(np.int32)]) + outputs = self.ort_infer_eb([hotwords, hotwords_length]) return outputs def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]: diff --git a/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py b/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py index 8194283ce..871674eff 100644 --- a/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py +++ b/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py @@ -285,7 +285,7 @@ class ContextualParaformer(Paraformer): model_eb_file = os.path.join(model_dir, "model_eb.onnx") if not (os.path.exists(model_eb_file) and os.path.exists(model_bb_file)): - print(".onnx is not exist, begin to export onnx") + print(".onnx does not exist, begin to export onnx") try: from funasr import AutoModel except: diff --git a/runtime/python/onnxruntime/funasr_onnx/paraformer_online_bin.py b/runtime/python/onnxruntime/funasr_onnx/paraformer_online_bin.py index 9b68b2f05..ddd857d37 100644 --- a/runtime/python/onnxruntime/funasr_onnx/paraformer_online_bin.py +++ b/runtime/python/onnxruntime/funasr_onnx/paraformer_online_bin.py @@ -54,7 +54,7 @@ class Paraformer: encoder_model_file = os.path.join(model_dir, "model_quant.onnx") decoder_model_file = os.path.join(model_dir, "decoder_quant.onnx") if not os.path.exists(encoder_model_file) or not os.path.exists(decoder_model_file): - print(".onnx is not exist, begin to export onnx") + print(".onnx does not exist, begin to export onnx") try: from funasr import AutoModel except: diff --git a/runtime/python/onnxruntime/funasr_onnx/punc_bin.py b/runtime/python/onnxruntime/funasr_onnx/punc_bin.py index 6208c09e0..ba55186d0 100644 --- a/runtime/python/onnxruntime/funasr_onnx/punc_bin.py +++ b/runtime/python/onnxruntime/funasr_onnx/punc_bin.py @@ -52,7 +52,7 @@ class CT_Transformer: if quantize: model_file = os.path.join(model_dir, "model_quant.onnx") if not os.path.exists(model_file): - print(".onnx is not exist, begin to export onnx") + print(".onnx does not exist, begin to export onnx") try: from funasr import AutoModel except: diff --git a/runtime/python/onnxruntime/funasr_onnx/vad_bin.py b/runtime/python/onnxruntime/funasr_onnx/vad_bin.py index c195bb3b3..92928a840 100644 --- a/runtime/python/onnxruntime/funasr_onnx/vad_bin.py +++ b/runtime/python/onnxruntime/funasr_onnx/vad_bin.py @@ -52,7 +52,7 @@ class Fsmn_vad: if quantize: model_file = os.path.join(model_dir, "model_quant.onnx") if not os.path.exists(model_file): - print(".onnx is not exist, begin to export onnx") + print(".onnx does not exist, begin to export onnx") try: from funasr import AutoModel except: @@ -221,7 +221,7 @@ class Fsmn_vad_online: if quantize: model_file = os.path.join(model_dir, "model_quant.onnx") if not os.path.exists(model_file): - print(".onnx is not exist, begin to export onnx") + print(".onnx does not exist, begin to export onnx") try: from funasr import AutoModel except: