torchscripts

This commit is contained in:
游雁 2023-03-02 20:20:44 +08:00
parent 905a1f2585
commit 548153260b
3 changed files with 9 additions and 8 deletions

View File

@ -2,7 +2,7 @@ import torch
import numpy as np
if __name__ == '__main__':
onnx_path = "/mnt/workspace/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/model.torchscripts"
onnx_path = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/model.torchscripts"
loaded = torch.jit.load(onnx_path)
x = torch.rand([2, 21, 560])

View File

@ -1,7 +1,7 @@
# -*- encoding: utf-8 -*-
from pathlib import Path
import setuptools
from setuptools import find_packages
def get_readme():
root_dir = Path(__file__).resolve().parent
@ -29,7 +29,7 @@ setuptools.setup(
"scipy", "numpy>=1.19.3",
"typeguard", "kaldi-native-fbank",
"PyYAML>=5.1.2"],
packages=['torch_paraformer'],
packages=find_packages(include=["torch_paraformer*"]),
keywords=[
'funasr,paraformer'
],

View File

@ -27,7 +27,7 @@ class Paraformer():
if not Path(model_dir).exists():
raise FileNotFoundError(f'{model_dir} does not exist.')
model_file = os.path.join(model_dir, 'model.onnx')
model_file = os.path.join(model_dir, 'model.torchscripts')
config_file = os.path.join(model_dir, 'config.yaml')
cmvn_file = os.path.join(model_dir, 'am.mvn')
config = read_yaml(config_file)
@ -52,9 +52,8 @@ class Paraformer():
feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
try:
outputs = self.infer(feats, feats_len)
outs = outputs[0], outputs[1]
am_scores, valid_token_lens = outs[0], outs[1]
outputs = self.ort_infer(feats, feats_len)
am_scores, valid_token_lens = outputs[0], outputs[1]
if len(outputs) == 4:
# for BiCifParaformer Inference
us_alphas, us_cif_peak = outputs[2], outputs[3]
@ -65,7 +64,7 @@ class Paraformer():
logging.warning("input wav is silence or noise")
preds = ['']
else:
am_scores, valid_token_lens = am_scores.cpu().numpy(), valid_token_lens.cpu().numpy()
am_scores, valid_token_lens = am_scores.detach().cpu().numpy(), valid_token_lens.detach().cpu().numpy()
preds, raw_token = self.decode(am_scores, valid_token_lens)[0]
res['preds'] = preds
if us_cif_peak is not None:
@ -105,6 +104,8 @@ class Paraformer():
feats = self.pad_feats(feats, np.max(feats_len))
feats_len = np.array(feats_len).astype(np.int32)
feats = torch.from_numpy(feats).type(torch.float32)
feats_len = torch.from_numpy(feats_len).type(torch.int32)
return feats, feats_len
@staticmethod