mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
bug fix
This commit is contained in:
parent
fcb2102a60
commit
1e5ef6ed9a
@ -1,12 +1,10 @@
|
||||
import os
|
||||
import torch
|
||||
import functools
|
||||
import onnx
|
||||
from onnxconverter_common import float16
|
||||
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
|
||||
def export(
|
||||
@ -44,14 +42,13 @@ def export(
|
||||
print(f"export_dir: {export_dir}")
|
||||
_torchscripts(m, path=export_dir, device="cuda")
|
||||
|
||||
|
||||
elif type=='onnx_fp16':
|
||||
elif type == "onnx_fp16":
|
||||
assert (
|
||||
torch.cuda.is_available()
|
||||
), "Currently onnx_fp16 optimization for FunASR only supports GPU"
|
||||
), "Currently onnx_fp16 optimization for FunASR only supports GPU"
|
||||
|
||||
if hasattr(m, "encoder") and hasattr(m, "decoder"):
|
||||
_onnx_opt_for_encdec(m, path=export_dir, enable_fp16=True)
|
||||
_onnx_opt_for_encdec(m, path=export_dir, enable_fp16=True)
|
||||
|
||||
return export_dir
|
||||
|
||||
@ -73,7 +70,6 @@ def _onnx(
|
||||
else:
|
||||
dummy_input = tuple([input.to(device) for input in dummy_input])
|
||||
|
||||
|
||||
verbose = kwargs.get("verbose", False)
|
||||
|
||||
if isinstance(model.export_name, str):
|
||||
@ -94,8 +90,13 @@ def _onnx(
|
||||
)
|
||||
|
||||
if quantize:
|
||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||
import onnx
|
||||
try:
|
||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||
import onnx
|
||||
except:
|
||||
raise RuntimeError(
|
||||
"You are quantizing the onnx model, please install onnxruntime first. via \n`pip install onnx`\n`pip install onnxruntime`."
|
||||
)
|
||||
|
||||
quant_model_path = model_path.replace(".onnx", "_quant.onnx")
|
||||
onnx_model = onnx.load(model_path)
|
||||
@ -117,19 +118,21 @@ def _onnx(
|
||||
|
||||
def _torchscripts(model, path, device="cuda"):
|
||||
dummy_input = model.export_dummy_inputs()
|
||||
|
||||
|
||||
if device == "cuda":
|
||||
model = model.cuda()
|
||||
if isinstance(dummy_input, torch.Tensor):
|
||||
dummy_input = dummy_input.cuda()
|
||||
else:
|
||||
dummy_input = tuple([i.cuda() for i in dummy_input])
|
||||
|
||||
|
||||
model_script = torch.jit.trace(model, dummy_input)
|
||||
if isinstance(model.export_name, str):
|
||||
model_script.save(os.path.join(path, f"{model.export_name}".replace("onnx", "torchscript")))
|
||||
else:
|
||||
model_script.save(os.path.join(path, f"{model.export_name()}".replace("onnx", "torchscript")))
|
||||
model_script.save(
|
||||
os.path.join(path, f"{model.export_name()}".replace("onnx", "torchscript"))
|
||||
)
|
||||
|
||||
|
||||
def _bladedisc_opt(model, model_inputs, enable_fp16=True):
|
||||
@ -225,7 +228,6 @@ def _bladedisc_opt_for_encdec(model, path, enable_fp16):
|
||||
model_script.save(os.path.join(path, f"{model.export_name}_blade.torchscript"))
|
||||
|
||||
|
||||
|
||||
def _onnx_opt_for_encdec(model, path, enable_fp16):
|
||||
|
||||
# Get input data
|
||||
@ -267,16 +269,19 @@ def _onnx_opt_for_encdec(model, path, enable_fp16):
|
||||
input_names=model.export_input_names(),
|
||||
output_names=model.export_output_names(),
|
||||
dynamic_axes=model.export_dynamic_axes(),
|
||||
)
|
||||
|
||||
)
|
||||
|
||||
# fp32 to fp16
|
||||
fp16_model_path = f"{path}/{model.export_name}_hook_fp16.onnx"
|
||||
print("*" * 50)
|
||||
print(f"[_onnx_opt_for_encdec(fp16)]: {fp16_model_path}\n\n")
|
||||
if os.path.exists(fp32_model_path) and not os.path.exists(fp16_model_path):
|
||||
try:
|
||||
from onnxconverter_common import float16
|
||||
except:
|
||||
raise RuntimeError(
|
||||
"You are converting the onnx model to fp16, please install onnxconverter-common first. via `pip install onnxconverter-common`."
|
||||
)
|
||||
fp32_onnx_model = onnx.load(fp32_model_path)
|
||||
fp16_onnx_model = float16.convert_float_to_float16(fp32_onnx_model, keep_io_types=True)
|
||||
onnx.save(
|
||||
fp16_onnx_model, fp16_model_path
|
||||
)
|
||||
onnx.save(fp16_onnx_model, fp16_model_path)
|
||||
|
||||
@ -10,7 +10,6 @@ import torchaudio
|
||||
import time
|
||||
import logging
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from pydub import AudioSegment
|
||||
|
||||
try:
|
||||
from funasr.download.file import download_from_url
|
||||
@ -20,6 +19,11 @@ import pdb
|
||||
import subprocess
|
||||
from subprocess import CalledProcessError, run
|
||||
|
||||
try:
|
||||
from pydub import AudioSegment
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
def is_ffmpeg_installed():
|
||||
try:
|
||||
@ -166,7 +170,12 @@ def validate_frame_rate(
|
||||
byte_data = BytesIO(input)
|
||||
|
||||
# 使用 pydub 加载音频
|
||||
audio = AudioSegment.from_file(byte_data)
|
||||
try:
|
||||
audio = AudioSegment.from_file(byte_data)
|
||||
except:
|
||||
raise RuntimeError(
|
||||
"You are decoding the pcm data, please install pydub first. via `pip install pydub`."
|
||||
)
|
||||
|
||||
# 确保采样率为 16000 Hz
|
||||
if audio.frame_rate != fs:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user