mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
torchscripts
This commit is contained in:
parent
905a1f2585
commit
548153260b
@ -2,7 +2,7 @@ import torch
|
||||
import numpy as np
|
||||
|
||||
if __name__ == '__main__':
|
||||
onnx_path = "/mnt/workspace/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/model.torchscripts"
|
||||
onnx_path = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/model.torchscripts"
|
||||
loaded = torch.jit.load(onnx_path)
|
||||
|
||||
x = torch.rand([2, 21, 560])
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
from pathlib import Path
|
||||
import setuptools
|
||||
|
||||
from setuptools import find_packages
|
||||
|
||||
def get_readme():
|
||||
root_dir = Path(__file__).resolve().parent
|
||||
@ -29,7 +29,7 @@ setuptools.setup(
|
||||
"scipy", "numpy>=1.19.3",
|
||||
"typeguard", "kaldi-native-fbank",
|
||||
"PyYAML>=5.1.2"],
|
||||
packages=['torch_paraformer'],
|
||||
packages=find_packages(include=["torch_paraformer*"]),
|
||||
keywords=[
|
||||
'funasr,paraformer'
|
||||
],
|
||||
|
||||
@ -27,7 +27,7 @@ class Paraformer():
|
||||
if not Path(model_dir).exists():
|
||||
raise FileNotFoundError(f'{model_dir} does not exist.')
|
||||
|
||||
model_file = os.path.join(model_dir, 'model.onnx')
|
||||
model_file = os.path.join(model_dir, 'model.torchscripts')
|
||||
config_file = os.path.join(model_dir, 'config.yaml')
|
||||
cmvn_file = os.path.join(model_dir, 'am.mvn')
|
||||
config = read_yaml(config_file)
|
||||
@ -52,9 +52,8 @@ class Paraformer():
|
||||
feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
|
||||
|
||||
try:
|
||||
outputs = self.infer(feats, feats_len)
|
||||
outs = outputs[0], outputs[1]
|
||||
am_scores, valid_token_lens = outs[0], outs[1]
|
||||
outputs = self.ort_infer(feats, feats_len)
|
||||
am_scores, valid_token_lens = outputs[0], outputs[1]
|
||||
if len(outputs) == 4:
|
||||
# for BiCifParaformer Inference
|
||||
us_alphas, us_cif_peak = outputs[2], outputs[3]
|
||||
@ -65,7 +64,7 @@ class Paraformer():
|
||||
logging.warning("input wav is silence or noise")
|
||||
preds = ['']
|
||||
else:
|
||||
am_scores, valid_token_lens = am_scores.cpu().numpy(), valid_token_lens.cpu().numpy()
|
||||
am_scores, valid_token_lens = am_scores.detach().cpu().numpy(), valid_token_lens.detach().cpu().numpy()
|
||||
preds, raw_token = self.decode(am_scores, valid_token_lens)[0]
|
||||
res['preds'] = preds
|
||||
if us_cif_peak is not None:
|
||||
@ -105,6 +104,8 @@ class Paraformer():
|
||||
|
||||
feats = self.pad_feats(feats, np.max(feats_len))
|
||||
feats_len = np.array(feats_len).astype(np.int32)
|
||||
feats = torch.from_numpy(feats).type(torch.float32)
|
||||
feats_len = torch.from_numpy(feats_len).type(torch.int32)
|
||||
return feats, feats_len
|
||||
|
||||
@staticmethod
|
||||
|
||||
Loading…
Reference in New Issue
Block a user