diff --git a/.gitignore b/.gitignore index 7a434073f..825837795 100644 --- a/.gitignore +++ b/.gitignore @@ -6,4 +6,5 @@ .DS_Store init_model/ *.tar.gz -test_local/ \ No newline at end of file +test_local/ +RapidASR \ No newline at end of file diff --git a/funasr/export/export_model.py b/funasr/export/export_model.py index 933a9271c..8d41462b0 100644 --- a/funasr/export/export_model.py +++ b/funasr/export/export_model.py @@ -7,11 +7,13 @@ import os import logging import torch -from funasr.bin.asr_inference_paraformer import Speech2Text from funasr.export.models import get_model import numpy as np import random +torch_version = float(".".join(torch.__version__.split(".")[:2])) +assert torch_version > 1.9 + class ASRModelExportParaformer: def __init__(self, cache_dir: Union[Path, str] = None, onnx: bool = True): assert check_argument_types() @@ -30,7 +32,7 @@ class ASRModelExportParaformer: def _export( self, - model: Speech2Text, + model, tag_name: str = None, verbose: bool = False, ): @@ -118,110 +120,6 @@ class ASRModelExportParaformer: ) -class ASRModelExport: - def __init__(self, cache_dir: Union[Path, str] = None, onnx: bool = True): - assert check_argument_types() - self.set_all_random_seed(0) - if cache_dir is None: - cache_dir = Path.home() / ".cache" / "export" - - self.cache_dir = Path(cache_dir) - self.export_config = dict( - feats_dim=560, - onnx=False, - ) - print("output dir: {}".format(self.cache_dir)) - self.onnx = onnx - - def _export( - self, - model: Speech2Text, - tag_name: str = None, - verbose: bool = False, - ): - - export_dir = self.cache_dir / tag_name.replace(' ', '-') - os.makedirs(export_dir, exist_ok=True) - - # export encoder1 - self.export_config["model_name"] = "model" - model = get_model( - model, - self.export_config, - ) - model.eval() - # self._export_onnx(model, verbose, export_dir) - if self.onnx: - self._export_onnx(model, verbose, export_dir) - else: - self._export_torchscripts(model, verbose, export_dir) - - print("output dir: {}".format(export_dir)) - - def _export_torchscripts(self, model, verbose, path, enc_size=None): - if enc_size: - dummy_input = model.get_dummy_inputs(enc_size) - else: - dummy_input = model.get_dummy_inputs_txt() - - # model_script = torch.jit.script(model) - model_script = torch.jit.trace(model, dummy_input) - model_script.save(os.path.join(path, f'{model.model_name}.torchscripts')) - - def set_all_random_seed(self, seed: int): - random.seed(seed) - np.random.seed(seed) - torch.random.manual_seed(seed) - - def export(self, - tag_name: str = 'damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch', - mode: str = 'paraformer', - ): - - model_dir = tag_name - if model_dir.startswith('damo/'): - from modelscope.hub.snapshot_download import snapshot_download - model_dir = snapshot_download(model_dir, cache_dir=self.cache_dir) - asr_train_config = os.path.join(model_dir, 'config.yaml') - asr_model_file = os.path.join(model_dir, 'model.pb') - cmvn_file = os.path.join(model_dir, 'am.mvn') - json_file = os.path.join(model_dir, 'configuration.json') - if mode is None: - import json - with open(json_file, 'r') as f: - config_data = json.load(f) - mode = config_data['model']['model_config']['mode'] - if mode.startswith('paraformer'): - from funasr.tasks.asr import ASRTaskParaformer as ASRTask - elif mode.startswith('uniasr'): - from funasr.tasks.asr import ASRTaskUniASR as ASRTask - - model, asr_train_args = ASRTask.build_model_from_file( - asr_train_config, asr_model_file, cmvn_file, 'cpu' - ) - self._export(model, tag_name) - - def _export_onnx(self, model, verbose, path, enc_size=None): - if enc_size: - dummy_input = model.get_dummy_inputs(enc_size) - else: - dummy_input = model.get_dummy_inputs() - - # model_script = torch.jit.script(model) - model_script = model # torch.jit.trace(model) - - torch.onnx.export( - model_script, - dummy_input, - os.path.join(path, f'{model.model_name}.onnx'), - verbose=verbose, - opset_version=12, - input_names=model.get_input_names(), - output_names=model.get_output_names(), - dynamic_axes=model.get_dynamic_axes() - ) - - if __name__ == '__main__': import sys diff --git a/funasr/export/models/predictor/cif.py b/funasr/export/models/predictor/cif.py index c8df7f381..cb26862bf 100644 --- a/funasr/export/models/predictor/cif.py +++ b/funasr/export/models/predictor/cif.py @@ -77,6 +77,53 @@ class CifPredictorV2(nn.Module): return hidden, alphas, token_num_floor +# @torch.jit.script +# def cif(hidden, alphas, threshold: float): +# batch_size, len_time, hidden_size = hidden.size() +# threshold = torch.tensor([threshold], dtype=alphas.dtype).to(alphas.device) +# +# # loop varss +# integrate = torch.zeros([batch_size], device=hidden.device) +# frame = torch.zeros([batch_size, hidden_size], device=hidden.device) +# # intermediate vars along time +# list_fires = [] +# list_frames = [] +# +# for t in range(len_time): +# alpha = alphas[:, t] +# distribution_completion = torch.ones([batch_size], device=hidden.device) - integrate +# +# integrate += alpha +# list_fires.append(integrate) +# +# fire_place = integrate >= threshold +# integrate = torch.where(fire_place, +# integrate - torch.ones([batch_size], device=hidden.device), +# integrate) +# cur = torch.where(fire_place, +# distribution_completion, +# alpha) +# remainds = alpha - cur +# +# frame += cur[:, None] * hidden[:, t, :] +# list_frames.append(frame) +# frame = torch.where(fire_place[:, None].repeat(1, hidden_size), +# remainds[:, None] * hidden[:, t, :], +# frame) +# +# fires = torch.stack(list_fires, 1) +# frames = torch.stack(list_frames, 1) +# list_ls = [] +# len_labels = torch.floor(alphas.sum(-1)).int() +# max_label_len = len_labels.max() +# for b in range(batch_size): +# fire = fires[b, :] +# l = torch.index_select(frames[b, :, :], 0, torch.nonzero(fire >= threshold).squeeze()) +# pad_l = torch.zeros([int(max_label_len - l.size(0)), int(hidden_size)], device=hidden.device) +# list_ls.append(torch.cat([l, pad_l], 0)) +# return torch.stack(list_ls, 0), fires + + @torch.jit.script def cif(hidden, alphas, threshold: float): batch_size, len_time, hidden_size = hidden.size() @@ -113,15 +160,11 @@ def cif(hidden, alphas, threshold: float): fires = torch.stack(list_fires, 1) frames = torch.stack(list_frames, 1) - # list_ls = [] - len_labels = torch.round(alphas.sum(-1)).type(torch.int32) - # max_label_len = int(torch.max(len_labels).item()) - # print("type: {}".format(type(max_label_len))) + fire_idxs = fires >= threshold frame_fires = torch.zeros_like(hidden) max_label_len = frames[0, fire_idxs[0]].size(0) for b in range(batch_size): - # fire = fires[b, :] frame_fire = frames[b, fire_idxs[b]] frame_len = frame_fire.size(0) frame_fires[b, :frame_len, :] = frame_fire diff --git a/funasr/runtime/python/onnxruntime/rapid_paraformer/utils/utils.py b/funasr/runtime/python/onnxruntime/rapid_paraformer/utils/utils.py index 8e220e098..7943abbf9 100644 --- a/funasr/runtime/python/onnxruntime/rapid_paraformer/utils/utils.py +++ b/funasr/runtime/python/onnxruntime/rapid_paraformer/utils/utils.py @@ -148,6 +148,7 @@ class ONNXRuntimeError(Exception): class OrtInferSession(): def __init__(self, model_file, device_id=-1): + device_id = str(device_id) sess_opt = SessionOptions() sess_opt.log_severity_level = 4 sess_opt.enable_cpu_mem_arena = False @@ -166,7 +167,7 @@ class OrtInferSession(): } EP_list = [] - if device_id != -1 and get_device() == 'GPU' \ + if device_id != "-1" and get_device() == 'GPU' \ and cuda_ep in get_available_providers(): EP_list = [(cuda_ep, cuda_provider_options)] EP_list.append((cpu_ep, cpu_provider_options)) @@ -176,7 +177,7 @@ class OrtInferSession(): sess_options=sess_opt, providers=EP_list) - if device_id != -1 and cuda_ep not in self.session.get_providers(): + if device_id != "-1" and cuda_ep not in self.session.get_providers(): warnings.warn(f'{cuda_ep} is not avaiable for current env, the inference part is automatically shifted to be executed under {cpu_ep}.\n' 'Please ensure the installed onnxruntime-gpu version matches your cuda and cudnn version, ' 'you can check their relations from the offical web site: ' diff --git a/scan.py b/scan.py deleted file mode 100644 index e69de29bb..000000000