mirror of
https://github.com/FunAudioLLM/SenseVoice.git
synced 2025-09-15 15:08:35 +08:00
ONNX support
This commit is contained in:
parent
ace1d34b69
commit
0858308f36
6
demo1.py
6
demo1.py
@ -1,9 +1,3 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- encoding: utf-8 -*-
|
||||
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||
# MIT License (https://opensource.org/licenses/MIT)
|
||||
|
||||
|
||||
from funasr import AutoModel
|
||||
from funasr.utils.postprocess_utils import rich_transcription_postprocess
|
||||
|
||||
|
||||
5
demo2.py
5
demo2.py
@ -1,8 +1,3 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- encoding: utf-8 -*-
|
||||
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||
# MIT License (https://opensource.org/licenses/MIT)
|
||||
|
||||
from model import SenseVoiceSmall
|
||||
from funasr.utils.postprocess_utils import rich_transcription_postprocess
|
||||
|
||||
|
||||
47
export.py
Normal file
47
export.py
Normal file
@ -0,0 +1,47 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- encoding: utf-8 -*-
|
||||
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||
# MIT License (https://opensource.org/licenses/MIT)
|
||||
|
||||
import os
|
||||
import torch
|
||||
from model import SenseVoiceSmall
|
||||
from utils import export_utils
|
||||
from utils.model_bin import SenseVoiceSmallONNX
|
||||
from funasr.utils.postprocess_utils import rich_transcription_postprocess
|
||||
|
||||
quantize = False
|
||||
|
||||
model_dir = "iic/SenseVoiceSmall"
|
||||
model, kwargs = SenseVoiceSmall.from_pretrained(model=model_dir, device="cuda:0")
|
||||
|
||||
rebuilt_model = model.export(type="onnx", quantize=False)
|
||||
model_path = kwargs.get("output_dir", os.path.dirname(kwargs.get("init_param")))
|
||||
|
||||
model_file = os.path.join(model_path, "model.onnx")
|
||||
if quantize:
|
||||
model_file = os.path.join(model_path, "model_quant.onnx")
|
||||
|
||||
# export model
|
||||
if not os.path.exists(model_file):
|
||||
with torch.no_grad():
|
||||
del kwargs['model']
|
||||
export_dir = export_utils.export(model=rebuilt_model, **kwargs)
|
||||
print("Export model onnx to {}".format(model_file))
|
||||
|
||||
# export model init
|
||||
model_bin = SenseVoiceSmallONNX(model_path)
|
||||
|
||||
# build tokenizer
|
||||
try:
|
||||
from funasr.tokenizer.sentencepiece_tokenizer import SentencepiecesTokenizer
|
||||
tokenizer = SentencepiecesTokenizer(bpemodel=os.path.join(model_path, "chn_jpn_yue_eng_ko_spectok.bpe.model"))
|
||||
except:
|
||||
tokenizer = None
|
||||
|
||||
# inference
|
||||
wav_or_scp = "/Users/shixian/Downloads/asr_example_hotword.wav"
|
||||
language_list = [0]
|
||||
textnorm_list = [15]
|
||||
res = model_bin(wav_or_scp, language_list, textnorm_list, tokenizer=tokenizer)
|
||||
print(res)
|
||||
@ -5,33 +5,20 @@
|
||||
|
||||
import types
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from funasr.register import tables
|
||||
from funasr.utils.torch_function import sequence_mask
|
||||
|
||||
|
||||
def export_rebuild_model(model, **kwargs):
|
||||
model.device = kwargs.get("device")
|
||||
is_onnx = kwargs.get("type", "onnx") == "onnx"
|
||||
# encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
|
||||
# model.encoder = encoder_class(model.encoder, onnx=is_onnx)
|
||||
|
||||
|
||||
|
||||
from funasr.utils.torch_function import sequence_mask
|
||||
|
||||
model.make_pad_mask = sequence_mask(kwargs["max_seq_len"], flip=False)
|
||||
|
||||
model.forward = types.MethodType(export_forward, model)
|
||||
model.export_dummy_inputs = types.MethodType(export_dummy_inputs, model)
|
||||
model.export_input_names = types.MethodType(export_input_names, model)
|
||||
model.export_output_names = types.MethodType(export_output_names, model)
|
||||
model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model)
|
||||
model.export_name = types.MethodType(export_name, model)
|
||||
|
||||
model.export_name = "model"
|
||||
return model
|
||||
|
||||
|
||||
def export_forward(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
@ -40,12 +27,11 @@ def export_forward(
|
||||
textnorm: torch.Tensor,
|
||||
**kwargs,
|
||||
):
|
||||
speech = speech.to(device=kwargs["device"])
|
||||
speech_lengths = speech_lengths.to(device=kwargs["device"])
|
||||
|
||||
language_query = self.embed(language).to(speech.device)
|
||||
|
||||
textnorm_query = self.embed(textnorm).to(speech.device)
|
||||
# speech = speech.to(device="cuda")
|
||||
# speech_lengths = speech_lengths.to(device="cuda")
|
||||
language_query = self.embed(language.to(speech.device)).unsqueeze(1)
|
||||
textnorm_query = self.embed(textnorm.to(speech.device)).unsqueeze(1)
|
||||
print(textnorm_query.shape, speech.shape)
|
||||
speech = torch.cat((textnorm_query, speech), dim=1)
|
||||
speech_lengths += 1
|
||||
|
||||
@ -56,18 +42,14 @@ def export_forward(
|
||||
speech = torch.cat((input_query, speech), dim=1)
|
||||
speech_lengths += 3
|
||||
|
||||
# Encoder
|
||||
encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths)
|
||||
if isinstance(encoder_out, tuple):
|
||||
encoder_out = encoder_out[0]
|
||||
|
||||
# c. Passed the encoder result and the beam search
|
||||
ctc_logits = self.ctc.log_softmax(encoder_out)
|
||||
|
||||
|
||||
ctc_logits = self.ctc.ctc_lo(encoder_out)
|
||||
|
||||
return ctc_logits, encoder_out_lens
|
||||
|
||||
|
||||
def export_dummy_inputs(self):
|
||||
speech = torch.randn(2, 30, 560)
|
||||
speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
|
||||
@ -75,27 +57,22 @@ def export_dummy_inputs(self):
|
||||
textnorm = torch.tensor([15, 15], dtype=torch.int32)
|
||||
return (speech, speech_lengths, language, textnorm)
|
||||
|
||||
|
||||
def export_input_names(self):
|
||||
return ["speech", "speech_lengths", "language", "textnorm"]
|
||||
|
||||
|
||||
def export_output_names(self):
|
||||
return ["ctc_logits", "encoder_out_lens"]
|
||||
|
||||
|
||||
def export_dynamic_axes(self):
|
||||
return {
|
||||
"speech": {0: "batch_size", 1: "feats_length"},
|
||||
"speech_lengths": {
|
||||
0: "batch_size",
|
||||
},
|
||||
"logits": {0: "batch_size", 1: "logits_length"},
|
||||
"speech_lengths": {0: "batch_size"},
|
||||
"language": {0: "batch_size"},
|
||||
"textnorm": {0: "batch_size"},
|
||||
"ctc_logits": {0: "batch_size", 1: "logits_length"},
|
||||
"encoder_out_lens": {0: "batch_size"},
|
||||
}
|
||||
|
||||
|
||||
def export_name(
|
||||
self,
|
||||
):
|
||||
def export_name(self):
|
||||
return "model.onnx"
|
||||
|
||||
|
||||
30
model.py
30
model.py
@ -1,24 +1,18 @@
|
||||
from typing import Iterable, Optional
|
||||
import types
|
||||
import time
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch import nn
|
||||
from torch.cuda.amp import autocast
|
||||
from funasr.metrics.compute_acc import compute_accuracy, th_accuracy
|
||||
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
|
||||
from funasr.train_utils.device_funcs import force_gatherable
|
||||
|
||||
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
|
||||
from funasr.utils.datadir_writer import DatadirWriter
|
||||
from funasr.models.ctc.ctc import CTC
|
||||
import time
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Iterable, Optional
|
||||
|
||||
from funasr.register import tables
|
||||
|
||||
|
||||
from funasr.models.ctc.ctc import CTC
|
||||
from funasr.utils.datadir_writer import DatadirWriter
|
||||
from funasr.models.paraformer.search import Hypothesis
|
||||
from funasr.train_utils.device_funcs import force_gatherable
|
||||
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
|
||||
from funasr.metrics.compute_acc import compute_accuracy, th_accuracy
|
||||
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
|
||||
|
||||
|
||||
class SinusoidalPositionEncoder(torch.nn.Module):
|
||||
@ -890,7 +884,7 @@ class SenseVoiceSmall(nn.Module):
|
||||
return results, meta_data
|
||||
|
||||
def export(self, **kwargs):
|
||||
from .export_meta import export_rebuild_model
|
||||
from export_meta import export_rebuild_model
|
||||
|
||||
if "max_seq_len" not in kwargs:
|
||||
kwargs["max_seq_len"] = 512
|
||||
|
||||
0
utils/__init__.py
Normal file
0
utils/__init__.py
Normal file
73
utils/export_utils.py
Normal file
73
utils/export_utils.py
Normal file
@ -0,0 +1,73 @@
|
||||
import os
|
||||
import torch
|
||||
|
||||
|
||||
def export(
|
||||
model, quantize: bool = False, opset_version: int = 14, type="onnx", **kwargs
|
||||
):
|
||||
model_scripts = model.export(**kwargs)
|
||||
export_dir = kwargs.get("output_dir", os.path.dirname(kwargs.get("init_param")))
|
||||
os.makedirs(export_dir, exist_ok=True)
|
||||
|
||||
if not isinstance(model_scripts, (list, tuple)):
|
||||
model_scripts = (model_scripts,)
|
||||
for m in model_scripts:
|
||||
m.eval()
|
||||
if type == "onnx":
|
||||
_onnx(
|
||||
m,
|
||||
quantize=quantize,
|
||||
opset_version=opset_version,
|
||||
export_dir=export_dir,
|
||||
**kwargs,
|
||||
)
|
||||
print("output dir: {}".format(export_dir))
|
||||
|
||||
return export_dir
|
||||
|
||||
|
||||
def _onnx(
|
||||
model,
|
||||
quantize: bool = False,
|
||||
opset_version: int = 14,
|
||||
export_dir: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
dummy_input = model.export_dummy_inputs()
|
||||
|
||||
verbose = kwargs.get("verbose", False)
|
||||
|
||||
export_name = model.export_name()
|
||||
model_path = os.path.join(export_dir, export_name)
|
||||
torch.onnx.export(
|
||||
model,
|
||||
dummy_input,
|
||||
model_path,
|
||||
verbose=verbose,
|
||||
opset_version=opset_version,
|
||||
input_names=model.export_input_names(),
|
||||
output_names=model.export_output_names(),
|
||||
dynamic_axes=model.export_dynamic_axes(),
|
||||
)
|
||||
|
||||
if quantize:
|
||||
from onnxruntime.quantization import QuantType, quantize_dynamic
|
||||
import onnx
|
||||
|
||||
quant_model_path = model_path.replace(".onnx", "_quant.onnx")
|
||||
if not os.path.exists(quant_model_path):
|
||||
onnx_model = onnx.load(model_path)
|
||||
nodes = [n.name for n in onnx_model.graph.node]
|
||||
nodes_to_exclude = [
|
||||
m for m in nodes if "output" in m or "bias_encoder" in m or "bias_decoder" in m
|
||||
]
|
||||
quantize_dynamic(
|
||||
model_input=model_path,
|
||||
model_output=quant_model_path,
|
||||
op_types_to_quantize=["MatMul"],
|
||||
per_channel=True,
|
||||
reduce_range=False,
|
||||
weight_type=QuantType.QUInt8,
|
||||
nodes_to_exclude=nodes_to_exclude,
|
||||
)
|
||||
433
utils/frontend.py
Normal file
433
utils/frontend.py
Normal file
@ -0,0 +1,433 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
|
||||
import copy
|
||||
|
||||
import numpy as np
|
||||
import kaldi_native_fbank as knf
|
||||
|
||||
root_dir = Path(__file__).resolve().parent
|
||||
|
||||
logger_initialized = {}
|
||||
|
||||
|
||||
class WavFrontend:
|
||||
"""Conventional frontend structure for ASR."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cmvn_file: str = None,
|
||||
fs: int = 16000,
|
||||
window: str = "hamming",
|
||||
n_mels: int = 80,
|
||||
frame_length: int = 25,
|
||||
frame_shift: int = 10,
|
||||
lfr_m: int = 1,
|
||||
lfr_n: int = 1,
|
||||
dither: float = 1.0,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
|
||||
opts = knf.FbankOptions()
|
||||
opts.frame_opts.samp_freq = fs
|
||||
opts.frame_opts.dither = dither
|
||||
opts.frame_opts.window_type = window
|
||||
opts.frame_opts.frame_shift_ms = float(frame_shift)
|
||||
opts.frame_opts.frame_length_ms = float(frame_length)
|
||||
opts.mel_opts.num_bins = n_mels
|
||||
opts.energy_floor = 0
|
||||
opts.frame_opts.snip_edges = True
|
||||
opts.mel_opts.debug_mel = False
|
||||
self.opts = opts
|
||||
|
||||
self.lfr_m = lfr_m
|
||||
self.lfr_n = lfr_n
|
||||
self.cmvn_file = cmvn_file
|
||||
|
||||
if self.cmvn_file:
|
||||
self.cmvn = self.load_cmvn()
|
||||
self.fbank_fn = None
|
||||
self.fbank_beg_idx = 0
|
||||
self.reset_status()
|
||||
|
||||
def fbank(self, waveform: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
||||
waveform = waveform * (1 << 15)
|
||||
self.fbank_fn = knf.OnlineFbank(self.opts)
|
||||
self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
|
||||
frames = self.fbank_fn.num_frames_ready
|
||||
mat = np.empty([frames, self.opts.mel_opts.num_bins])
|
||||
for i in range(frames):
|
||||
mat[i, :] = self.fbank_fn.get_frame(i)
|
||||
feat = mat.astype(np.float32)
|
||||
feat_len = np.array(mat.shape[0]).astype(np.int32)
|
||||
return feat, feat_len
|
||||
|
||||
def fbank_online(self, waveform: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
||||
waveform = waveform * (1 << 15)
|
||||
# self.fbank_fn = knf.OnlineFbank(self.opts)
|
||||
self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
|
||||
frames = self.fbank_fn.num_frames_ready
|
||||
mat = np.empty([frames, self.opts.mel_opts.num_bins])
|
||||
for i in range(self.fbank_beg_idx, frames):
|
||||
mat[i, :] = self.fbank_fn.get_frame(i)
|
||||
# self.fbank_beg_idx += (frames-self.fbank_beg_idx)
|
||||
feat = mat.astype(np.float32)
|
||||
feat_len = np.array(mat.shape[0]).astype(np.int32)
|
||||
return feat, feat_len
|
||||
|
||||
def reset_status(self):
|
||||
self.fbank_fn = knf.OnlineFbank(self.opts)
|
||||
self.fbank_beg_idx = 0
|
||||
|
||||
def lfr_cmvn(self, feat: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
||||
if self.lfr_m != 1 or self.lfr_n != 1:
|
||||
feat = self.apply_lfr(feat, self.lfr_m, self.lfr_n)
|
||||
|
||||
if self.cmvn_file:
|
||||
feat = self.apply_cmvn(feat)
|
||||
|
||||
feat_len = np.array(feat.shape[0]).astype(np.int32)
|
||||
return feat, feat_len
|
||||
|
||||
@staticmethod
|
||||
def apply_lfr(inputs: np.ndarray, lfr_m: int, lfr_n: int) -> np.ndarray:
|
||||
LFR_inputs = []
|
||||
|
||||
T = inputs.shape[0]
|
||||
T_lfr = int(np.ceil(T / lfr_n))
|
||||
left_padding = np.tile(inputs[0], ((lfr_m - 1) // 2, 1))
|
||||
inputs = np.vstack((left_padding, inputs))
|
||||
T = T + (lfr_m - 1) // 2
|
||||
for i in range(T_lfr):
|
||||
if lfr_m <= T - i * lfr_n:
|
||||
LFR_inputs.append((inputs[i * lfr_n : i * lfr_n + lfr_m]).reshape(1, -1))
|
||||
else:
|
||||
# process last LFR frame
|
||||
num_padding = lfr_m - (T - i * lfr_n)
|
||||
frame = inputs[i * lfr_n :].reshape(-1)
|
||||
for _ in range(num_padding):
|
||||
frame = np.hstack((frame, inputs[-1]))
|
||||
|
||||
LFR_inputs.append(frame)
|
||||
LFR_outputs = np.vstack(LFR_inputs).astype(np.float32)
|
||||
return LFR_outputs
|
||||
|
||||
def apply_cmvn(self, inputs: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Apply CMVN with mvn data
|
||||
"""
|
||||
frame, dim = inputs.shape
|
||||
means = np.tile(self.cmvn[0:1, :dim], (frame, 1))
|
||||
vars = np.tile(self.cmvn[1:2, :dim], (frame, 1))
|
||||
inputs = (inputs + means) * vars
|
||||
return inputs
|
||||
|
||||
def load_cmvn(
|
||||
self,
|
||||
) -> np.ndarray:
|
||||
with open(self.cmvn_file, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
means_list = []
|
||||
vars_list = []
|
||||
for i in range(len(lines)):
|
||||
line_item = lines[i].split()
|
||||
if line_item[0] == "<AddShift>":
|
||||
line_item = lines[i + 1].split()
|
||||
if line_item[0] == "<LearnRateCoef>":
|
||||
add_shift_line = line_item[3 : (len(line_item) - 1)]
|
||||
means_list = list(add_shift_line)
|
||||
continue
|
||||
elif line_item[0] == "<Rescale>":
|
||||
line_item = lines[i + 1].split()
|
||||
if line_item[0] == "<LearnRateCoef>":
|
||||
rescale_line = line_item[3 : (len(line_item) - 1)]
|
||||
vars_list = list(rescale_line)
|
||||
continue
|
||||
|
||||
means = np.array(means_list).astype(np.float64)
|
||||
vars = np.array(vars_list).astype(np.float64)
|
||||
cmvn = np.array([means, vars])
|
||||
return cmvn
|
||||
|
||||
|
||||
class WavFrontendOnline(WavFrontend):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
# self.fbank_fn = knf.OnlineFbank(self.opts)
|
||||
# add variables
|
||||
self.frame_sample_length = int(
|
||||
self.opts.frame_opts.frame_length_ms * self.opts.frame_opts.samp_freq / 1000
|
||||
)
|
||||
self.frame_shift_sample_length = int(
|
||||
self.opts.frame_opts.frame_shift_ms * self.opts.frame_opts.samp_freq / 1000
|
||||
)
|
||||
self.waveform = None
|
||||
self.reserve_waveforms = None
|
||||
self.input_cache = None
|
||||
self.lfr_splice_cache = []
|
||||
|
||||
@staticmethod
|
||||
# inputs has catted the cache
|
||||
def apply_lfr(
|
||||
inputs: np.ndarray, lfr_m: int, lfr_n: int, is_final: bool = False
|
||||
) -> Tuple[np.ndarray, np.ndarray, int]:
|
||||
"""
|
||||
Apply lfr with data
|
||||
"""
|
||||
|
||||
LFR_inputs = []
|
||||
T = inputs.shape[0] # include the right context
|
||||
T_lfr = int(
|
||||
np.ceil((T - (lfr_m - 1) // 2) / lfr_n)
|
||||
) # minus the right context: (lfr_m - 1) // 2
|
||||
splice_idx = T_lfr
|
||||
for i in range(T_lfr):
|
||||
if lfr_m <= T - i * lfr_n:
|
||||
LFR_inputs.append((inputs[i * lfr_n : i * lfr_n + lfr_m]).reshape(1, -1))
|
||||
else: # process last LFR frame
|
||||
if is_final:
|
||||
num_padding = lfr_m - (T - i * lfr_n)
|
||||
frame = (inputs[i * lfr_n :]).reshape(-1)
|
||||
for _ in range(num_padding):
|
||||
frame = np.hstack((frame, inputs[-1]))
|
||||
LFR_inputs.append(frame)
|
||||
else:
|
||||
# update splice_idx and break the circle
|
||||
splice_idx = i
|
||||
break
|
||||
splice_idx = min(T - 1, splice_idx * lfr_n)
|
||||
lfr_splice_cache = inputs[splice_idx:, :]
|
||||
LFR_outputs = np.vstack(LFR_inputs)
|
||||
return LFR_outputs.astype(np.float32), lfr_splice_cache, splice_idx
|
||||
|
||||
@staticmethod
|
||||
def compute_frame_num(
|
||||
sample_length: int, frame_sample_length: int, frame_shift_sample_length: int
|
||||
) -> int:
|
||||
frame_num = int((sample_length - frame_sample_length) / frame_shift_sample_length + 1)
|
||||
return frame_num if frame_num >= 1 and sample_length >= frame_sample_length else 0
|
||||
|
||||
def fbank(
|
||||
self, input: np.ndarray, input_lengths: np.ndarray
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
self.fbank_fn = knf.OnlineFbank(self.opts)
|
||||
batch_size = input.shape[0]
|
||||
if self.input_cache is None:
|
||||
self.input_cache = np.empty((batch_size, 0), dtype=np.float32)
|
||||
input = np.concatenate((self.input_cache, input), axis=1)
|
||||
frame_num = self.compute_frame_num(
|
||||
input.shape[-1], self.frame_sample_length, self.frame_shift_sample_length
|
||||
)
|
||||
# update self.in_cache
|
||||
self.input_cache = input[
|
||||
:, -(input.shape[-1] - frame_num * self.frame_shift_sample_length) :
|
||||
]
|
||||
waveforms = np.empty(0, dtype=np.float32)
|
||||
feats_pad = np.empty(0, dtype=np.float32)
|
||||
feats_lens = np.empty(0, dtype=np.int32)
|
||||
if frame_num:
|
||||
waveforms = []
|
||||
feats = []
|
||||
feats_lens = []
|
||||
for i in range(batch_size):
|
||||
waveform = input[i]
|
||||
waveforms.append(
|
||||
waveform[
|
||||
: (
|
||||
(frame_num - 1) * self.frame_shift_sample_length
|
||||
+ self.frame_sample_length
|
||||
)
|
||||
]
|
||||
)
|
||||
waveform = waveform * (1 << 15)
|
||||
|
||||
self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
|
||||
frames = self.fbank_fn.num_frames_ready
|
||||
mat = np.empty([frames, self.opts.mel_opts.num_bins])
|
||||
for i in range(frames):
|
||||
mat[i, :] = self.fbank_fn.get_frame(i)
|
||||
feat = mat.astype(np.float32)
|
||||
feat_len = np.array(mat.shape[0]).astype(np.int32)
|
||||
feats.append(feat)
|
||||
feats_lens.append(feat_len)
|
||||
|
||||
waveforms = np.stack(waveforms)
|
||||
feats_lens = np.array(feats_lens)
|
||||
feats_pad = np.array(feats)
|
||||
self.fbanks = feats_pad
|
||||
self.fbanks_lens = copy.deepcopy(feats_lens)
|
||||
return waveforms, feats_pad, feats_lens
|
||||
|
||||
def get_fbank(self) -> Tuple[np.ndarray, np.ndarray]:
|
||||
return self.fbanks, self.fbanks_lens
|
||||
|
||||
def lfr_cmvn(
|
||||
self, input: np.ndarray, input_lengths: np.ndarray, is_final: bool = False
|
||||
) -> Tuple[np.ndarray, np.ndarray, List[int]]:
|
||||
batch_size = input.shape[0]
|
||||
feats = []
|
||||
feats_lens = []
|
||||
lfr_splice_frame_idxs = []
|
||||
for i in range(batch_size):
|
||||
mat = input[i, : input_lengths[i], :]
|
||||
lfr_splice_frame_idx = -1
|
||||
if self.lfr_m != 1 or self.lfr_n != 1:
|
||||
# update self.lfr_splice_cache in self.apply_lfr
|
||||
mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr(
|
||||
mat, self.lfr_m, self.lfr_n, is_final
|
||||
)
|
||||
if self.cmvn_file is not None:
|
||||
mat = self.apply_cmvn(mat)
|
||||
feat_length = mat.shape[0]
|
||||
feats.append(mat)
|
||||
feats_lens.append(feat_length)
|
||||
lfr_splice_frame_idxs.append(lfr_splice_frame_idx)
|
||||
|
||||
feats_lens = np.array(feats_lens)
|
||||
feats_pad = np.array(feats)
|
||||
return feats_pad, feats_lens, lfr_splice_frame_idxs
|
||||
|
||||
def extract_fbank(
|
||||
self, input: np.ndarray, input_lengths: np.ndarray, is_final: bool = False
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
batch_size = input.shape[0]
|
||||
assert (
|
||||
batch_size == 1
|
||||
), "we support to extract feature online only when the batch size is equal to 1 now"
|
||||
waveforms, feats, feats_lengths = self.fbank(input, input_lengths) # input shape: B T D
|
||||
if feats.shape[0]:
|
||||
self.waveforms = (
|
||||
waveforms
|
||||
if self.reserve_waveforms is None
|
||||
else np.concatenate((self.reserve_waveforms, waveforms), axis=1)
|
||||
)
|
||||
if not self.lfr_splice_cache:
|
||||
for i in range(batch_size):
|
||||
self.lfr_splice_cache.append(
|
||||
np.expand_dims(feats[i][0, :], axis=0).repeat((self.lfr_m - 1) // 2, axis=0)
|
||||
)
|
||||
|
||||
if feats_lengths[0] + self.lfr_splice_cache[0].shape[0] >= self.lfr_m:
|
||||
lfr_splice_cache_np = np.stack(self.lfr_splice_cache) # B T D
|
||||
feats = np.concatenate((lfr_splice_cache_np, feats), axis=1)
|
||||
feats_lengths += lfr_splice_cache_np[0].shape[0]
|
||||
frame_from_waveforms = int(
|
||||
(self.waveforms.shape[1] - self.frame_sample_length)
|
||||
/ self.frame_shift_sample_length
|
||||
+ 1
|
||||
)
|
||||
minus_frame = (self.lfr_m - 1) // 2 if self.reserve_waveforms is None else 0
|
||||
feats, feats_lengths, lfr_splice_frame_idxs = self.lfr_cmvn(
|
||||
feats, feats_lengths, is_final
|
||||
)
|
||||
if self.lfr_m == 1:
|
||||
self.reserve_waveforms = None
|
||||
else:
|
||||
reserve_frame_idx = lfr_splice_frame_idxs[0] - minus_frame
|
||||
# print('reserve_frame_idx: ' + str(reserve_frame_idx))
|
||||
# print('frame_frame: ' + str(frame_from_waveforms))
|
||||
self.reserve_waveforms = self.waveforms[
|
||||
:,
|
||||
reserve_frame_idx
|
||||
* self.frame_shift_sample_length : frame_from_waveforms
|
||||
* self.frame_shift_sample_length,
|
||||
]
|
||||
sample_length = (
|
||||
frame_from_waveforms - 1
|
||||
) * self.frame_shift_sample_length + self.frame_sample_length
|
||||
self.waveforms = self.waveforms[:, :sample_length]
|
||||
else:
|
||||
# update self.reserve_waveforms and self.lfr_splice_cache
|
||||
self.reserve_waveforms = self.waveforms[
|
||||
:, : -(self.frame_sample_length - self.frame_shift_sample_length)
|
||||
]
|
||||
for i in range(batch_size):
|
||||
self.lfr_splice_cache[i] = np.concatenate(
|
||||
(self.lfr_splice_cache[i], feats[i]), axis=0
|
||||
)
|
||||
return np.empty(0, dtype=np.float32), feats_lengths
|
||||
else:
|
||||
if is_final:
|
||||
self.waveforms = (
|
||||
waveforms if self.reserve_waveforms is None else self.reserve_waveforms
|
||||
)
|
||||
feats = np.stack(self.lfr_splice_cache)
|
||||
feats_lengths = np.zeros(batch_size, dtype=np.int32) + feats.shape[1]
|
||||
feats, feats_lengths, _ = self.lfr_cmvn(feats, feats_lengths, is_final)
|
||||
if is_final:
|
||||
self.cache_reset()
|
||||
return feats, feats_lengths
|
||||
|
||||
def get_waveforms(self):
|
||||
return self.waveforms
|
||||
|
||||
def cache_reset(self):
|
||||
self.fbank_fn = knf.OnlineFbank(self.opts)
|
||||
self.reserve_waveforms = None
|
||||
self.input_cache = None
|
||||
self.lfr_splice_cache = []
|
||||
|
||||
|
||||
def load_bytes(input):
|
||||
middle_data = np.frombuffer(input, dtype=np.int16)
|
||||
middle_data = np.asarray(middle_data)
|
||||
if middle_data.dtype.kind not in "iu":
|
||||
raise TypeError("'middle_data' must be an array of integers")
|
||||
dtype = np.dtype("float32")
|
||||
if dtype.kind != "f":
|
||||
raise TypeError("'dtype' must be a floating point type")
|
||||
|
||||
i = np.iinfo(middle_data.dtype)
|
||||
abs_max = 2 ** (i.bits - 1)
|
||||
offset = i.min + abs_max
|
||||
array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
|
||||
return array
|
||||
|
||||
|
||||
class SinusoidalPositionEncoderOnline:
|
||||
"""Streaming Positional encoding."""
|
||||
|
||||
def encode(self, positions: np.ndarray = None, depth: int = None, dtype: np.dtype = np.float32):
|
||||
batch_size = positions.shape[0]
|
||||
positions = positions.astype(dtype)
|
||||
log_timescale_increment = np.log(np.array([10000], dtype=dtype)) / (depth / 2 - 1)
|
||||
inv_timescales = np.exp(np.arange(depth / 2).astype(dtype) * (-log_timescale_increment))
|
||||
inv_timescales = np.reshape(inv_timescales, [batch_size, -1])
|
||||
scaled_time = np.reshape(positions, [1, -1, 1]) * np.reshape(inv_timescales, [1, 1, -1])
|
||||
encoding = np.concatenate((np.sin(scaled_time), np.cos(scaled_time)), axis=2)
|
||||
return encoding.astype(dtype)
|
||||
|
||||
def forward(self, x, start_idx=0):
|
||||
batch_size, timesteps, input_dim = x.shape
|
||||
positions = np.arange(1, timesteps + 1 + start_idx)[None, :]
|
||||
position_encoding = self.encode(positions, input_dim, x.dtype)
|
||||
|
||||
return x + position_encoding[:, start_idx : start_idx + timesteps]
|
||||
|
||||
|
||||
def test():
|
||||
path = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav"
|
||||
import librosa
|
||||
|
||||
cmvn_file = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/am.mvn"
|
||||
config_file = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/config.yaml"
|
||||
from funasr.runtime.python.onnxruntime.rapid_paraformer.utils.utils import read_yaml
|
||||
|
||||
config = read_yaml(config_file)
|
||||
waveform, _ = librosa.load(path, sr=None)
|
||||
frontend = WavFrontend(
|
||||
cmvn_file=cmvn_file,
|
||||
**config["frontend_conf"],
|
||||
)
|
||||
speech, _ = frontend.fbank_online(waveform) # 1d, (sample,), numpy
|
||||
feat, feat_len = frontend.lfr_cmvn(
|
||||
speech
|
||||
) # 2d, (frame, 450), np.float32 -> torch, torch.from_numpy(), dtype, (1, frame, 450)
|
||||
|
||||
frontend.reset_status() # clear cache
|
||||
return feat, feat_len
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test()
|
||||
395
utils/infer_utils.py
Normal file
395
utils/infer_utils.py
Normal file
@ -0,0 +1,395 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import functools
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
|
||||
|
||||
import re
|
||||
import numpy as np
|
||||
import yaml
|
||||
|
||||
try:
|
||||
from onnxruntime import (
|
||||
GraphOptimizationLevel,
|
||||
InferenceSession,
|
||||
SessionOptions,
|
||||
get_available_providers,
|
||||
get_device,
|
||||
)
|
||||
except:
|
||||
print("please pip3 install onnxruntime")
|
||||
import jieba
|
||||
import warnings
|
||||
|
||||
root_dir = Path(__file__).resolve().parent
|
||||
|
||||
logger_initialized = {}
|
||||
|
||||
|
||||
def pad_list(xs, pad_value, max_len=None):
|
||||
n_batch = len(xs)
|
||||
if max_len is None:
|
||||
max_len = max(x.size(0) for x in xs)
|
||||
# pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
|
||||
# numpy format
|
||||
pad = (np.zeros((n_batch, max_len)) + pad_value).astype(np.int32)
|
||||
for i in range(n_batch):
|
||||
pad[i, : xs[i].shape[0]] = xs[i]
|
||||
|
||||
return pad
|
||||
|
||||
|
||||
"""
|
||||
def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None):
|
||||
if length_dim == 0:
|
||||
raise ValueError("length_dim cannot be 0: {}".format(length_dim))
|
||||
|
||||
if not isinstance(lengths, list):
|
||||
lengths = lengths.tolist()
|
||||
bs = int(len(lengths))
|
||||
if maxlen is None:
|
||||
if xs is None:
|
||||
maxlen = int(max(lengths))
|
||||
else:
|
||||
maxlen = xs.size(length_dim)
|
||||
else:
|
||||
assert xs is None
|
||||
assert maxlen >= int(max(lengths))
|
||||
|
||||
seq_range = torch.arange(0, maxlen, dtype=torch.int64)
|
||||
seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen)
|
||||
seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1)
|
||||
mask = seq_range_expand >= seq_length_expand
|
||||
|
||||
if xs is not None:
|
||||
assert xs.size(0) == bs, (xs.size(0), bs)
|
||||
|
||||
if length_dim < 0:
|
||||
length_dim = xs.dim() + length_dim
|
||||
# ind = (:, None, ..., None, :, , None, ..., None)
|
||||
ind = tuple(
|
||||
slice(None) if i in (0, length_dim) else None for i in range(xs.dim())
|
||||
)
|
||||
mask = mask[ind].expand_as(xs).to(xs.device)
|
||||
return mask
|
||||
"""
|
||||
|
||||
|
||||
class TokenIDConverter:
|
||||
def __init__(
|
||||
self,
|
||||
token_list: Union[List, str],
|
||||
):
|
||||
|
||||
self.token_list = token_list
|
||||
self.unk_symbol = token_list[-1]
|
||||
self.token2id = {v: i for i, v in enumerate(self.token_list)}
|
||||
self.unk_id = self.token2id[self.unk_symbol]
|
||||
|
||||
def get_num_vocabulary_size(self) -> int:
|
||||
return len(self.token_list)
|
||||
|
||||
def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
|
||||
if isinstance(integers, np.ndarray) and integers.ndim != 1:
|
||||
raise TokenIDConverterError(f"Must be 1 dim ndarray, but got {integers.ndim}")
|
||||
return [self.token_list[i] for i in integers]
|
||||
|
||||
def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
|
||||
|
||||
return [self.token2id.get(i, self.unk_id) for i in tokens]
|
||||
|
||||
|
||||
class CharTokenizer:
|
||||
def __init__(
|
||||
self,
|
||||
symbol_value: Union[Path, str, Iterable[str]] = None,
|
||||
space_symbol: str = "<space>",
|
||||
remove_non_linguistic_symbols: bool = False,
|
||||
):
|
||||
|
||||
self.space_symbol = space_symbol
|
||||
self.non_linguistic_symbols = self.load_symbols(symbol_value)
|
||||
self.remove_non_linguistic_symbols = remove_non_linguistic_symbols
|
||||
|
||||
@staticmethod
|
||||
def load_symbols(value: Union[Path, str, Iterable[str]] = None) -> Set:
|
||||
if value is None:
|
||||
return set()
|
||||
|
||||
if isinstance(value, Iterable[str]):
|
||||
return set(value)
|
||||
|
||||
file_path = Path(value)
|
||||
if not file_path.exists():
|
||||
logging.warning("%s doesn't exist.", file_path)
|
||||
return set()
|
||||
|
||||
with file_path.open("r", encoding="utf-8") as f:
|
||||
return set(line.rstrip() for line in f)
|
||||
|
||||
def text2tokens(self, line: Union[str, list]) -> List[str]:
|
||||
tokens = []
|
||||
while len(line) != 0:
|
||||
for w in self.non_linguistic_symbols:
|
||||
if line.startswith(w):
|
||||
if not self.remove_non_linguistic_symbols:
|
||||
tokens.append(line[: len(w)])
|
||||
line = line[len(w) :]
|
||||
break
|
||||
else:
|
||||
t = line[0]
|
||||
if t == " ":
|
||||
t = "<space>"
|
||||
tokens.append(t)
|
||||
line = line[1:]
|
||||
return tokens
|
||||
|
||||
def tokens2text(self, tokens: Iterable[str]) -> str:
|
||||
tokens = [t if t != self.space_symbol else " " for t in tokens]
|
||||
return "".join(tokens)
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"{self.__class__.__name__}("
|
||||
f'space_symbol="{self.space_symbol}"'
|
||||
f'non_linguistic_symbols="{self.non_linguistic_symbols}"'
|
||||
f")"
|
||||
)
|
||||
|
||||
|
||||
class Hypothesis(NamedTuple):
|
||||
"""Hypothesis data type."""
|
||||
|
||||
yseq: np.ndarray
|
||||
score: Union[float, np.ndarray] = 0
|
||||
scores: Dict[str, Union[float, np.ndarray]] = dict()
|
||||
states: Dict[str, Any] = dict()
|
||||
|
||||
def asdict(self) -> dict:
|
||||
"""Convert data to JSON-friendly dict."""
|
||||
return self._replace(
|
||||
yseq=self.yseq.tolist(),
|
||||
score=float(self.score),
|
||||
scores={k: float(v) for k, v in self.scores.items()},
|
||||
)._asdict()
|
||||
|
||||
|
||||
class TokenIDConverterError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ONNXRuntimeError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class OrtInferSession:
|
||||
def __init__(self, model_file, device_id=-1, intra_op_num_threads=4):
|
||||
device_id = str(device_id)
|
||||
sess_opt = SessionOptions()
|
||||
sess_opt.intra_op_num_threads = intra_op_num_threads
|
||||
sess_opt.log_severity_level = 4
|
||||
sess_opt.enable_cpu_mem_arena = False
|
||||
sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
|
||||
cuda_ep = "CUDAExecutionProvider"
|
||||
cuda_provider_options = {
|
||||
"device_id": device_id,
|
||||
"arena_extend_strategy": "kNextPowerOfTwo",
|
||||
"cudnn_conv_algo_search": "EXHAUSTIVE",
|
||||
"do_copy_in_default_stream": "true",
|
||||
}
|
||||
cpu_ep = "CPUExecutionProvider"
|
||||
cpu_provider_options = {
|
||||
"arena_extend_strategy": "kSameAsRequested",
|
||||
}
|
||||
|
||||
EP_list = []
|
||||
if device_id != "-1" and get_device() == "GPU" and cuda_ep in get_available_providers():
|
||||
EP_list = [(cuda_ep, cuda_provider_options)]
|
||||
EP_list.append((cpu_ep, cpu_provider_options))
|
||||
|
||||
self._verify_model(model_file)
|
||||
self.session = InferenceSession(model_file, sess_options=sess_opt, providers=EP_list)
|
||||
|
||||
if device_id != "-1" and cuda_ep not in self.session.get_providers():
|
||||
warnings.warn(
|
||||
f"{cuda_ep} is not avaiable for current env, the inference part is automatically shifted to be executed under {cpu_ep}.\n"
|
||||
"Please ensure the installed onnxruntime-gpu version matches your cuda and cudnn version, "
|
||||
"you can check their relations from the offical web site: "
|
||||
"https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html",
|
||||
RuntimeWarning,
|
||||
)
|
||||
|
||||
def __call__(self, input_content: List[Union[np.ndarray, np.ndarray]]) -> np.ndarray:
|
||||
input_dict = dict(zip(self.get_input_names(), input_content))
|
||||
try:
|
||||
return self.session.run(self.get_output_names(), input_dict)
|
||||
except Exception as e:
|
||||
raise ONNXRuntimeError("ONNXRuntime inferece failed.") from e
|
||||
|
||||
def get_input_names(
|
||||
self,
|
||||
):
|
||||
return [v.name for v in self.session.get_inputs()]
|
||||
|
||||
def get_output_names(
|
||||
self,
|
||||
):
|
||||
return [v.name for v in self.session.get_outputs()]
|
||||
|
||||
def get_character_list(self, key: str = "character"):
|
||||
return self.meta_dict[key].splitlines()
|
||||
|
||||
def have_key(self, key: str = "character") -> bool:
|
||||
self.meta_dict = self.session.get_modelmeta().custom_metadata_map
|
||||
if key in self.meta_dict.keys():
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _verify_model(model_path):
|
||||
model_path = Path(model_path)
|
||||
if not model_path.exists():
|
||||
raise FileNotFoundError(f"{model_path} does not exists.")
|
||||
if not model_path.is_file():
|
||||
raise FileExistsError(f"{model_path} is not a file.")
|
||||
|
||||
|
||||
def split_to_mini_sentence(words: list, word_limit: int = 20):
|
||||
assert word_limit > 1
|
||||
if len(words) <= word_limit:
|
||||
return [words]
|
||||
sentences = []
|
||||
length = len(words)
|
||||
sentence_len = length // word_limit
|
||||
for i in range(sentence_len):
|
||||
sentences.append(words[i * word_limit : (i + 1) * word_limit])
|
||||
if length % word_limit > 0:
|
||||
sentences.append(words[sentence_len * word_limit :])
|
||||
return sentences
|
||||
|
||||
|
||||
def code_mix_split_words(text: str):
|
||||
words = []
|
||||
segs = text.split()
|
||||
for seg in segs:
|
||||
# There is no space in seg.
|
||||
current_word = ""
|
||||
for c in seg:
|
||||
if len(c.encode()) == 1:
|
||||
# This is an ASCII char.
|
||||
current_word += c
|
||||
else:
|
||||
# This is a Chinese char.
|
||||
if len(current_word) > 0:
|
||||
words.append(current_word)
|
||||
current_word = ""
|
||||
words.append(c)
|
||||
if len(current_word) > 0:
|
||||
words.append(current_word)
|
||||
return words
|
||||
|
||||
|
||||
def isEnglish(text: str):
|
||||
if re.search("^[a-zA-Z']+$", text):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def join_chinese_and_english(input_list):
|
||||
line = ""
|
||||
for token in input_list:
|
||||
if isEnglish(token):
|
||||
line = line + " " + token
|
||||
else:
|
||||
line = line + token
|
||||
|
||||
line = line.strip()
|
||||
return line
|
||||
|
||||
|
||||
def code_mix_split_words_jieba(seg_dict_file: str):
|
||||
jieba.load_userdict(seg_dict_file)
|
||||
|
||||
def _fn(text: str):
|
||||
input_list = text.split()
|
||||
token_list_all = []
|
||||
langauge_list = []
|
||||
token_list_tmp = []
|
||||
language_flag = None
|
||||
for token in input_list:
|
||||
if isEnglish(token) and language_flag == "Chinese":
|
||||
token_list_all.append(token_list_tmp)
|
||||
langauge_list.append("Chinese")
|
||||
token_list_tmp = []
|
||||
elif not isEnglish(token) and language_flag == "English":
|
||||
token_list_all.append(token_list_tmp)
|
||||
langauge_list.append("English")
|
||||
token_list_tmp = []
|
||||
|
||||
token_list_tmp.append(token)
|
||||
|
||||
if isEnglish(token):
|
||||
language_flag = "English"
|
||||
else:
|
||||
language_flag = "Chinese"
|
||||
|
||||
if token_list_tmp:
|
||||
token_list_all.append(token_list_tmp)
|
||||
langauge_list.append(language_flag)
|
||||
|
||||
result_list = []
|
||||
for token_list_tmp, language_flag in zip(token_list_all, langauge_list):
|
||||
if language_flag == "English":
|
||||
result_list.extend(token_list_tmp)
|
||||
else:
|
||||
seg_list = jieba.cut(join_chinese_and_english(token_list_tmp), HMM=False)
|
||||
result_list.extend(seg_list)
|
||||
|
||||
return result_list
|
||||
|
||||
return _fn
|
||||
|
||||
|
||||
def read_yaml(yaml_path: Union[str, Path]) -> Dict:
|
||||
if not Path(yaml_path).exists():
|
||||
raise FileExistsError(f"The {yaml_path} does not exist.")
|
||||
|
||||
with open(str(yaml_path), "rb") as f:
|
||||
data = yaml.load(f, Loader=yaml.Loader)
|
||||
return data
|
||||
|
||||
|
||||
@functools.lru_cache()
|
||||
def get_logger(name="funasr_onnx"):
|
||||
"""Initialize and get a logger by name.
|
||||
If the logger has not been initialized, this method will initialize the
|
||||
logger by adding one or two handlers, otherwise the initialized logger will
|
||||
be directly returned. During initialization, a StreamHandler will always be
|
||||
added.
|
||||
Args:
|
||||
name (str): Logger name.
|
||||
Returns:
|
||||
logging.Logger: The expected logger.
|
||||
"""
|
||||
logger = logging.getLogger(name)
|
||||
if name in logger_initialized:
|
||||
return logger
|
||||
|
||||
for logger_name in logger_initialized:
|
||||
if name.startswith(logger_name):
|
||||
return logger
|
||||
|
||||
formatter = logging.Formatter(
|
||||
"[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%Y/%m/%d %H:%M:%S"
|
||||
)
|
||||
|
||||
sh = logging.StreamHandler()
|
||||
sh.setFormatter(formatter)
|
||||
logger.addHandler(sh)
|
||||
logger_initialized[name] = True
|
||||
logger.propagate = False
|
||||
logging.basicConfig(level=logging.ERROR)
|
||||
return logger
|
||||
146
utils/model_bin.py
Normal file
146
utils/model_bin.py
Normal file
@ -0,0 +1,146 @@
|
||||
# -*- encoding: utf-8 -*-
|
||||
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
|
||||
# MIT License (https://opensource.org/licenses/MIT)
|
||||
|
||||
import os.path
|
||||
from pathlib import Path
|
||||
from typing import List, Union, Tuple
|
||||
import torch
|
||||
import librosa
|
||||
import numpy as np
|
||||
|
||||
from utils.infer_utils import (
|
||||
CharTokenizer,
|
||||
Hypothesis,
|
||||
ONNXRuntimeError,
|
||||
OrtInferSession,
|
||||
TokenIDConverter,
|
||||
get_logger,
|
||||
read_yaml,
|
||||
)
|
||||
# from .utils.postprocess_utils import sentence_postprocess, sentence_postprocess_sentencepiece
|
||||
from utils.frontend import WavFrontend
|
||||
# from .utils.timestamp_utils import time_stamp_lfr6_onnx
|
||||
from utils.infer_utils import pad_list
|
||||
|
||||
logging = get_logger()
|
||||
|
||||
|
||||
class SenseVoiceSmallONNX:
|
||||
"""
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
|
||||
https://arxiv.org/abs/2206.08317
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_dir: Union[str, Path] = None,
|
||||
batch_size: int = 1,
|
||||
device_id: Union[str, int] = "-1",
|
||||
plot_timestamp_to: str = "",
|
||||
quantize: bool = False,
|
||||
intra_op_num_threads: int = 4,
|
||||
cache_dir: str = None,
|
||||
**kwargs,
|
||||
):
|
||||
if quantize:
|
||||
model_file = os.path.join(model_dir, "model_quant.onnx")
|
||||
else:
|
||||
model_file = os.path.join(model_dir, "model.onnx")
|
||||
|
||||
config_file = os.path.join(model_dir, "config.yaml")
|
||||
cmvn_file = os.path.join(model_dir, "am.mvn")
|
||||
config = read_yaml(config_file)
|
||||
# token_list = os.path.join(model_dir, "tokens.json")
|
||||
# with open(token_list, "r", encoding="utf-8") as f:
|
||||
# token_list = json.load(f)
|
||||
|
||||
# self.converter = TokenIDConverter(token_list)
|
||||
self.tokenizer = CharTokenizer()
|
||||
config["frontend_conf"]['cmvn_file'] = cmvn_file
|
||||
self.frontend = WavFrontend(**config["frontend_conf"])
|
||||
self.ort_infer = OrtInferSession(
|
||||
model_file, device_id, intra_op_num_threads=intra_op_num_threads
|
||||
)
|
||||
self.batch_size = batch_size
|
||||
self.blank_id = 0
|
||||
|
||||
def __call__(self,
|
||||
wav_content: Union[str, np.ndarray, List[str]],
|
||||
language: List,
|
||||
textnorm: List,
|
||||
tokenizer=None,
|
||||
**kwargs) -> List:
|
||||
waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
|
||||
waveform_nums = len(waveform_list)
|
||||
asr_res = []
|
||||
for beg_idx in range(0, waveform_nums, self.batch_size):
|
||||
end_idx = min(waveform_nums, beg_idx + self.batch_size)
|
||||
feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
|
||||
ctc_logits, encoder_out_lens = self.infer(feats,
|
||||
feats_len,
|
||||
np.array(language, dtype=np.int32),
|
||||
np.array(textnorm, dtype=np.int32)
|
||||
)
|
||||
# back to torch.Tensor
|
||||
ctc_logits = torch.from_numpy(ctc_logits).float()
|
||||
# support batch_size=1 only currently
|
||||
x = ctc_logits[0, : encoder_out_lens[0].item(), :]
|
||||
yseq = x.argmax(dim=-1)
|
||||
yseq = torch.unique_consecutive(yseq, dim=-1)
|
||||
|
||||
mask = yseq != self.blank_id
|
||||
token_int = yseq[mask].tolist()
|
||||
|
||||
if tokenizer is not None:
|
||||
asr_res.append(tokenizer.tokens2text(token_int))
|
||||
else:
|
||||
asr_res.append(token_int)
|
||||
return asr_res
|
||||
|
||||
def load_data(self, wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
|
||||
def load_wav(path: str) -> np.ndarray:
|
||||
waveform, _ = librosa.load(path, sr=fs)
|
||||
return waveform
|
||||
|
||||
if isinstance(wav_content, np.ndarray):
|
||||
return [wav_content]
|
||||
|
||||
if isinstance(wav_content, str):
|
||||
return [load_wav(wav_content)]
|
||||
|
||||
if isinstance(wav_content, list):
|
||||
return [load_wav(path) for path in wav_content]
|
||||
|
||||
raise TypeError(f"The type of {wav_content} is not in [str, np.ndarray, list]")
|
||||
|
||||
def extract_feat(self, waveform_list: List[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
|
||||
feats, feats_len = [], []
|
||||
for waveform in waveform_list:
|
||||
speech, _ = self.frontend.fbank(waveform)
|
||||
feat, feat_len = self.frontend.lfr_cmvn(speech)
|
||||
feats.append(feat)
|
||||
feats_len.append(feat_len)
|
||||
|
||||
feats = self.pad_feats(feats, np.max(feats_len))
|
||||
feats_len = np.array(feats_len).astype(np.int32)
|
||||
return feats, feats_len
|
||||
|
||||
@staticmethod
|
||||
def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray:
|
||||
def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray:
|
||||
pad_width = ((0, max_feat_len - cur_len), (0, 0))
|
||||
return np.pad(feat, pad_width, "constant", constant_values=0)
|
||||
|
||||
feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats]
|
||||
feats = np.array(feat_res).astype(np.float32)
|
||||
return feats
|
||||
|
||||
def infer(self,
|
||||
feats: np.ndarray,
|
||||
feats_len: np.ndarray,
|
||||
language: np.ndarray,
|
||||
textnorm: np.ndarray,) -> Tuple[np.ndarray, np.ndarray]:
|
||||
outputs = self.ort_infer([feats, feats_len, language, textnorm])
|
||||
return outputs
|
||||
Loading…
Reference in New Issue
Block a user