diff --git a/fbank.py b/fbank.py
new file mode 100644
index 000000000..26daa45f6
--- /dev/null
+++ b/fbank.py
@@ -0,0 +1,123 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+# Part of the implementation is borrowed from espnet/espnet.
+
+from typing import Tuple
+
+import numpy as np
+import torch
+import torchaudio.compliance.kaldi as kaldi
+from funasr.models.frontend.abs_frontend import AbsFrontend
+from typeguard import check_argument_types
+from torch.nn.utils.rnn import pad_sequence
+import kaldi_native_fbank as knf
+
+class WavFrontend(AbsFrontend):
+ """Conventional frontend structure for ASR.
+ """
+
+ def __init__(
+ self,
+ cmvn_file: str = None,
+ fs: int = 16000,
+ window: str = 'hamming',
+ n_mels: int = 80,
+ frame_length: int = 25,
+ frame_shift: int = 10,
+ filter_length_min: int = -1,
+ filter_length_max: int = -1,
+ lfr_m: int = 1,
+ lfr_n: int = 1,
+ dither: float = 1.0,
+ snip_edges: bool = True,
+ upsacle_samples: bool = True,
+ ):
+ assert check_argument_types()
+ super().__init__()
+ self.fs = fs
+ self.window = window
+ self.n_mels = n_mels
+ self.frame_length = frame_length
+ self.frame_shift = frame_shift
+ self.filter_length_min = filter_length_min
+ self.filter_length_max = filter_length_max
+ self.lfr_m = lfr_m
+ self.lfr_n = lfr_n
+ self.cmvn_file = cmvn_file
+ self.dither = dither
+ self.snip_edges = snip_edges
+ self.upsacle_samples = upsacle_samples
+
+ def output_size(self) -> int:
+ return self.n_mels * self.lfr_m
+
+ def forward(
+ self,
+ input: torch.Tensor,
+ input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ batch_size = input.size(0)
+ feats = []
+ feats_lens = []
+ for i in range(batch_size):
+ waveform_length = input_lengths[i]
+ waveform = input[i][:waveform_length]
+ waveform = waveform * (1 << 15)
+ waveform = waveform.unsqueeze(0)
+ mat = kaldi.fbank(waveform,
+ num_mel_bins=self.n_mels,
+ frame_length=self.frame_length,
+ frame_shift=self.frame_shift,
+ dither=self.dither,
+ energy_floor=0.0,
+ window_type=self.window,
+ sample_frequency=self.fs)
+
+ feat_length = mat.size(0)
+ feats.append(mat)
+ feats_lens.append(feat_length)
+
+ feats_lens = torch.as_tensor(feats_lens)
+ feats_pad = pad_sequence(feats,
+ batch_first=True,
+ padding_value=0.0)
+ return feats_pad, feats_lens
+
+import kaldi_native_fbank as knf
+
+def fbank_knf(waveform):
+ # sampling_rate = 16000
+ # samples = torch.randn(16000 * 10)
+
+ opts = knf.FbankOptions()
+ opts.frame_opts.samp_freq = 16000
+ opts.frame_opts.dither = 0.0
+ opts.frame_opts.window_type = "hamming"
+ opts.frame_opts.frame_shift_ms = 10.0
+ opts.frame_opts.frame_length_ms = 25.0
+ opts.mel_opts.num_bins = 80
+ opts.energy_floor = 1
+ opts.frame_opts.snip_edges = True
+ opts.mel_opts.debug_mel = False
+
+ fbank = knf.OnlineFbank(opts)
+ waveform = waveform * (1 << 15)
+ fbank.accept_waveform(opts.frame_opts.samp_freq, waveform.tolist())
+ frames = fbank.num_frames_ready
+ mat = np.empty([frames, opts.mel_opts.num_bins])
+ for i in range(frames):
+ mat[i, :] = fbank.get_frame(i)
+ return mat
+
+if __name__ == '__main__':
+ import librosa
+
+ path = "/home/zhifu.gzf/.cache/modelscope/hub/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav"
+ waveform, fs = librosa.load(path, sr=None)
+ fbank = fbank_knf(waveform)
+ frontend = WavFrontend(dither=0.0)
+ waveform_tensor = torch.from_numpy(waveform)[None, :]
+ fbank_torch, _ = frontend.forward(waveform_tensor, [waveform_tensor.size(1)])
+ fbank_torch = fbank_torch.cpu().numpy()[0, :, :]
+ diff = fbank - fbank_torch
+ diff_max = diff.max()
+ diff_sum = diff.abs().sum()
+ pass
\ No newline at end of file
diff --git a/funasr/models/frontend/wav_frontend.py b/funasr/models/frontend/wav_frontend.py
index 7a6425be3..ed8cb3646 100644
--- a/funasr/models/frontend/wav_frontend.py
+++ b/funasr/models/frontend/wav_frontend.py
@@ -171,10 +171,7 @@ class WavFrontend(AbsFrontend):
window_type=self.window,
sample_frequency=self.fs)
- # if self.lfr_m != 1 or self.lfr_n != 1:
- # mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
- # if self.cmvn_file is not None:
- # mat = apply_cmvn(mat, self.cmvn_file)
+
feat_length = mat.size(0)
feats.append(mat)
feats_lens.append(feat_length)
diff --git a/funasr/runtime/__init__.py b/funasr/runtime/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/funasr/runtime/python/__init__.py b/funasr/runtime/python/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/funasr/runtime/python/onnxruntime/__init__.py b/funasr/runtime/python/onnxruntime/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/funasr/runtime/python/onnxruntime/.gitignore b/funasr/runtime/python/onnxruntime/paraformer/.gitignore
similarity index 100%
rename from funasr/runtime/python/onnxruntime/.gitignore
rename to funasr/runtime/python/onnxruntime/paraformer/.gitignore
diff --git a/funasr/runtime/python/onnxruntime/README.md b/funasr/runtime/python/onnxruntime/paraformer/README.md
similarity index 82%
rename from funasr/runtime/python/onnxruntime/README.md
rename to funasr/runtime/python/onnxruntime/paraformer/README.md
index ee3ce0a6a..d68600f6b 100644
--- a/funasr/runtime/python/onnxruntime/README.md
+++ b/funasr/runtime/python/onnxruntime/paraformer/README.md
@@ -29,12 +29,6 @@
│ └── utils.py
├── README.md
├── requirements.txt
- ├── resources
- │ ├── config.yaml
- │ └── models
- │ ├── am.mvn
- │ ├── model.onnx # Put it here.
- │ └── token_list.pkl
├── test_onnx.py
├── tests
│ ├── __pycache__
@@ -48,15 +42,15 @@
- Output: `List[str]`: recognition result.
- Example:
```python
- from rapid_paraformer import RapidParaformer
+ from paraformer_onnx import Paraformer
config_path = 'resources/config.yaml'
- paraformer = RapidParaformer(config_path)
+ model = Paraformer(config_path)
- wav_path = ['test_wavs/0478_00017.wav']
+ wav_path = ['example/asr_example.wav']
- result = paraformer(wav_path)
+ result = model(wav_path)
print(result)
```
diff --git a/funasr/runtime/python/onnxruntime/paraformer/__init__.py b/funasr/runtime/python/onnxruntime/paraformer/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/funasr/runtime/python/onnxruntime/rapid_paraformer/__init__.py b/funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/__init__.py
similarity index 100%
rename from funasr/runtime/python/onnxruntime/rapid_paraformer/__init__.py
rename to funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/__init__.py
diff --git a/funasr/runtime/python/onnxruntime/rapid_paraformer/kaldifeat/LICENSE b/funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/kaldifeat/LICENSE
similarity index 100%
rename from funasr/runtime/python/onnxruntime/rapid_paraformer/kaldifeat/LICENSE
rename to funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/kaldifeat/LICENSE
diff --git a/funasr/runtime/python/onnxruntime/rapid_paraformer/kaldifeat/__init__.py b/funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/kaldifeat/__init__.py
similarity index 100%
rename from funasr/runtime/python/onnxruntime/rapid_paraformer/kaldifeat/__init__.py
rename to funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/kaldifeat/__init__.py
diff --git a/funasr/runtime/python/onnxruntime/rapid_paraformer/kaldifeat/feature.py b/funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/kaldifeat/feature.py
similarity index 100%
rename from funasr/runtime/python/onnxruntime/rapid_paraformer/kaldifeat/feature.py
rename to funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/kaldifeat/feature.py
diff --git a/funasr/runtime/python/onnxruntime/rapid_paraformer/kaldifeat/ivector.py b/funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/kaldifeat/ivector.py
similarity index 100%
rename from funasr/runtime/python/onnxruntime/rapid_paraformer/kaldifeat/ivector.py
rename to funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/kaldifeat/ivector.py
diff --git a/funasr/runtime/python/onnxruntime/rapid_paraformer/rapid_paraformer.py b/funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/paraformer_onnx.py
similarity index 76%
rename from funasr/runtime/python/onnxruntime/rapid_paraformer/rapid_paraformer.py
rename to funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/paraformer_onnx.py
index 10bfa8ae4..1fc3582ce 100644
--- a/funasr/runtime/python/onnxruntime/rapid_paraformer/rapid_paraformer.py
+++ b/funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/paraformer_onnx.py
@@ -1,6 +1,7 @@
# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: liekkaskono@163.com
+import os.path
import traceback
from pathlib import Path
from typing import List, Union, Tuple
@@ -11,25 +12,33 @@ import numpy as np
from .utils import (CharTokenizer, Hypothesis, ONNXRuntimeError,
OrtInferSession, TokenIDConverter, WavFrontend, get_logger,
read_yaml)
+from .postprocess_utils import sentence_postprocess
logging = get_logger()
-class RapidParaformer():
- def __init__(self, config_path: Union[str, Path]) -> None:
- if not Path(config_path).exists():
- raise FileNotFoundError(f'{config_path} does not exist.')
+class Paraformer():
+ def __init__(self, model_dir: Union[str, Path]=None,
+ batch_size: int = 1,
+ device_id: Union[str, int]="-1",
+ ):
+
+ if not Path(model_dir).exists():
+ raise FileNotFoundError(f'{model_dir} does not exist.')
- config = read_yaml(config_path)
+ model_file = os.path.join(model_dir, 'model.onnx')
+ config_file = os.path.join(model_dir, 'config.yaml')
+ cmvn_file = os.path.join(model_dir, 'am.mvn')
+ config = read_yaml(config_file)
- self.converter = TokenIDConverter(**config['TokenIDConverter'])
- self.tokenizer = CharTokenizer(**config['CharTokenizer'])
+ self.converter = TokenIDConverter(config['token_list'])
+ self.tokenizer = CharTokenizer()
self.frontend = WavFrontend(
- cmvn_file=config['WavFrontend']['cmvn_file'],
- **config['WavFrontend']['frontend_conf']
+ cmvn_file=cmvn_file,
+ **config['frontend_conf']
)
- self.ort_infer = OrtInferSession(config['Model'])
- self.batch_size = config['Model']['batch_size']
+ self.ort_infer = OrtInferSession(model_file, device_id)
+ self.batch_size = batch_size
def __call__(self, wav_content: Union[str, np.ndarray, List[str]]) -> List:
waveform_list = self.load_data(wav_content)
@@ -124,16 +133,19 @@ class RapidParaformer():
# Change integer-ids to tokens
token = self.converter.ids2tokens(token_int)
- text = self.tokenizer.tokens2text(token)
+ token = token[:valid_token_num-1]
+ texts = sentence_postprocess(token)
+ text = texts[0]
+ # text = self.tokenizer.tokens2text(token)
return text[:valid_token_num-1]
if __name__ == '__main__':
project_dir = Path(__file__).resolve().parent.parent
- cfg_path = project_dir / 'resources' / 'config.yaml'
- paraformer = RapidParaformer(cfg_path)
+ model_dir = "/home/zhifu.gzf/.cache/modelscope/hub/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
+ model = Paraformer(model_dir)
+
+ wav_file = os.path.join(model_dir, 'example/asr_example.wav')
+ result = model(wav_file)
+ print(result)
- wav_file = '0478_00017.wav'
- for i in range(1000):
- result = paraformer(wav_file)
- print(result)
diff --git a/funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/postprocess_utils.py b/funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/postprocess_utils.py
new file mode 100644
index 000000000..575fb90dd
--- /dev/null
+++ b/funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/postprocess_utils.py
@@ -0,0 +1,240 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import string
+import logging
+from typing import Any, List, Union
+
+
+def isChinese(ch: str):
+ if '\u4e00' <= ch <= '\u9fff' or '\u0030' <= ch <= '\u0039':
+ return True
+ return False
+
+
+def isAllChinese(word: Union[List[Any], str]):
+ word_lists = []
+ for i in word:
+ cur = i.replace(' ', '')
+ cur = cur.replace('', '')
+ cur = cur.replace('', '')
+ word_lists.append(cur)
+
+ if len(word_lists) == 0:
+ return False
+
+ for ch in word_lists:
+ if isChinese(ch) is False:
+ return False
+ return True
+
+
+def isAllAlpha(word: Union[List[Any], str]):
+ word_lists = []
+ for i in word:
+ cur = i.replace(' ', '')
+ cur = cur.replace('', '')
+ cur = cur.replace('', '')
+ word_lists.append(cur)
+
+ if len(word_lists) == 0:
+ return False
+
+ for ch in word_lists:
+ if ch.isalpha() is False and ch != "'":
+ return False
+ elif ch.isalpha() is True and isChinese(ch) is True:
+ return False
+
+ return True
+
+
+# def abbr_dispose(words: List[Any]) -> List[Any]:
+def abbr_dispose(words: List[Any], time_stamp: List[List] = None) -> List[Any]:
+ words_size = len(words)
+ word_lists = []
+ abbr_begin = []
+ abbr_end = []
+ last_num = -1
+ ts_lists = []
+ ts_nums = []
+ ts_index = 0
+ for num in range(words_size):
+ if num <= last_num:
+ continue
+
+ if len(words[num]) == 1 and words[num].encode('utf-8').isalpha():
+ if num + 1 < words_size and words[
+ num + 1] == ' ' and num + 2 < words_size and len(
+ words[num +
+ 2]) == 1 and words[num +
+ 2].encode('utf-8').isalpha():
+ # found the begin of abbr
+ abbr_begin.append(num)
+ num += 2
+ abbr_end.append(num)
+ # to find the end of abbr
+ while True:
+ num += 1
+ if num < words_size and words[num] == ' ':
+ num += 1
+ if num < words_size and len(
+ words[num]) == 1 and words[num].encode(
+ 'utf-8').isalpha():
+ abbr_end.pop()
+ abbr_end.append(num)
+ last_num = num
+ else:
+ break
+ else:
+ break
+
+ for num in range(words_size):
+ if words[num] == ' ':
+ ts_nums.append(ts_index)
+ else:
+ ts_nums.append(ts_index)
+ ts_index += 1
+ last_num = -1
+ for num in range(words_size):
+ if num <= last_num:
+ continue
+
+ if num in abbr_begin:
+ if time_stamp is not None:
+ begin = time_stamp[ts_nums[num]][0]
+ word_lists.append(words[num].upper())
+ num += 1
+ while num < words_size:
+ if num in abbr_end:
+ word_lists.append(words[num].upper())
+ last_num = num
+ break
+ else:
+ if words[num].encode('utf-8').isalpha():
+ word_lists.append(words[num].upper())
+ num += 1
+ if time_stamp is not None:
+ end = time_stamp[ts_nums[num]][1]
+ ts_lists.append([begin, end])
+ else:
+ word_lists.append(words[num])
+ if time_stamp is not None and words[num] != ' ':
+ begin = time_stamp[ts_nums[num]][0]
+ end = time_stamp[ts_nums[num]][1]
+ ts_lists.append([begin, end])
+ begin = end
+
+ if time_stamp is not None:
+ return word_lists, ts_lists
+ else:
+ return word_lists
+
+
+def sentence_postprocess(words: List[Any], time_stamp: List[List] = None):
+ middle_lists = []
+ word_lists = []
+ word_item = ''
+ ts_lists = []
+
+ # wash words lists
+ for i in words:
+ word = ''
+ if isinstance(i, str):
+ word = i
+ else:
+ word = i.decode('utf-8')
+
+ if word in ['', '', '']:
+ continue
+ else:
+ middle_lists.append(word)
+
+ # all chinese characters
+ if isAllChinese(middle_lists):
+ for i, ch in enumerate(middle_lists):
+ word_lists.append(ch.replace(' ', ''))
+ if time_stamp is not None:
+ ts_lists = time_stamp
+
+ # all alpha characters
+ elif isAllAlpha(middle_lists):
+ ts_flag = True
+ for i, ch in enumerate(middle_lists):
+ if ts_flag and time_stamp is not None:
+ begin = time_stamp[i][0]
+ end = time_stamp[i][1]
+ word = ''
+ if '@@' in ch:
+ word = ch.replace('@@', '')
+ word_item += word
+ if time_stamp is not None:
+ ts_flag = False
+ end = time_stamp[i][1]
+ else:
+ word_item += ch
+ word_lists.append(word_item)
+ word_lists.append(' ')
+ word_item = ''
+ if time_stamp is not None:
+ ts_flag = True
+ end = time_stamp[i][1]
+ ts_lists.append([begin, end])
+ begin = end
+
+ # mix characters
+ else:
+ alpha_blank = False
+ ts_flag = True
+ begin = -1
+ end = -1
+ for i, ch in enumerate(middle_lists):
+ if ts_flag and time_stamp is not None:
+ begin = time_stamp[i][0]
+ end = time_stamp[i][1]
+ word = ''
+ if isAllChinese(ch):
+ if alpha_blank is True:
+ word_lists.pop()
+ word_lists.append(ch)
+ alpha_blank = False
+ if time_stamp is not None:
+ ts_flag = True
+ ts_lists.append([begin, end])
+ begin = end
+ elif '@@' in ch:
+ word = ch.replace('@@', '')
+ word_item += word
+ alpha_blank = False
+ if time_stamp is not None:
+ ts_flag = False
+ end = time_stamp[i][1]
+ elif isAllAlpha(ch):
+ word_item += ch
+ word_lists.append(word_item)
+ word_lists.append(' ')
+ word_item = ''
+ alpha_blank = True
+ if time_stamp is not None:
+ ts_flag = True
+ end = time_stamp[i][1]
+ ts_lists.append([begin, end])
+ begin = end
+ else:
+ raise ValueError('invalid character: {}'.format(ch))
+
+ if time_stamp is not None:
+ word_lists, ts_lists = abbr_dispose(word_lists, ts_lists)
+ real_word_lists = []
+ for ch in word_lists:
+ if ch != ' ':
+ real_word_lists.append(ch)
+ sentence = ' '.join(real_word_lists).strip()
+ return sentence, ts_lists, real_word_lists
+ else:
+ word_lists = abbr_dispose(word_lists)
+ real_word_lists = []
+ for ch in word_lists:
+ if ch != ' ':
+ real_word_lists.append(ch)
+ sentence = ''.join(word_lists).strip()
+ return sentence, real_word_lists
diff --git a/funasr/runtime/python/onnxruntime/rapid_paraformer/utils.py b/funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/utils.py
similarity index 90%
rename from funasr/runtime/python/onnxruntime/rapid_paraformer/utils.py
rename to funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/utils.py
index 839adb4c4..ea3c0b7f7 100644
--- a/funasr/runtime/python/onnxruntime/rapid_paraformer/utils.py
+++ b/funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/utils.py
@@ -14,6 +14,7 @@ from onnxruntime import (GraphOptimizationLevel, InferenceSession,
from typeguard import check_argument_types
from .kaldifeat import compute_fbank_feats
+import warnings
root_dir = Path(__file__).resolve().parent
@@ -21,24 +22,25 @@ logger_initialized = {}
class TokenIDConverter():
- def __init__(self, token_path: Union[Path, str],
+ def __init__(self, token_list: Union[Path, str],
unk_symbol: str = "",):
check_argument_types()
- self.token_list = self.load_token(token_path)
- self.unk_symbol = unk_symbol
+ # self.token_list = self.load_token(token_path)
+ self.token_list = token_list
+ self.unk_symbol = token_list[-1]
- @staticmethod
- def load_token(file_path: Union[Path, str]) -> List:
- if not Path(file_path).exists():
- raise TokenIDConverterError(f'The {file_path} does not exist.')
-
- with open(str(file_path), 'rb') as f:
- token_list = pickle.load(f)
-
- if len(token_list) != len(set(token_list)):
- raise TokenIDConverterError('The Token exists duplicated symbol.')
- return token_list
+ # @staticmethod
+ # def load_token(file_path: Union[Path, str]) -> List:
+ # if not Path(file_path).exists():
+ # raise TokenIDConverterError(f'The {file_path} does not exist.')
+ #
+ # with open(str(file_path), 'rb') as f:
+ # token_list = pickle.load(f)
+ #
+ # if len(token_list) != len(set(token_list)):
+ # raise TokenIDConverterError('The Token exists duplicated symbol.')
+ # return token_list
def get_num_vocabulary_size(self) -> int:
return len(self.token_list)
@@ -268,31 +270,36 @@ class ONNXRuntimeError(Exception):
class OrtInferSession():
- def __init__(self, config):
+ def __init__(self, model_file, device_id=-1):
sess_opt = SessionOptions()
sess_opt.log_severity_level = 4
sess_opt.enable_cpu_mem_arena = False
sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
cuda_ep = 'CUDAExecutionProvider'
+ cuda_provider_options = {
+ "device_id": device_id,
+ "arena_extend_strategy": "kNextPowerOfTwo",
+ "cudnn_conv_algo_search": "EXHAUSTIVE",
+ "do_copy_in_default_stream": "true",
+ }
cpu_ep = 'CPUExecutionProvider'
cpu_provider_options = {
"arena_extend_strategy": "kSameAsRequested",
}
EP_list = []
- if config['use_cuda'] and get_device() == 'GPU' \
+ if device_id != -1 and get_device() == 'GPU' \
and cuda_ep in get_available_providers():
- EP_list = [(cuda_ep, config[cuda_ep])]
+ EP_list = [(cuda_ep, cuda_provider_options)]
EP_list.append((cpu_ep, cpu_provider_options))
- config['model_path'] = config['model_path']
- self._verify_model(config['model_path'])
- self.session = InferenceSession(config['model_path'],
+ self._verify_model(model_file)
+ self.session = InferenceSession(model_file,
sess_options=sess_opt,
providers=EP_list)
- if config['use_cuda'] and cuda_ep not in self.session.get_providers():
+ if device_id != -1 and cuda_ep not in self.session.get_providers():
warnings.warn(f'{cuda_ep} is not avaiable for current env, the inference part is automatically shifted to be executed under {cpu_ep}.\n'
'Please ensure the installed onnxruntime-gpu version matches your cuda and cudnn version, '
'you can check their relations from the offical web site: '
diff --git a/funasr/runtime/python/onnxruntime/requirements.txt b/funasr/runtime/python/onnxruntime/paraformer/requirements.txt
similarity index 100%
rename from funasr/runtime/python/onnxruntime/requirements.txt
rename to funasr/runtime/python/onnxruntime/paraformer/requirements.txt
diff --git a/funasr/runtime/python/onnxruntime/resources/config.yaml b/funasr/runtime/python/onnxruntime/paraformer/resources/config.yaml
similarity index 97%
rename from funasr/runtime/python/onnxruntime/resources/config.yaml
rename to funasr/runtime/python/onnxruntime/paraformer/resources/config.yaml
index fd243c304..83736a422 100644
--- a/funasr/runtime/python/onnxruntime/resources/config.yaml
+++ b/funasr/runtime/python/onnxruntime/paraformer/resources/config.yaml
@@ -18,6 +18,7 @@ WavFrontend:
lfr_m: 7
lfr_n: 6
filter_length_max: -.inf
+ dither: 0.0
Model:
model_path: resources/models/model.onnx
diff --git a/funasr/runtime/python/onnxruntime/resources/models/am.mvn b/funasr/runtime/python/onnxruntime/paraformer/resources/models/am.mvn
similarity index 100%
rename from funasr/runtime/python/onnxruntime/resources/models/am.mvn
rename to funasr/runtime/python/onnxruntime/paraformer/resources/models/am.mvn
diff --git a/funasr/runtime/python/onnxruntime/resources/models/token_list.pkl b/funasr/runtime/python/onnxruntime/paraformer/resources/models/token_list.pkl
similarity index 100%
rename from funasr/runtime/python/onnxruntime/resources/models/token_list.pkl
rename to funasr/runtime/python/onnxruntime/paraformer/resources/models/token_list.pkl