diff --git a/funasr/utils/export_utils.py b/funasr/utils/export_utils.py index 0691ef2d5..ca04d75ce 100644 --- a/funasr/utils/export_utils.py +++ b/funasr/utils/export_utils.py @@ -1,12 +1,10 @@ import os import torch import functools -import onnx -from onnxconverter_common import float16 import warnings -warnings.filterwarnings("ignore") +warnings.filterwarnings("ignore") def export( @@ -44,14 +42,13 @@ def export( print(f"export_dir: {export_dir}") _torchscripts(m, path=export_dir, device="cuda") - - elif type=='onnx_fp16': + elif type == "onnx_fp16": assert ( torch.cuda.is_available() - ), "Currently onnx_fp16 optimization for FunASR only supports GPU" + ), "Currently onnx_fp16 optimization for FunASR only supports GPU" if hasattr(m, "encoder") and hasattr(m, "decoder"): - _onnx_opt_for_encdec(m, path=export_dir, enable_fp16=True) + _onnx_opt_for_encdec(m, path=export_dir, enable_fp16=True) return export_dir @@ -73,7 +70,6 @@ def _onnx( else: dummy_input = tuple([input.to(device) for input in dummy_input]) - verbose = kwargs.get("verbose", False) if isinstance(model.export_name, str): @@ -94,8 +90,13 @@ def _onnx( ) if quantize: - from onnxruntime.quantization import QuantType, quantize_dynamic - import onnx + try: + from onnxruntime.quantization import QuantType, quantize_dynamic + import onnx + except: + raise RuntimeError( + "You are quantizing the onnx model, please install onnxruntime first. via \n`pip install onnx`\n`pip install onnxruntime`." + ) quant_model_path = model_path.replace(".onnx", "_quant.onnx") onnx_model = onnx.load(model_path) @@ -117,19 +118,21 @@ def _onnx( def _torchscripts(model, path, device="cuda"): dummy_input = model.export_dummy_inputs() - + if device == "cuda": model = model.cuda() if isinstance(dummy_input, torch.Tensor): dummy_input = dummy_input.cuda() else: dummy_input = tuple([i.cuda() for i in dummy_input]) - + model_script = torch.jit.trace(model, dummy_input) if isinstance(model.export_name, str): model_script.save(os.path.join(path, f"{model.export_name}".replace("onnx", "torchscript"))) else: - model_script.save(os.path.join(path, f"{model.export_name()}".replace("onnx", "torchscript"))) + model_script.save( + os.path.join(path, f"{model.export_name()}".replace("onnx", "torchscript")) + ) def _bladedisc_opt(model, model_inputs, enable_fp16=True): @@ -225,7 +228,6 @@ def _bladedisc_opt_for_encdec(model, path, enable_fp16): model_script.save(os.path.join(path, f"{model.export_name}_blade.torchscript")) - def _onnx_opt_for_encdec(model, path, enable_fp16): # Get input data @@ -267,16 +269,19 @@ def _onnx_opt_for_encdec(model, path, enable_fp16): input_names=model.export_input_names(), output_names=model.export_output_names(), dynamic_axes=model.export_dynamic_axes(), - ) - + ) # fp32 to fp16 fp16_model_path = f"{path}/{model.export_name}_hook_fp16.onnx" print("*" * 50) print(f"[_onnx_opt_for_encdec(fp16)]: {fp16_model_path}\n\n") if os.path.exists(fp32_model_path) and not os.path.exists(fp16_model_path): + try: + from onnxconverter_common import float16 + except: + raise RuntimeError( + "You are converting the onnx model to fp16, please install onnxconverter-common first. via `pip install onnxconverter-common`." + ) fp32_onnx_model = onnx.load(fp32_model_path) fp16_onnx_model = float16.convert_float_to_float16(fp32_onnx_model, keep_io_types=True) - onnx.save( - fp16_onnx_model, fp16_model_path - ) + onnx.save(fp16_onnx_model, fp16_model_path) diff --git a/funasr/utils/load_utils.py b/funasr/utils/load_utils.py index 52753c4e2..1d80fcf36 100644 --- a/funasr/utils/load_utils.py +++ b/funasr/utils/load_utils.py @@ -10,7 +10,6 @@ import torchaudio import time import logging from torch.nn.utils.rnn import pad_sequence -from pydub import AudioSegment try: from funasr.download.file import download_from_url @@ -20,6 +19,11 @@ import pdb import subprocess from subprocess import CalledProcessError, run +try: + from pydub import AudioSegment +except: + pass + def is_ffmpeg_installed(): try: @@ -166,7 +170,12 @@ def validate_frame_rate( byte_data = BytesIO(input) # 使用 pydub 加载音频 - audio = AudioSegment.from_file(byte_data) + try: + audio = AudioSegment.from_file(byte_data) + except: + raise RuntimeError( + "You are decoding the pcm data, please install pydub first. via `pip install pydub`." + ) # 确保采样率为 16000 Hz if audio.frame_rate != fs: