update libtorch inference

This commit is contained in:
维石 2024-06-03 15:32:34 +08:00
parent c5339e8302
commit fd0992af3d
5 changed files with 11 additions and 12 deletions

View File

@ -275,7 +275,7 @@ class ContextualParaformer(Paraformer):
model_eb_file = os.path.join(model_dir, "model_eb.torchscripts") model_eb_file = os.path.join(model_dir, "model_eb.torchscripts")
if not (os.path.exists(model_eb_file) and os.path.exists(model_bb_file)): if not (os.path.exists(model_eb_file) and os.path.exists(model_bb_file)):
print(".onnx is not exist, begin to export onnx") print(".onnx does not exist, begin to export onnx")
try: try:
from funasr import AutoModel from funasr import AutoModel
except: except:
@ -316,8 +316,7 @@ class ContextualParaformer(Paraformer):
) -> List: ) -> List:
# make hotword list # make hotword list
hotwords, hotwords_length = self.proc_hotword(hotwords) hotwords, hotwords_length = self.proc_hotword(hotwords)
# import pdb; pdb.set_trace() [bias_embed] = self.eb_infer(torch.Tensor(hotwords), torch.Tensor(hotwords_length))
[bias_embed] = self.eb_infer(hotwords, hotwords_length)
# index from bias_embed # index from bias_embed
bias_embed = bias_embed.transpose(1, 0, 2) bias_embed = bias_embed.transpose(1, 0, 2)
_ind = np.arange(0, len(hotwords)).tolist() _ind = np.arange(0, len(hotwords)).tolist()
@ -333,10 +332,10 @@ class ContextualParaformer(Paraformer):
try: try:
with torch.no_grad(): with torch.no_grad():
if int(self.device_id) == -1: if int(self.device_id) == -1:
outputs = self.ort_infer(feats, feats_len) outputs = self.bb_infer(feats, feats_len)
am_scores, valid_token_lens = outputs[0], outputs[1] am_scores, valid_token_lens = outputs[0], outputs[1]
else: else:
outputs = self.ort_infer(feats.cuda(), feats_len.cuda()) outputs = self.bb_infer_infer(feats.cuda(), feats_len.cuda())
am_scores, valid_token_lens = outputs[0].cpu(), outputs[1].cpu() am_scores, valid_token_lens = outputs[0].cpu(), outputs[1].cpu()
except: except:
# logging.warning(traceback.format_exc()) # logging.warning(traceback.format_exc())
@ -374,13 +373,13 @@ class ContextualParaformer(Paraformer):
return hotwords, hotwords_length return hotwords, hotwords_length
def bb_infer( def bb_infer(
self, feats: np.ndarray, feats_len: np.ndarray, bias_embed self, feats, feats_len, bias_embed
) -> Tuple[np.ndarray, np.ndarray]: ) -> Tuple[np.ndarray, np.ndarray]:
outputs = self.ort_infer_bb([feats, feats_len, bias_embed]) outputs = self.ort_infer_bb([feats, feats_len, bias_embed])
return outputs return outputs
def eb_infer(self, hotwords, hotwords_length): def eb_infer(self, hotwords, hotwords_length):
outputs = self.ort_infer_eb([hotwords.astype(np.int32), hotwords_length.astype(np.int32)]) outputs = self.ort_infer_eb([hotwords, hotwords_length])
return outputs return outputs
def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]: def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]:

View File

@ -285,7 +285,7 @@ class ContextualParaformer(Paraformer):
model_eb_file = os.path.join(model_dir, "model_eb.onnx") model_eb_file = os.path.join(model_dir, "model_eb.onnx")
if not (os.path.exists(model_eb_file) and os.path.exists(model_bb_file)): if not (os.path.exists(model_eb_file) and os.path.exists(model_bb_file)):
print(".onnx is not exist, begin to export onnx") print(".onnx does not exist, begin to export onnx")
try: try:
from funasr import AutoModel from funasr import AutoModel
except: except:

View File

@ -54,7 +54,7 @@ class Paraformer:
encoder_model_file = os.path.join(model_dir, "model_quant.onnx") encoder_model_file = os.path.join(model_dir, "model_quant.onnx")
decoder_model_file = os.path.join(model_dir, "decoder_quant.onnx") decoder_model_file = os.path.join(model_dir, "decoder_quant.onnx")
if not os.path.exists(encoder_model_file) or not os.path.exists(decoder_model_file): if not os.path.exists(encoder_model_file) or not os.path.exists(decoder_model_file):
print(".onnx is not exist, begin to export onnx") print(".onnx does not exist, begin to export onnx")
try: try:
from funasr import AutoModel from funasr import AutoModel
except: except:

View File

@ -52,7 +52,7 @@ class CT_Transformer:
if quantize: if quantize:
model_file = os.path.join(model_dir, "model_quant.onnx") model_file = os.path.join(model_dir, "model_quant.onnx")
if not os.path.exists(model_file): if not os.path.exists(model_file):
print(".onnx is not exist, begin to export onnx") print(".onnx does not exist, begin to export onnx")
try: try:
from funasr import AutoModel from funasr import AutoModel
except: except:

View File

@ -52,7 +52,7 @@ class Fsmn_vad:
if quantize: if quantize:
model_file = os.path.join(model_dir, "model_quant.onnx") model_file = os.path.join(model_dir, "model_quant.onnx")
if not os.path.exists(model_file): if not os.path.exists(model_file):
print(".onnx is not exist, begin to export onnx") print(".onnx does not exist, begin to export onnx")
try: try:
from funasr import AutoModel from funasr import AutoModel
except: except:
@ -221,7 +221,7 @@ class Fsmn_vad_online:
if quantize: if quantize:
model_file = os.path.join(model_dir, "model_quant.onnx") model_file = os.path.join(model_dir, "model_quant.onnx")
if not os.path.exists(model_file): if not os.path.exists(model_file):
print(".onnx is not exist, begin to export onnx") print(".onnx does not exist, begin to export onnx")
try: try:
from funasr import AutoModel from funasr import AutoModel
except: except: