From 865ae89f0a713f70dda16859638b25e7350275ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Mon, 13 Feb 2023 17:43:01 +0800 Subject: [PATCH] export model --- fbank.py | 123 +++++++++ funasr/models/frontend/wav_frontend.py | 5 +- funasr/runtime/__init__.py | 0 funasr/runtime/python/__init__.py | 0 funasr/runtime/python/onnxruntime/__init__.py | 0 .../onnxruntime/{ => paraformer}/.gitignore | 0 .../onnxruntime/{ => paraformer}/README.md | 14 +- .../python/onnxruntime/paraformer/__init__.py | 0 .../rapid_paraformer/__init__.py | 0 .../rapid_paraformer/kaldifeat/LICENSE | 0 .../rapid_paraformer/kaldifeat/__init__.py | 0 .../rapid_paraformer/kaldifeat/feature.py | 0 .../rapid_paraformer/kaldifeat/ivector.py | 0 .../rapid_paraformer/paraformer_onnx.py} | 48 ++-- .../rapid_paraformer/postprocess_utils.py | 240 ++++++++++++++++++ .../rapid_paraformer/utils.py | 49 ++-- .../{ => paraformer}/requirements.txt | 0 .../{ => paraformer}/resources/config.yaml | 1 + .../{ => paraformer}/resources/models/am.mvn | 0 .../resources/models/token_list.pkl | Bin 20 files changed, 427 insertions(+), 53 deletions(-) create mode 100644 fbank.py create mode 100644 funasr/runtime/__init__.py create mode 100644 funasr/runtime/python/__init__.py create mode 100644 funasr/runtime/python/onnxruntime/__init__.py rename funasr/runtime/python/onnxruntime/{ => paraformer}/.gitignore (100%) rename funasr/runtime/python/onnxruntime/{ => paraformer}/README.md (82%) create mode 100644 funasr/runtime/python/onnxruntime/paraformer/__init__.py rename funasr/runtime/python/onnxruntime/{ => paraformer}/rapid_paraformer/__init__.py (100%) rename funasr/runtime/python/onnxruntime/{ => paraformer}/rapid_paraformer/kaldifeat/LICENSE (100%) rename funasr/runtime/python/onnxruntime/{ => paraformer}/rapid_paraformer/kaldifeat/__init__.py (100%) rename funasr/runtime/python/onnxruntime/{ => paraformer}/rapid_paraformer/kaldifeat/feature.py (100%) rename funasr/runtime/python/onnxruntime/{ => paraformer}/rapid_paraformer/kaldifeat/ivector.py (100%) rename funasr/runtime/python/onnxruntime/{rapid_paraformer/rapid_paraformer.py => paraformer/rapid_paraformer/paraformer_onnx.py} (76%) create mode 100644 funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/postprocess_utils.py rename funasr/runtime/python/onnxruntime/{ => paraformer}/rapid_paraformer/utils.py (90%) rename funasr/runtime/python/onnxruntime/{ => paraformer}/requirements.txt (100%) rename funasr/runtime/python/onnxruntime/{ => paraformer}/resources/config.yaml (97%) rename funasr/runtime/python/onnxruntime/{ => paraformer}/resources/models/am.mvn (100%) rename funasr/runtime/python/onnxruntime/{ => paraformer}/resources/models/token_list.pkl (100%) diff --git a/fbank.py b/fbank.py new file mode 100644 index 000000000..26daa45f6 --- /dev/null +++ b/fbank.py @@ -0,0 +1,123 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +# Part of the implementation is borrowed from espnet/espnet. + +from typing import Tuple + +import numpy as np +import torch +import torchaudio.compliance.kaldi as kaldi +from funasr.models.frontend.abs_frontend import AbsFrontend +from typeguard import check_argument_types +from torch.nn.utils.rnn import pad_sequence +import kaldi_native_fbank as knf + +class WavFrontend(AbsFrontend): + """Conventional frontend structure for ASR. + """ + + def __init__( + self, + cmvn_file: str = None, + fs: int = 16000, + window: str = 'hamming', + n_mels: int = 80, + frame_length: int = 25, + frame_shift: int = 10, + filter_length_min: int = -1, + filter_length_max: int = -1, + lfr_m: int = 1, + lfr_n: int = 1, + dither: float = 1.0, + snip_edges: bool = True, + upsacle_samples: bool = True, + ): + assert check_argument_types() + super().__init__() + self.fs = fs + self.window = window + self.n_mels = n_mels + self.frame_length = frame_length + self.frame_shift = frame_shift + self.filter_length_min = filter_length_min + self.filter_length_max = filter_length_max + self.lfr_m = lfr_m + self.lfr_n = lfr_n + self.cmvn_file = cmvn_file + self.dither = dither + self.snip_edges = snip_edges + self.upsacle_samples = upsacle_samples + + def output_size(self) -> int: + return self.n_mels * self.lfr_m + + def forward( + self, + input: torch.Tensor, + input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size = input.size(0) + feats = [] + feats_lens = [] + for i in range(batch_size): + waveform_length = input_lengths[i] + waveform = input[i][:waveform_length] + waveform = waveform * (1 << 15) + waveform = waveform.unsqueeze(0) + mat = kaldi.fbank(waveform, + num_mel_bins=self.n_mels, + frame_length=self.frame_length, + frame_shift=self.frame_shift, + dither=self.dither, + energy_floor=0.0, + window_type=self.window, + sample_frequency=self.fs) + + feat_length = mat.size(0) + feats.append(mat) + feats_lens.append(feat_length) + + feats_lens = torch.as_tensor(feats_lens) + feats_pad = pad_sequence(feats, + batch_first=True, + padding_value=0.0) + return feats_pad, feats_lens + +import kaldi_native_fbank as knf + +def fbank_knf(waveform): + # sampling_rate = 16000 + # samples = torch.randn(16000 * 10) + + opts = knf.FbankOptions() + opts.frame_opts.samp_freq = 16000 + opts.frame_opts.dither = 0.0 + opts.frame_opts.window_type = "hamming" + opts.frame_opts.frame_shift_ms = 10.0 + opts.frame_opts.frame_length_ms = 25.0 + opts.mel_opts.num_bins = 80 + opts.energy_floor = 1 + opts.frame_opts.snip_edges = True + opts.mel_opts.debug_mel = False + + fbank = knf.OnlineFbank(opts) + waveform = waveform * (1 << 15) + fbank.accept_waveform(opts.frame_opts.samp_freq, waveform.tolist()) + frames = fbank.num_frames_ready + mat = np.empty([frames, opts.mel_opts.num_bins]) + for i in range(frames): + mat[i, :] = fbank.get_frame(i) + return mat + +if __name__ == '__main__': + import librosa + + path = "/home/zhifu.gzf/.cache/modelscope/hub/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav" + waveform, fs = librosa.load(path, sr=None) + fbank = fbank_knf(waveform) + frontend = WavFrontend(dither=0.0) + waveform_tensor = torch.from_numpy(waveform)[None, :] + fbank_torch, _ = frontend.forward(waveform_tensor, [waveform_tensor.size(1)]) + fbank_torch = fbank_torch.cpu().numpy()[0, :, :] + diff = fbank - fbank_torch + diff_max = diff.max() + diff_sum = diff.abs().sum() + pass \ No newline at end of file diff --git a/funasr/models/frontend/wav_frontend.py b/funasr/models/frontend/wav_frontend.py index 7a6425be3..ed8cb3646 100644 --- a/funasr/models/frontend/wav_frontend.py +++ b/funasr/models/frontend/wav_frontend.py @@ -171,10 +171,7 @@ class WavFrontend(AbsFrontend): window_type=self.window, sample_frequency=self.fs) - # if self.lfr_m != 1 or self.lfr_n != 1: - # mat = apply_lfr(mat, self.lfr_m, self.lfr_n) - # if self.cmvn_file is not None: - # mat = apply_cmvn(mat, self.cmvn_file) + feat_length = mat.size(0) feats.append(mat) feats_lens.append(feat_length) diff --git a/funasr/runtime/__init__.py b/funasr/runtime/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/funasr/runtime/python/__init__.py b/funasr/runtime/python/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/funasr/runtime/python/onnxruntime/__init__.py b/funasr/runtime/python/onnxruntime/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/funasr/runtime/python/onnxruntime/.gitignore b/funasr/runtime/python/onnxruntime/paraformer/.gitignore similarity index 100% rename from funasr/runtime/python/onnxruntime/.gitignore rename to funasr/runtime/python/onnxruntime/paraformer/.gitignore diff --git a/funasr/runtime/python/onnxruntime/README.md b/funasr/runtime/python/onnxruntime/paraformer/README.md similarity index 82% rename from funasr/runtime/python/onnxruntime/README.md rename to funasr/runtime/python/onnxruntime/paraformer/README.md index ee3ce0a6a..d68600f6b 100644 --- a/funasr/runtime/python/onnxruntime/README.md +++ b/funasr/runtime/python/onnxruntime/paraformer/README.md @@ -29,12 +29,6 @@ │   └── utils.py ├── README.md ├── requirements.txt - ├── resources - │   ├── config.yaml - │   └── models - │   ├── am.mvn - │   ├── model.onnx # Put it here. - │   └── token_list.pkl ├── test_onnx.py ├── tests │   ├── __pycache__ @@ -48,15 +42,15 @@ - Output: `List[str]`: recognition result. - Example: ```python - from rapid_paraformer import RapidParaformer + from paraformer_onnx import Paraformer config_path = 'resources/config.yaml' - paraformer = RapidParaformer(config_path) + model = Paraformer(config_path) - wav_path = ['test_wavs/0478_00017.wav'] + wav_path = ['example/asr_example.wav'] - result = paraformer(wav_path) + result = model(wav_path) print(result) ``` diff --git a/funasr/runtime/python/onnxruntime/paraformer/__init__.py b/funasr/runtime/python/onnxruntime/paraformer/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/funasr/runtime/python/onnxruntime/rapid_paraformer/__init__.py b/funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/__init__.py similarity index 100% rename from funasr/runtime/python/onnxruntime/rapid_paraformer/__init__.py rename to funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/__init__.py diff --git a/funasr/runtime/python/onnxruntime/rapid_paraformer/kaldifeat/LICENSE b/funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/kaldifeat/LICENSE similarity index 100% rename from funasr/runtime/python/onnxruntime/rapid_paraformer/kaldifeat/LICENSE rename to funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/kaldifeat/LICENSE diff --git a/funasr/runtime/python/onnxruntime/rapid_paraformer/kaldifeat/__init__.py b/funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/kaldifeat/__init__.py similarity index 100% rename from funasr/runtime/python/onnxruntime/rapid_paraformer/kaldifeat/__init__.py rename to funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/kaldifeat/__init__.py diff --git a/funasr/runtime/python/onnxruntime/rapid_paraformer/kaldifeat/feature.py b/funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/kaldifeat/feature.py similarity index 100% rename from funasr/runtime/python/onnxruntime/rapid_paraformer/kaldifeat/feature.py rename to funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/kaldifeat/feature.py diff --git a/funasr/runtime/python/onnxruntime/rapid_paraformer/kaldifeat/ivector.py b/funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/kaldifeat/ivector.py similarity index 100% rename from funasr/runtime/python/onnxruntime/rapid_paraformer/kaldifeat/ivector.py rename to funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/kaldifeat/ivector.py diff --git a/funasr/runtime/python/onnxruntime/rapid_paraformer/rapid_paraformer.py b/funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/paraformer_onnx.py similarity index 76% rename from funasr/runtime/python/onnxruntime/rapid_paraformer/rapid_paraformer.py rename to funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/paraformer_onnx.py index 10bfa8ae4..1fc3582ce 100644 --- a/funasr/runtime/python/onnxruntime/rapid_paraformer/rapid_paraformer.py +++ b/funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/paraformer_onnx.py @@ -1,6 +1,7 @@ # -*- encoding: utf-8 -*- # @Author: SWHL # @Contact: liekkaskono@163.com +import os.path import traceback from pathlib import Path from typing import List, Union, Tuple @@ -11,25 +12,33 @@ import numpy as np from .utils import (CharTokenizer, Hypothesis, ONNXRuntimeError, OrtInferSession, TokenIDConverter, WavFrontend, get_logger, read_yaml) +from .postprocess_utils import sentence_postprocess logging = get_logger() -class RapidParaformer(): - def __init__(self, config_path: Union[str, Path]) -> None: - if not Path(config_path).exists(): - raise FileNotFoundError(f'{config_path} does not exist.') +class Paraformer(): + def __init__(self, model_dir: Union[str, Path]=None, + batch_size: int = 1, + device_id: Union[str, int]="-1", + ): + + if not Path(model_dir).exists(): + raise FileNotFoundError(f'{model_dir} does not exist.') - config = read_yaml(config_path) + model_file = os.path.join(model_dir, 'model.onnx') + config_file = os.path.join(model_dir, 'config.yaml') + cmvn_file = os.path.join(model_dir, 'am.mvn') + config = read_yaml(config_file) - self.converter = TokenIDConverter(**config['TokenIDConverter']) - self.tokenizer = CharTokenizer(**config['CharTokenizer']) + self.converter = TokenIDConverter(config['token_list']) + self.tokenizer = CharTokenizer() self.frontend = WavFrontend( - cmvn_file=config['WavFrontend']['cmvn_file'], - **config['WavFrontend']['frontend_conf'] + cmvn_file=cmvn_file, + **config['frontend_conf'] ) - self.ort_infer = OrtInferSession(config['Model']) - self.batch_size = config['Model']['batch_size'] + self.ort_infer = OrtInferSession(model_file, device_id) + self.batch_size = batch_size def __call__(self, wav_content: Union[str, np.ndarray, List[str]]) -> List: waveform_list = self.load_data(wav_content) @@ -124,16 +133,19 @@ class RapidParaformer(): # Change integer-ids to tokens token = self.converter.ids2tokens(token_int) - text = self.tokenizer.tokens2text(token) + token = token[:valid_token_num-1] + texts = sentence_postprocess(token) + text = texts[0] + # text = self.tokenizer.tokens2text(token) return text[:valid_token_num-1] if __name__ == '__main__': project_dir = Path(__file__).resolve().parent.parent - cfg_path = project_dir / 'resources' / 'config.yaml' - paraformer = RapidParaformer(cfg_path) + model_dir = "/home/zhifu.gzf/.cache/modelscope/hub/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" + model = Paraformer(model_dir) + + wav_file = os.path.join(model_dir, 'example/asr_example.wav') + result = model(wav_file) + print(result) - wav_file = '0478_00017.wav' - for i in range(1000): - result = paraformer(wav_file) - print(result) diff --git a/funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/postprocess_utils.py b/funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/postprocess_utils.py new file mode 100644 index 000000000..575fb90dd --- /dev/null +++ b/funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/postprocess_utils.py @@ -0,0 +1,240 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import string +import logging +from typing import Any, List, Union + + +def isChinese(ch: str): + if '\u4e00' <= ch <= '\u9fff' or '\u0030' <= ch <= '\u0039': + return True + return False + + +def isAllChinese(word: Union[List[Any], str]): + word_lists = [] + for i in word: + cur = i.replace(' ', '') + cur = cur.replace('', '') + cur = cur.replace('', '') + word_lists.append(cur) + + if len(word_lists) == 0: + return False + + for ch in word_lists: + if isChinese(ch) is False: + return False + return True + + +def isAllAlpha(word: Union[List[Any], str]): + word_lists = [] + for i in word: + cur = i.replace(' ', '') + cur = cur.replace('', '') + cur = cur.replace('', '') + word_lists.append(cur) + + if len(word_lists) == 0: + return False + + for ch in word_lists: + if ch.isalpha() is False and ch != "'": + return False + elif ch.isalpha() is True and isChinese(ch) is True: + return False + + return True + + +# def abbr_dispose(words: List[Any]) -> List[Any]: +def abbr_dispose(words: List[Any], time_stamp: List[List] = None) -> List[Any]: + words_size = len(words) + word_lists = [] + abbr_begin = [] + abbr_end = [] + last_num = -1 + ts_lists = [] + ts_nums = [] + ts_index = 0 + for num in range(words_size): + if num <= last_num: + continue + + if len(words[num]) == 1 and words[num].encode('utf-8').isalpha(): + if num + 1 < words_size and words[ + num + 1] == ' ' and num + 2 < words_size and len( + words[num + + 2]) == 1 and words[num + + 2].encode('utf-8').isalpha(): + # found the begin of abbr + abbr_begin.append(num) + num += 2 + abbr_end.append(num) + # to find the end of abbr + while True: + num += 1 + if num < words_size and words[num] == ' ': + num += 1 + if num < words_size and len( + words[num]) == 1 and words[num].encode( + 'utf-8').isalpha(): + abbr_end.pop() + abbr_end.append(num) + last_num = num + else: + break + else: + break + + for num in range(words_size): + if words[num] == ' ': + ts_nums.append(ts_index) + else: + ts_nums.append(ts_index) + ts_index += 1 + last_num = -1 + for num in range(words_size): + if num <= last_num: + continue + + if num in abbr_begin: + if time_stamp is not None: + begin = time_stamp[ts_nums[num]][0] + word_lists.append(words[num].upper()) + num += 1 + while num < words_size: + if num in abbr_end: + word_lists.append(words[num].upper()) + last_num = num + break + else: + if words[num].encode('utf-8').isalpha(): + word_lists.append(words[num].upper()) + num += 1 + if time_stamp is not None: + end = time_stamp[ts_nums[num]][1] + ts_lists.append([begin, end]) + else: + word_lists.append(words[num]) + if time_stamp is not None and words[num] != ' ': + begin = time_stamp[ts_nums[num]][0] + end = time_stamp[ts_nums[num]][1] + ts_lists.append([begin, end]) + begin = end + + if time_stamp is not None: + return word_lists, ts_lists + else: + return word_lists + + +def sentence_postprocess(words: List[Any], time_stamp: List[List] = None): + middle_lists = [] + word_lists = [] + word_item = '' + ts_lists = [] + + # wash words lists + for i in words: + word = '' + if isinstance(i, str): + word = i + else: + word = i.decode('utf-8') + + if word in ['', '', '']: + continue + else: + middle_lists.append(word) + + # all chinese characters + if isAllChinese(middle_lists): + for i, ch in enumerate(middle_lists): + word_lists.append(ch.replace(' ', '')) + if time_stamp is not None: + ts_lists = time_stamp + + # all alpha characters + elif isAllAlpha(middle_lists): + ts_flag = True + for i, ch in enumerate(middle_lists): + if ts_flag and time_stamp is not None: + begin = time_stamp[i][0] + end = time_stamp[i][1] + word = '' + if '@@' in ch: + word = ch.replace('@@', '') + word_item += word + if time_stamp is not None: + ts_flag = False + end = time_stamp[i][1] + else: + word_item += ch + word_lists.append(word_item) + word_lists.append(' ') + word_item = '' + if time_stamp is not None: + ts_flag = True + end = time_stamp[i][1] + ts_lists.append([begin, end]) + begin = end + + # mix characters + else: + alpha_blank = False + ts_flag = True + begin = -1 + end = -1 + for i, ch in enumerate(middle_lists): + if ts_flag and time_stamp is not None: + begin = time_stamp[i][0] + end = time_stamp[i][1] + word = '' + if isAllChinese(ch): + if alpha_blank is True: + word_lists.pop() + word_lists.append(ch) + alpha_blank = False + if time_stamp is not None: + ts_flag = True + ts_lists.append([begin, end]) + begin = end + elif '@@' in ch: + word = ch.replace('@@', '') + word_item += word + alpha_blank = False + if time_stamp is not None: + ts_flag = False + end = time_stamp[i][1] + elif isAllAlpha(ch): + word_item += ch + word_lists.append(word_item) + word_lists.append(' ') + word_item = '' + alpha_blank = True + if time_stamp is not None: + ts_flag = True + end = time_stamp[i][1] + ts_lists.append([begin, end]) + begin = end + else: + raise ValueError('invalid character: {}'.format(ch)) + + if time_stamp is not None: + word_lists, ts_lists = abbr_dispose(word_lists, ts_lists) + real_word_lists = [] + for ch in word_lists: + if ch != ' ': + real_word_lists.append(ch) + sentence = ' '.join(real_word_lists).strip() + return sentence, ts_lists, real_word_lists + else: + word_lists = abbr_dispose(word_lists) + real_word_lists = [] + for ch in word_lists: + if ch != ' ': + real_word_lists.append(ch) + sentence = ''.join(word_lists).strip() + return sentence, real_word_lists diff --git a/funasr/runtime/python/onnxruntime/rapid_paraformer/utils.py b/funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/utils.py similarity index 90% rename from funasr/runtime/python/onnxruntime/rapid_paraformer/utils.py rename to funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/utils.py index 839adb4c4..ea3c0b7f7 100644 --- a/funasr/runtime/python/onnxruntime/rapid_paraformer/utils.py +++ b/funasr/runtime/python/onnxruntime/paraformer/rapid_paraformer/utils.py @@ -14,6 +14,7 @@ from onnxruntime import (GraphOptimizationLevel, InferenceSession, from typeguard import check_argument_types from .kaldifeat import compute_fbank_feats +import warnings root_dir = Path(__file__).resolve().parent @@ -21,24 +22,25 @@ logger_initialized = {} class TokenIDConverter(): - def __init__(self, token_path: Union[Path, str], + def __init__(self, token_list: Union[Path, str], unk_symbol: str = "",): check_argument_types() - self.token_list = self.load_token(token_path) - self.unk_symbol = unk_symbol + # self.token_list = self.load_token(token_path) + self.token_list = token_list + self.unk_symbol = token_list[-1] - @staticmethod - def load_token(file_path: Union[Path, str]) -> List: - if not Path(file_path).exists(): - raise TokenIDConverterError(f'The {file_path} does not exist.') - - with open(str(file_path), 'rb') as f: - token_list = pickle.load(f) - - if len(token_list) != len(set(token_list)): - raise TokenIDConverterError('The Token exists duplicated symbol.') - return token_list + # @staticmethod + # def load_token(file_path: Union[Path, str]) -> List: + # if not Path(file_path).exists(): + # raise TokenIDConverterError(f'The {file_path} does not exist.') + # + # with open(str(file_path), 'rb') as f: + # token_list = pickle.load(f) + # + # if len(token_list) != len(set(token_list)): + # raise TokenIDConverterError('The Token exists duplicated symbol.') + # return token_list def get_num_vocabulary_size(self) -> int: return len(self.token_list) @@ -268,31 +270,36 @@ class ONNXRuntimeError(Exception): class OrtInferSession(): - def __init__(self, config): + def __init__(self, model_file, device_id=-1): sess_opt = SessionOptions() sess_opt.log_severity_level = 4 sess_opt.enable_cpu_mem_arena = False sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL cuda_ep = 'CUDAExecutionProvider' + cuda_provider_options = { + "device_id": device_id, + "arena_extend_strategy": "kNextPowerOfTwo", + "cudnn_conv_algo_search": "EXHAUSTIVE", + "do_copy_in_default_stream": "true", + } cpu_ep = 'CPUExecutionProvider' cpu_provider_options = { "arena_extend_strategy": "kSameAsRequested", } EP_list = [] - if config['use_cuda'] and get_device() == 'GPU' \ + if device_id != -1 and get_device() == 'GPU' \ and cuda_ep in get_available_providers(): - EP_list = [(cuda_ep, config[cuda_ep])] + EP_list = [(cuda_ep, cuda_provider_options)] EP_list.append((cpu_ep, cpu_provider_options)) - config['model_path'] = config['model_path'] - self._verify_model(config['model_path']) - self.session = InferenceSession(config['model_path'], + self._verify_model(model_file) + self.session = InferenceSession(model_file, sess_options=sess_opt, providers=EP_list) - if config['use_cuda'] and cuda_ep not in self.session.get_providers(): + if device_id != -1 and cuda_ep not in self.session.get_providers(): warnings.warn(f'{cuda_ep} is not avaiable for current env, the inference part is automatically shifted to be executed under {cpu_ep}.\n' 'Please ensure the installed onnxruntime-gpu version matches your cuda and cudnn version, ' 'you can check their relations from the offical web site: ' diff --git a/funasr/runtime/python/onnxruntime/requirements.txt b/funasr/runtime/python/onnxruntime/paraformer/requirements.txt similarity index 100% rename from funasr/runtime/python/onnxruntime/requirements.txt rename to funasr/runtime/python/onnxruntime/paraformer/requirements.txt diff --git a/funasr/runtime/python/onnxruntime/resources/config.yaml b/funasr/runtime/python/onnxruntime/paraformer/resources/config.yaml similarity index 97% rename from funasr/runtime/python/onnxruntime/resources/config.yaml rename to funasr/runtime/python/onnxruntime/paraformer/resources/config.yaml index fd243c304..83736a422 100644 --- a/funasr/runtime/python/onnxruntime/resources/config.yaml +++ b/funasr/runtime/python/onnxruntime/paraformer/resources/config.yaml @@ -18,6 +18,7 @@ WavFrontend: lfr_m: 7 lfr_n: 6 filter_length_max: -.inf + dither: 0.0 Model: model_path: resources/models/model.onnx diff --git a/funasr/runtime/python/onnxruntime/resources/models/am.mvn b/funasr/runtime/python/onnxruntime/paraformer/resources/models/am.mvn similarity index 100% rename from funasr/runtime/python/onnxruntime/resources/models/am.mvn rename to funasr/runtime/python/onnxruntime/paraformer/resources/models/am.mvn diff --git a/funasr/runtime/python/onnxruntime/resources/models/token_list.pkl b/funasr/runtime/python/onnxruntime/paraformer/resources/models/token_list.pkl similarity index 100% rename from funasr/runtime/python/onnxruntime/resources/models/token_list.pkl rename to funasr/runtime/python/onnxruntime/paraformer/resources/models/token_list.pkl