mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update libtorch inference
This commit is contained in:
parent
c5339e8302
commit
fd0992af3d
@ -275,7 +275,7 @@ class ContextualParaformer(Paraformer):
|
|||||||
model_eb_file = os.path.join(model_dir, "model_eb.torchscripts")
|
model_eb_file = os.path.join(model_dir, "model_eb.torchscripts")
|
||||||
|
|
||||||
if not (os.path.exists(model_eb_file) and os.path.exists(model_bb_file)):
|
if not (os.path.exists(model_eb_file) and os.path.exists(model_bb_file)):
|
||||||
print(".onnx is not exist, begin to export onnx")
|
print(".onnx does not exist, begin to export onnx")
|
||||||
try:
|
try:
|
||||||
from funasr import AutoModel
|
from funasr import AutoModel
|
||||||
except:
|
except:
|
||||||
@ -316,8 +316,7 @@ class ContextualParaformer(Paraformer):
|
|||||||
) -> List:
|
) -> List:
|
||||||
# make hotword list
|
# make hotword list
|
||||||
hotwords, hotwords_length = self.proc_hotword(hotwords)
|
hotwords, hotwords_length = self.proc_hotword(hotwords)
|
||||||
# import pdb; pdb.set_trace()
|
[bias_embed] = self.eb_infer(torch.Tensor(hotwords), torch.Tensor(hotwords_length))
|
||||||
[bias_embed] = self.eb_infer(hotwords, hotwords_length)
|
|
||||||
# index from bias_embed
|
# index from bias_embed
|
||||||
bias_embed = bias_embed.transpose(1, 0, 2)
|
bias_embed = bias_embed.transpose(1, 0, 2)
|
||||||
_ind = np.arange(0, len(hotwords)).tolist()
|
_ind = np.arange(0, len(hotwords)).tolist()
|
||||||
@ -333,10 +332,10 @@ class ContextualParaformer(Paraformer):
|
|||||||
try:
|
try:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
if int(self.device_id) == -1:
|
if int(self.device_id) == -1:
|
||||||
outputs = self.ort_infer(feats, feats_len)
|
outputs = self.bb_infer(feats, feats_len)
|
||||||
am_scores, valid_token_lens = outputs[0], outputs[1]
|
am_scores, valid_token_lens = outputs[0], outputs[1]
|
||||||
else:
|
else:
|
||||||
outputs = self.ort_infer(feats.cuda(), feats_len.cuda())
|
outputs = self.bb_infer_infer(feats.cuda(), feats_len.cuda())
|
||||||
am_scores, valid_token_lens = outputs[0].cpu(), outputs[1].cpu()
|
am_scores, valid_token_lens = outputs[0].cpu(), outputs[1].cpu()
|
||||||
except:
|
except:
|
||||||
# logging.warning(traceback.format_exc())
|
# logging.warning(traceback.format_exc())
|
||||||
@ -374,13 +373,13 @@ class ContextualParaformer(Paraformer):
|
|||||||
return hotwords, hotwords_length
|
return hotwords, hotwords_length
|
||||||
|
|
||||||
def bb_infer(
|
def bb_infer(
|
||||||
self, feats: np.ndarray, feats_len: np.ndarray, bias_embed
|
self, feats, feats_len, bias_embed
|
||||||
) -> Tuple[np.ndarray, np.ndarray]:
|
) -> Tuple[np.ndarray, np.ndarray]:
|
||||||
outputs = self.ort_infer_bb([feats, feats_len, bias_embed])
|
outputs = self.ort_infer_bb([feats, feats_len, bias_embed])
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def eb_infer(self, hotwords, hotwords_length):
|
def eb_infer(self, hotwords, hotwords_length):
|
||||||
outputs = self.ort_infer_eb([hotwords.astype(np.int32), hotwords_length.astype(np.int32)])
|
outputs = self.ort_infer_eb([hotwords, hotwords_length])
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]:
|
def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]:
|
||||||
|
|||||||
@ -285,7 +285,7 @@ class ContextualParaformer(Paraformer):
|
|||||||
model_eb_file = os.path.join(model_dir, "model_eb.onnx")
|
model_eb_file = os.path.join(model_dir, "model_eb.onnx")
|
||||||
|
|
||||||
if not (os.path.exists(model_eb_file) and os.path.exists(model_bb_file)):
|
if not (os.path.exists(model_eb_file) and os.path.exists(model_bb_file)):
|
||||||
print(".onnx is not exist, begin to export onnx")
|
print(".onnx does not exist, begin to export onnx")
|
||||||
try:
|
try:
|
||||||
from funasr import AutoModel
|
from funasr import AutoModel
|
||||||
except:
|
except:
|
||||||
|
|||||||
@ -54,7 +54,7 @@ class Paraformer:
|
|||||||
encoder_model_file = os.path.join(model_dir, "model_quant.onnx")
|
encoder_model_file = os.path.join(model_dir, "model_quant.onnx")
|
||||||
decoder_model_file = os.path.join(model_dir, "decoder_quant.onnx")
|
decoder_model_file = os.path.join(model_dir, "decoder_quant.onnx")
|
||||||
if not os.path.exists(encoder_model_file) or not os.path.exists(decoder_model_file):
|
if not os.path.exists(encoder_model_file) or not os.path.exists(decoder_model_file):
|
||||||
print(".onnx is not exist, begin to export onnx")
|
print(".onnx does not exist, begin to export onnx")
|
||||||
try:
|
try:
|
||||||
from funasr import AutoModel
|
from funasr import AutoModel
|
||||||
except:
|
except:
|
||||||
|
|||||||
@ -52,7 +52,7 @@ class CT_Transformer:
|
|||||||
if quantize:
|
if quantize:
|
||||||
model_file = os.path.join(model_dir, "model_quant.onnx")
|
model_file = os.path.join(model_dir, "model_quant.onnx")
|
||||||
if not os.path.exists(model_file):
|
if not os.path.exists(model_file):
|
||||||
print(".onnx is not exist, begin to export onnx")
|
print(".onnx does not exist, begin to export onnx")
|
||||||
try:
|
try:
|
||||||
from funasr import AutoModel
|
from funasr import AutoModel
|
||||||
except:
|
except:
|
||||||
|
|||||||
@ -52,7 +52,7 @@ class Fsmn_vad:
|
|||||||
if quantize:
|
if quantize:
|
||||||
model_file = os.path.join(model_dir, "model_quant.onnx")
|
model_file = os.path.join(model_dir, "model_quant.onnx")
|
||||||
if not os.path.exists(model_file):
|
if not os.path.exists(model_file):
|
||||||
print(".onnx is not exist, begin to export onnx")
|
print(".onnx does not exist, begin to export onnx")
|
||||||
try:
|
try:
|
||||||
from funasr import AutoModel
|
from funasr import AutoModel
|
||||||
except:
|
except:
|
||||||
@ -221,7 +221,7 @@ class Fsmn_vad_online:
|
|||||||
if quantize:
|
if quantize:
|
||||||
model_file = os.path.join(model_dir, "model_quant.onnx")
|
model_file = os.path.join(model_dir, "model_quant.onnx")
|
||||||
if not os.path.exists(model_file):
|
if not os.path.exists(model_file):
|
||||||
print(".onnx is not exist, begin to export onnx")
|
print(".onnx does not exist, begin to export onnx")
|
||||||
try:
|
try:
|
||||||
from funasr import AutoModel
|
from funasr import AutoModel
|
||||||
except:
|
except:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user