ONNX support

This commit is contained in:
维石 2024-07-22 15:27:56 +08:00
parent ace1d34b69
commit 0858308f36
10 changed files with 1120 additions and 66 deletions

View File

@ -1,9 +1,3 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
from funasr import AutoModel
from funasr.utils.postprocess_utils import rich_transcription_postprocess

View File

@ -1,8 +1,3 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
from model import SenseVoiceSmall
from funasr.utils.postprocess_utils import rich_transcription_postprocess

47
export.py Normal file
View File

@ -0,0 +1,47 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
import os
import torch
from model import SenseVoiceSmall
from utils import export_utils
from utils.model_bin import SenseVoiceSmallONNX
from funasr.utils.postprocess_utils import rich_transcription_postprocess
quantize = False
model_dir = "iic/SenseVoiceSmall"
model, kwargs = SenseVoiceSmall.from_pretrained(model=model_dir, device="cuda:0")
rebuilt_model = model.export(type="onnx", quantize=False)
model_path = kwargs.get("output_dir", os.path.dirname(kwargs.get("init_param")))
model_file = os.path.join(model_path, "model.onnx")
if quantize:
model_file = os.path.join(model_path, "model_quant.onnx")
# export model
if not os.path.exists(model_file):
with torch.no_grad():
del kwargs['model']
export_dir = export_utils.export(model=rebuilt_model, **kwargs)
print("Export model onnx to {}".format(model_file))
# export model init
model_bin = SenseVoiceSmallONNX(model_path)
# build tokenizer
try:
from funasr.tokenizer.sentencepiece_tokenizer import SentencepiecesTokenizer
tokenizer = SentencepiecesTokenizer(bpemodel=os.path.join(model_path, "chn_jpn_yue_eng_ko_spectok.bpe.model"))
except:
tokenizer = None
# inference
wav_or_scp = "/Users/shixian/Downloads/asr_example_hotword.wav"
language_list = [0]
textnorm_list = [15]
res = model_bin(wav_or_scp, language_list, textnorm_list, tokenizer=tokenizer)
print(res)

View File

@ -5,33 +5,20 @@
import types
import torch
import torch.nn as nn
from funasr.register import tables
from funasr.utils.torch_function import sequence_mask
def export_rebuild_model(model, **kwargs):
model.device = kwargs.get("device")
is_onnx = kwargs.get("type", "onnx") == "onnx"
# encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
# model.encoder = encoder_class(model.encoder, onnx=is_onnx)
from funasr.utils.torch_function import sequence_mask
model.make_pad_mask = sequence_mask(kwargs["max_seq_len"], flip=False)
model.forward = types.MethodType(export_forward, model)
model.export_dummy_inputs = types.MethodType(export_dummy_inputs, model)
model.export_input_names = types.MethodType(export_input_names, model)
model.export_output_names = types.MethodType(export_output_names, model)
model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model)
model.export_name = types.MethodType(export_name, model)
model.export_name = "model"
return model
def export_forward(
self,
speech: torch.Tensor,
@ -40,12 +27,11 @@ def export_forward(
textnorm: torch.Tensor,
**kwargs,
):
speech = speech.to(device=kwargs["device"])
speech_lengths = speech_lengths.to(device=kwargs["device"])
language_query = self.embed(language).to(speech.device)
textnorm_query = self.embed(textnorm).to(speech.device)
# speech = speech.to(device="cuda")
# speech_lengths = speech_lengths.to(device="cuda")
language_query = self.embed(language.to(speech.device)).unsqueeze(1)
textnorm_query = self.embed(textnorm.to(speech.device)).unsqueeze(1)
print(textnorm_query.shape, speech.shape)
speech = torch.cat((textnorm_query, speech), dim=1)
speech_lengths += 1
@ -56,18 +42,14 @@ def export_forward(
speech = torch.cat((input_query, speech), dim=1)
speech_lengths += 3
# Encoder
encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths)
if isinstance(encoder_out, tuple):
encoder_out = encoder_out[0]
# c. Passed the encoder result and the beam search
ctc_logits = self.ctc.log_softmax(encoder_out)
ctc_logits = self.ctc.ctc_lo(encoder_out)
return ctc_logits, encoder_out_lens
def export_dummy_inputs(self):
speech = torch.randn(2, 30, 560)
speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
@ -75,27 +57,22 @@ def export_dummy_inputs(self):
textnorm = torch.tensor([15, 15], dtype=torch.int32)
return (speech, speech_lengths, language, textnorm)
def export_input_names(self):
return ["speech", "speech_lengths", "language", "textnorm"]
def export_output_names(self):
return ["ctc_logits", "encoder_out_lens"]
def export_dynamic_axes(self):
return {
"speech": {0: "batch_size", 1: "feats_length"},
"speech_lengths": {
0: "batch_size",
},
"logits": {0: "batch_size", 1: "logits_length"},
"speech_lengths": {0: "batch_size"},
"language": {0: "batch_size"},
"textnorm": {0: "batch_size"},
"ctc_logits": {0: "batch_size", 1: "logits_length"},
"encoder_out_lens": {0: "batch_size"},
}
def export_name(
self,
):
def export_name(self):
return "model.onnx"

View File

@ -1,24 +1,18 @@
from typing import Iterable, Optional
import types
import time
import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor
from torch import nn
from torch.cuda.amp import autocast
from funasr.metrics.compute_acc import compute_accuracy, th_accuracy
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
from funasr.train_utils.device_funcs import force_gatherable
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from funasr.utils.datadir_writer import DatadirWriter
from funasr.models.ctc.ctc import CTC
import time
import torch
from torch import nn
import torch.nn.functional as F
from typing import Iterable, Optional
from funasr.register import tables
from funasr.models.ctc.ctc import CTC
from funasr.utils.datadir_writer import DatadirWriter
from funasr.models.paraformer.search import Hypothesis
from funasr.train_utils.device_funcs import force_gatherable
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
from funasr.metrics.compute_acc import compute_accuracy, th_accuracy
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
class SinusoidalPositionEncoder(torch.nn.Module):
@ -890,7 +884,7 @@ class SenseVoiceSmall(nn.Module):
return results, meta_data
def export(self, **kwargs):
from .export_meta import export_rebuild_model
from export_meta import export_rebuild_model
if "max_seq_len" not in kwargs:
kwargs["max_seq_len"] = 512

0
utils/__init__.py Normal file
View File

73
utils/export_utils.py Normal file
View File

@ -0,0 +1,73 @@
import os
import torch
def export(
model, quantize: bool = False, opset_version: int = 14, type="onnx", **kwargs
):
model_scripts = model.export(**kwargs)
export_dir = kwargs.get("output_dir", os.path.dirname(kwargs.get("init_param")))
os.makedirs(export_dir, exist_ok=True)
if not isinstance(model_scripts, (list, tuple)):
model_scripts = (model_scripts,)
for m in model_scripts:
m.eval()
if type == "onnx":
_onnx(
m,
quantize=quantize,
opset_version=opset_version,
export_dir=export_dir,
**kwargs,
)
print("output dir: {}".format(export_dir))
return export_dir
def _onnx(
model,
quantize: bool = False,
opset_version: int = 14,
export_dir: str = None,
**kwargs,
):
dummy_input = model.export_dummy_inputs()
verbose = kwargs.get("verbose", False)
export_name = model.export_name()
model_path = os.path.join(export_dir, export_name)
torch.onnx.export(
model,
dummy_input,
model_path,
verbose=verbose,
opset_version=opset_version,
input_names=model.export_input_names(),
output_names=model.export_output_names(),
dynamic_axes=model.export_dynamic_axes(),
)
if quantize:
from onnxruntime.quantization import QuantType, quantize_dynamic
import onnx
quant_model_path = model_path.replace(".onnx", "_quant.onnx")
if not os.path.exists(quant_model_path):
onnx_model = onnx.load(model_path)
nodes = [n.name for n in onnx_model.graph.node]
nodes_to_exclude = [
m for m in nodes if "output" in m or "bias_encoder" in m or "bias_decoder" in m
]
quantize_dynamic(
model_input=model_path,
model_output=quant_model_path,
op_types_to_quantize=["MatMul"],
per_channel=True,
reduce_range=False,
weight_type=QuantType.QUInt8,
nodes_to_exclude=nodes_to_exclude,
)

433
utils/frontend.py Normal file
View File

@ -0,0 +1,433 @@
# -*- encoding: utf-8 -*-
from pathlib import Path
from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
import copy
import numpy as np
import kaldi_native_fbank as knf
root_dir = Path(__file__).resolve().parent
logger_initialized = {}
class WavFrontend:
"""Conventional frontend structure for ASR."""
def __init__(
self,
cmvn_file: str = None,
fs: int = 16000,
window: str = "hamming",
n_mels: int = 80,
frame_length: int = 25,
frame_shift: int = 10,
lfr_m: int = 1,
lfr_n: int = 1,
dither: float = 1.0,
**kwargs,
) -> None:
opts = knf.FbankOptions()
opts.frame_opts.samp_freq = fs
opts.frame_opts.dither = dither
opts.frame_opts.window_type = window
opts.frame_opts.frame_shift_ms = float(frame_shift)
opts.frame_opts.frame_length_ms = float(frame_length)
opts.mel_opts.num_bins = n_mels
opts.energy_floor = 0
opts.frame_opts.snip_edges = True
opts.mel_opts.debug_mel = False
self.opts = opts
self.lfr_m = lfr_m
self.lfr_n = lfr_n
self.cmvn_file = cmvn_file
if self.cmvn_file:
self.cmvn = self.load_cmvn()
self.fbank_fn = None
self.fbank_beg_idx = 0
self.reset_status()
def fbank(self, waveform: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
waveform = waveform * (1 << 15)
self.fbank_fn = knf.OnlineFbank(self.opts)
self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
frames = self.fbank_fn.num_frames_ready
mat = np.empty([frames, self.opts.mel_opts.num_bins])
for i in range(frames):
mat[i, :] = self.fbank_fn.get_frame(i)
feat = mat.astype(np.float32)
feat_len = np.array(mat.shape[0]).astype(np.int32)
return feat, feat_len
def fbank_online(self, waveform: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
waveform = waveform * (1 << 15)
# self.fbank_fn = knf.OnlineFbank(self.opts)
self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
frames = self.fbank_fn.num_frames_ready
mat = np.empty([frames, self.opts.mel_opts.num_bins])
for i in range(self.fbank_beg_idx, frames):
mat[i, :] = self.fbank_fn.get_frame(i)
# self.fbank_beg_idx += (frames-self.fbank_beg_idx)
feat = mat.astype(np.float32)
feat_len = np.array(mat.shape[0]).astype(np.int32)
return feat, feat_len
def reset_status(self):
self.fbank_fn = knf.OnlineFbank(self.opts)
self.fbank_beg_idx = 0
def lfr_cmvn(self, feat: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
if self.lfr_m != 1 or self.lfr_n != 1:
feat = self.apply_lfr(feat, self.lfr_m, self.lfr_n)
if self.cmvn_file:
feat = self.apply_cmvn(feat)
feat_len = np.array(feat.shape[0]).astype(np.int32)
return feat, feat_len
@staticmethod
def apply_lfr(inputs: np.ndarray, lfr_m: int, lfr_n: int) -> np.ndarray:
LFR_inputs = []
T = inputs.shape[0]
T_lfr = int(np.ceil(T / lfr_n))
left_padding = np.tile(inputs[0], ((lfr_m - 1) // 2, 1))
inputs = np.vstack((left_padding, inputs))
T = T + (lfr_m - 1) // 2
for i in range(T_lfr):
if lfr_m <= T - i * lfr_n:
LFR_inputs.append((inputs[i * lfr_n : i * lfr_n + lfr_m]).reshape(1, -1))
else:
# process last LFR frame
num_padding = lfr_m - (T - i * lfr_n)
frame = inputs[i * lfr_n :].reshape(-1)
for _ in range(num_padding):
frame = np.hstack((frame, inputs[-1]))
LFR_inputs.append(frame)
LFR_outputs = np.vstack(LFR_inputs).astype(np.float32)
return LFR_outputs
def apply_cmvn(self, inputs: np.ndarray) -> np.ndarray:
"""
Apply CMVN with mvn data
"""
frame, dim = inputs.shape
means = np.tile(self.cmvn[0:1, :dim], (frame, 1))
vars = np.tile(self.cmvn[1:2, :dim], (frame, 1))
inputs = (inputs + means) * vars
return inputs
def load_cmvn(
self,
) -> np.ndarray:
with open(self.cmvn_file, "r", encoding="utf-8") as f:
lines = f.readlines()
means_list = []
vars_list = []
for i in range(len(lines)):
line_item = lines[i].split()
if line_item[0] == "<AddShift>":
line_item = lines[i + 1].split()
if line_item[0] == "<LearnRateCoef>":
add_shift_line = line_item[3 : (len(line_item) - 1)]
means_list = list(add_shift_line)
continue
elif line_item[0] == "<Rescale>":
line_item = lines[i + 1].split()
if line_item[0] == "<LearnRateCoef>":
rescale_line = line_item[3 : (len(line_item) - 1)]
vars_list = list(rescale_line)
continue
means = np.array(means_list).astype(np.float64)
vars = np.array(vars_list).astype(np.float64)
cmvn = np.array([means, vars])
return cmvn
class WavFrontendOnline(WavFrontend):
def __init__(self, **kwargs):
super().__init__(**kwargs)
# self.fbank_fn = knf.OnlineFbank(self.opts)
# add variables
self.frame_sample_length = int(
self.opts.frame_opts.frame_length_ms * self.opts.frame_opts.samp_freq / 1000
)
self.frame_shift_sample_length = int(
self.opts.frame_opts.frame_shift_ms * self.opts.frame_opts.samp_freq / 1000
)
self.waveform = None
self.reserve_waveforms = None
self.input_cache = None
self.lfr_splice_cache = []
@staticmethod
# inputs has catted the cache
def apply_lfr(
inputs: np.ndarray, lfr_m: int, lfr_n: int, is_final: bool = False
) -> Tuple[np.ndarray, np.ndarray, int]:
"""
Apply lfr with data
"""
LFR_inputs = []
T = inputs.shape[0] # include the right context
T_lfr = int(
np.ceil((T - (lfr_m - 1) // 2) / lfr_n)
) # minus the right context: (lfr_m - 1) // 2
splice_idx = T_lfr
for i in range(T_lfr):
if lfr_m <= T - i * lfr_n:
LFR_inputs.append((inputs[i * lfr_n : i * lfr_n + lfr_m]).reshape(1, -1))
else: # process last LFR frame
if is_final:
num_padding = lfr_m - (T - i * lfr_n)
frame = (inputs[i * lfr_n :]).reshape(-1)
for _ in range(num_padding):
frame = np.hstack((frame, inputs[-1]))
LFR_inputs.append(frame)
else:
# update splice_idx and break the circle
splice_idx = i
break
splice_idx = min(T - 1, splice_idx * lfr_n)
lfr_splice_cache = inputs[splice_idx:, :]
LFR_outputs = np.vstack(LFR_inputs)
return LFR_outputs.astype(np.float32), lfr_splice_cache, splice_idx
@staticmethod
def compute_frame_num(
sample_length: int, frame_sample_length: int, frame_shift_sample_length: int
) -> int:
frame_num = int((sample_length - frame_sample_length) / frame_shift_sample_length + 1)
return frame_num if frame_num >= 1 and sample_length >= frame_sample_length else 0
def fbank(
self, input: np.ndarray, input_lengths: np.ndarray
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
self.fbank_fn = knf.OnlineFbank(self.opts)
batch_size = input.shape[0]
if self.input_cache is None:
self.input_cache = np.empty((batch_size, 0), dtype=np.float32)
input = np.concatenate((self.input_cache, input), axis=1)
frame_num = self.compute_frame_num(
input.shape[-1], self.frame_sample_length, self.frame_shift_sample_length
)
# update self.in_cache
self.input_cache = input[
:, -(input.shape[-1] - frame_num * self.frame_shift_sample_length) :
]
waveforms = np.empty(0, dtype=np.float32)
feats_pad = np.empty(0, dtype=np.float32)
feats_lens = np.empty(0, dtype=np.int32)
if frame_num:
waveforms = []
feats = []
feats_lens = []
for i in range(batch_size):
waveform = input[i]
waveforms.append(
waveform[
: (
(frame_num - 1) * self.frame_shift_sample_length
+ self.frame_sample_length
)
]
)
waveform = waveform * (1 << 15)
self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
frames = self.fbank_fn.num_frames_ready
mat = np.empty([frames, self.opts.mel_opts.num_bins])
for i in range(frames):
mat[i, :] = self.fbank_fn.get_frame(i)
feat = mat.astype(np.float32)
feat_len = np.array(mat.shape[0]).astype(np.int32)
feats.append(feat)
feats_lens.append(feat_len)
waveforms = np.stack(waveforms)
feats_lens = np.array(feats_lens)
feats_pad = np.array(feats)
self.fbanks = feats_pad
self.fbanks_lens = copy.deepcopy(feats_lens)
return waveforms, feats_pad, feats_lens
def get_fbank(self) -> Tuple[np.ndarray, np.ndarray]:
return self.fbanks, self.fbanks_lens
def lfr_cmvn(
self, input: np.ndarray, input_lengths: np.ndarray, is_final: bool = False
) -> Tuple[np.ndarray, np.ndarray, List[int]]:
batch_size = input.shape[0]
feats = []
feats_lens = []
lfr_splice_frame_idxs = []
for i in range(batch_size):
mat = input[i, : input_lengths[i], :]
lfr_splice_frame_idx = -1
if self.lfr_m != 1 or self.lfr_n != 1:
# update self.lfr_splice_cache in self.apply_lfr
mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr(
mat, self.lfr_m, self.lfr_n, is_final
)
if self.cmvn_file is not None:
mat = self.apply_cmvn(mat)
feat_length = mat.shape[0]
feats.append(mat)
feats_lens.append(feat_length)
lfr_splice_frame_idxs.append(lfr_splice_frame_idx)
feats_lens = np.array(feats_lens)
feats_pad = np.array(feats)
return feats_pad, feats_lens, lfr_splice_frame_idxs
def extract_fbank(
self, input: np.ndarray, input_lengths: np.ndarray, is_final: bool = False
) -> Tuple[np.ndarray, np.ndarray]:
batch_size = input.shape[0]
assert (
batch_size == 1
), "we support to extract feature online only when the batch size is equal to 1 now"
waveforms, feats, feats_lengths = self.fbank(input, input_lengths) # input shape: B T D
if feats.shape[0]:
self.waveforms = (
waveforms
if self.reserve_waveforms is None
else np.concatenate((self.reserve_waveforms, waveforms), axis=1)
)
if not self.lfr_splice_cache:
for i in range(batch_size):
self.lfr_splice_cache.append(
np.expand_dims(feats[i][0, :], axis=0).repeat((self.lfr_m - 1) // 2, axis=0)
)
if feats_lengths[0] + self.lfr_splice_cache[0].shape[0] >= self.lfr_m:
lfr_splice_cache_np = np.stack(self.lfr_splice_cache) # B T D
feats = np.concatenate((lfr_splice_cache_np, feats), axis=1)
feats_lengths += lfr_splice_cache_np[0].shape[0]
frame_from_waveforms = int(
(self.waveforms.shape[1] - self.frame_sample_length)
/ self.frame_shift_sample_length
+ 1
)
minus_frame = (self.lfr_m - 1) // 2 if self.reserve_waveforms is None else 0
feats, feats_lengths, lfr_splice_frame_idxs = self.lfr_cmvn(
feats, feats_lengths, is_final
)
if self.lfr_m == 1:
self.reserve_waveforms = None
else:
reserve_frame_idx = lfr_splice_frame_idxs[0] - minus_frame
# print('reserve_frame_idx: ' + str(reserve_frame_idx))
# print('frame_frame: ' + str(frame_from_waveforms))
self.reserve_waveforms = self.waveforms[
:,
reserve_frame_idx
* self.frame_shift_sample_length : frame_from_waveforms
* self.frame_shift_sample_length,
]
sample_length = (
frame_from_waveforms - 1
) * self.frame_shift_sample_length + self.frame_sample_length
self.waveforms = self.waveforms[:, :sample_length]
else:
# update self.reserve_waveforms and self.lfr_splice_cache
self.reserve_waveforms = self.waveforms[
:, : -(self.frame_sample_length - self.frame_shift_sample_length)
]
for i in range(batch_size):
self.lfr_splice_cache[i] = np.concatenate(
(self.lfr_splice_cache[i], feats[i]), axis=0
)
return np.empty(0, dtype=np.float32), feats_lengths
else:
if is_final:
self.waveforms = (
waveforms if self.reserve_waveforms is None else self.reserve_waveforms
)
feats = np.stack(self.lfr_splice_cache)
feats_lengths = np.zeros(batch_size, dtype=np.int32) + feats.shape[1]
feats, feats_lengths, _ = self.lfr_cmvn(feats, feats_lengths, is_final)
if is_final:
self.cache_reset()
return feats, feats_lengths
def get_waveforms(self):
return self.waveforms
def cache_reset(self):
self.fbank_fn = knf.OnlineFbank(self.opts)
self.reserve_waveforms = None
self.input_cache = None
self.lfr_splice_cache = []
def load_bytes(input):
middle_data = np.frombuffer(input, dtype=np.int16)
middle_data = np.asarray(middle_data)
if middle_data.dtype.kind not in "iu":
raise TypeError("'middle_data' must be an array of integers")
dtype = np.dtype("float32")
if dtype.kind != "f":
raise TypeError("'dtype' must be a floating point type")
i = np.iinfo(middle_data.dtype)
abs_max = 2 ** (i.bits - 1)
offset = i.min + abs_max
array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32)
return array
class SinusoidalPositionEncoderOnline:
"""Streaming Positional encoding."""
def encode(self, positions: np.ndarray = None, depth: int = None, dtype: np.dtype = np.float32):
batch_size = positions.shape[0]
positions = positions.astype(dtype)
log_timescale_increment = np.log(np.array([10000], dtype=dtype)) / (depth / 2 - 1)
inv_timescales = np.exp(np.arange(depth / 2).astype(dtype) * (-log_timescale_increment))
inv_timescales = np.reshape(inv_timescales, [batch_size, -1])
scaled_time = np.reshape(positions, [1, -1, 1]) * np.reshape(inv_timescales, [1, 1, -1])
encoding = np.concatenate((np.sin(scaled_time), np.cos(scaled_time)), axis=2)
return encoding.astype(dtype)
def forward(self, x, start_idx=0):
batch_size, timesteps, input_dim = x.shape
positions = np.arange(1, timesteps + 1 + start_idx)[None, :]
position_encoding = self.encode(positions, input_dim, x.dtype)
return x + position_encoding[:, start_idx : start_idx + timesteps]
def test():
path = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav"
import librosa
cmvn_file = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/am.mvn"
config_file = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/config.yaml"
from funasr.runtime.python.onnxruntime.rapid_paraformer.utils.utils import read_yaml
config = read_yaml(config_file)
waveform, _ = librosa.load(path, sr=None)
frontend = WavFrontend(
cmvn_file=cmvn_file,
**config["frontend_conf"],
)
speech, _ = frontend.fbank_online(waveform) # 1d, (sample,), numpy
feat, feat_len = frontend.lfr_cmvn(
speech
) # 2d, (frame, 450), np.float32 -> torch, torch.from_numpy(), dtype, (1, frame, 450)
frontend.reset_status() # clear cache
return feat, feat_len
if __name__ == "__main__":
test()

395
utils/infer_utils.py Normal file
View File

@ -0,0 +1,395 @@
# -*- encoding: utf-8 -*-
import functools
import logging
from pathlib import Path
from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
import re
import numpy as np
import yaml
try:
from onnxruntime import (
GraphOptimizationLevel,
InferenceSession,
SessionOptions,
get_available_providers,
get_device,
)
except:
print("please pip3 install onnxruntime")
import jieba
import warnings
root_dir = Path(__file__).resolve().parent
logger_initialized = {}
def pad_list(xs, pad_value, max_len=None):
n_batch = len(xs)
if max_len is None:
max_len = max(x.size(0) for x in xs)
# pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
# numpy format
pad = (np.zeros((n_batch, max_len)) + pad_value).astype(np.int32)
for i in range(n_batch):
pad[i, : xs[i].shape[0]] = xs[i]
return pad
"""
def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None):
if length_dim == 0:
raise ValueError("length_dim cannot be 0: {}".format(length_dim))
if not isinstance(lengths, list):
lengths = lengths.tolist()
bs = int(len(lengths))
if maxlen is None:
if xs is None:
maxlen = int(max(lengths))
else:
maxlen = xs.size(length_dim)
else:
assert xs is None
assert maxlen >= int(max(lengths))
seq_range = torch.arange(0, maxlen, dtype=torch.int64)
seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen)
seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1)
mask = seq_range_expand >= seq_length_expand
if xs is not None:
assert xs.size(0) == bs, (xs.size(0), bs)
if length_dim < 0:
length_dim = xs.dim() + length_dim
# ind = (:, None, ..., None, :, , None, ..., None)
ind = tuple(
slice(None) if i in (0, length_dim) else None for i in range(xs.dim())
)
mask = mask[ind].expand_as(xs).to(xs.device)
return mask
"""
class TokenIDConverter:
def __init__(
self,
token_list: Union[List, str],
):
self.token_list = token_list
self.unk_symbol = token_list[-1]
self.token2id = {v: i for i, v in enumerate(self.token_list)}
self.unk_id = self.token2id[self.unk_symbol]
def get_num_vocabulary_size(self) -> int:
return len(self.token_list)
def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]:
if isinstance(integers, np.ndarray) and integers.ndim != 1:
raise TokenIDConverterError(f"Must be 1 dim ndarray, but got {integers.ndim}")
return [self.token_list[i] for i in integers]
def tokens2ids(self, tokens: Iterable[str]) -> List[int]:
return [self.token2id.get(i, self.unk_id) for i in tokens]
class CharTokenizer:
def __init__(
self,
symbol_value: Union[Path, str, Iterable[str]] = None,
space_symbol: str = "<space>",
remove_non_linguistic_symbols: bool = False,
):
self.space_symbol = space_symbol
self.non_linguistic_symbols = self.load_symbols(symbol_value)
self.remove_non_linguistic_symbols = remove_non_linguistic_symbols
@staticmethod
def load_symbols(value: Union[Path, str, Iterable[str]] = None) -> Set:
if value is None:
return set()
if isinstance(value, Iterable[str]):
return set(value)
file_path = Path(value)
if not file_path.exists():
logging.warning("%s doesn't exist.", file_path)
return set()
with file_path.open("r", encoding="utf-8") as f:
return set(line.rstrip() for line in f)
def text2tokens(self, line: Union[str, list]) -> List[str]:
tokens = []
while len(line) != 0:
for w in self.non_linguistic_symbols:
if line.startswith(w):
if not self.remove_non_linguistic_symbols:
tokens.append(line[: len(w)])
line = line[len(w) :]
break
else:
t = line[0]
if t == " ":
t = "<space>"
tokens.append(t)
line = line[1:]
return tokens
def tokens2text(self, tokens: Iterable[str]) -> str:
tokens = [t if t != self.space_symbol else " " for t in tokens]
return "".join(tokens)
def __repr__(self):
return (
f"{self.__class__.__name__}("
f'space_symbol="{self.space_symbol}"'
f'non_linguistic_symbols="{self.non_linguistic_symbols}"'
f")"
)
class Hypothesis(NamedTuple):
"""Hypothesis data type."""
yseq: np.ndarray
score: Union[float, np.ndarray] = 0
scores: Dict[str, Union[float, np.ndarray]] = dict()
states: Dict[str, Any] = dict()
def asdict(self) -> dict:
"""Convert data to JSON-friendly dict."""
return self._replace(
yseq=self.yseq.tolist(),
score=float(self.score),
scores={k: float(v) for k, v in self.scores.items()},
)._asdict()
class TokenIDConverterError(Exception):
pass
class ONNXRuntimeError(Exception):
pass
class OrtInferSession:
def __init__(self, model_file, device_id=-1, intra_op_num_threads=4):
device_id = str(device_id)
sess_opt = SessionOptions()
sess_opt.intra_op_num_threads = intra_op_num_threads
sess_opt.log_severity_level = 4
sess_opt.enable_cpu_mem_arena = False
sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
cuda_ep = "CUDAExecutionProvider"
cuda_provider_options = {
"device_id": device_id,
"arena_extend_strategy": "kNextPowerOfTwo",
"cudnn_conv_algo_search": "EXHAUSTIVE",
"do_copy_in_default_stream": "true",
}
cpu_ep = "CPUExecutionProvider"
cpu_provider_options = {
"arena_extend_strategy": "kSameAsRequested",
}
EP_list = []
if device_id != "-1" and get_device() == "GPU" and cuda_ep in get_available_providers():
EP_list = [(cuda_ep, cuda_provider_options)]
EP_list.append((cpu_ep, cpu_provider_options))
self._verify_model(model_file)
self.session = InferenceSession(model_file, sess_options=sess_opt, providers=EP_list)
if device_id != "-1" and cuda_ep not in self.session.get_providers():
warnings.warn(
f"{cuda_ep} is not avaiable for current env, the inference part is automatically shifted to be executed under {cpu_ep}.\n"
"Please ensure the installed onnxruntime-gpu version matches your cuda and cudnn version, "
"you can check their relations from the offical web site: "
"https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html",
RuntimeWarning,
)
def __call__(self, input_content: List[Union[np.ndarray, np.ndarray]]) -> np.ndarray:
input_dict = dict(zip(self.get_input_names(), input_content))
try:
return self.session.run(self.get_output_names(), input_dict)
except Exception as e:
raise ONNXRuntimeError("ONNXRuntime inferece failed.") from e
def get_input_names(
self,
):
return [v.name for v in self.session.get_inputs()]
def get_output_names(
self,
):
return [v.name for v in self.session.get_outputs()]
def get_character_list(self, key: str = "character"):
return self.meta_dict[key].splitlines()
def have_key(self, key: str = "character") -> bool:
self.meta_dict = self.session.get_modelmeta().custom_metadata_map
if key in self.meta_dict.keys():
return True
return False
@staticmethod
def _verify_model(model_path):
model_path = Path(model_path)
if not model_path.exists():
raise FileNotFoundError(f"{model_path} does not exists.")
if not model_path.is_file():
raise FileExistsError(f"{model_path} is not a file.")
def split_to_mini_sentence(words: list, word_limit: int = 20):
assert word_limit > 1
if len(words) <= word_limit:
return [words]
sentences = []
length = len(words)
sentence_len = length // word_limit
for i in range(sentence_len):
sentences.append(words[i * word_limit : (i + 1) * word_limit])
if length % word_limit > 0:
sentences.append(words[sentence_len * word_limit :])
return sentences
def code_mix_split_words(text: str):
words = []
segs = text.split()
for seg in segs:
# There is no space in seg.
current_word = ""
for c in seg:
if len(c.encode()) == 1:
# This is an ASCII char.
current_word += c
else:
# This is a Chinese char.
if len(current_word) > 0:
words.append(current_word)
current_word = ""
words.append(c)
if len(current_word) > 0:
words.append(current_word)
return words
def isEnglish(text: str):
if re.search("^[a-zA-Z']+$", text):
return True
else:
return False
def join_chinese_and_english(input_list):
line = ""
for token in input_list:
if isEnglish(token):
line = line + " " + token
else:
line = line + token
line = line.strip()
return line
def code_mix_split_words_jieba(seg_dict_file: str):
jieba.load_userdict(seg_dict_file)
def _fn(text: str):
input_list = text.split()
token_list_all = []
langauge_list = []
token_list_tmp = []
language_flag = None
for token in input_list:
if isEnglish(token) and language_flag == "Chinese":
token_list_all.append(token_list_tmp)
langauge_list.append("Chinese")
token_list_tmp = []
elif not isEnglish(token) and language_flag == "English":
token_list_all.append(token_list_tmp)
langauge_list.append("English")
token_list_tmp = []
token_list_tmp.append(token)
if isEnglish(token):
language_flag = "English"
else:
language_flag = "Chinese"
if token_list_tmp:
token_list_all.append(token_list_tmp)
langauge_list.append(language_flag)
result_list = []
for token_list_tmp, language_flag in zip(token_list_all, langauge_list):
if language_flag == "English":
result_list.extend(token_list_tmp)
else:
seg_list = jieba.cut(join_chinese_and_english(token_list_tmp), HMM=False)
result_list.extend(seg_list)
return result_list
return _fn
def read_yaml(yaml_path: Union[str, Path]) -> Dict:
if not Path(yaml_path).exists():
raise FileExistsError(f"The {yaml_path} does not exist.")
with open(str(yaml_path), "rb") as f:
data = yaml.load(f, Loader=yaml.Loader)
return data
@functools.lru_cache()
def get_logger(name="funasr_onnx"):
"""Initialize and get a logger by name.
If the logger has not been initialized, this method will initialize the
logger by adding one or two handlers, otherwise the initialized logger will
be directly returned. During initialization, a StreamHandler will always be
added.
Args:
name (str): Logger name.
Returns:
logging.Logger: The expected logger.
"""
logger = logging.getLogger(name)
if name in logger_initialized:
return logger
for logger_name in logger_initialized:
if name.startswith(logger_name):
return logger
formatter = logging.Formatter(
"[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%Y/%m/%d %H:%M:%S"
)
sh = logging.StreamHandler()
sh.setFormatter(formatter)
logger.addHandler(sh)
logger_initialized[name] = True
logger.propagate = False
logging.basicConfig(level=logging.ERROR)
return logger

146
utils/model_bin.py Normal file
View File

@ -0,0 +1,146 @@
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
import os.path
from pathlib import Path
from typing import List, Union, Tuple
import torch
import librosa
import numpy as np
from utils.infer_utils import (
CharTokenizer,
Hypothesis,
ONNXRuntimeError,
OrtInferSession,
TokenIDConverter,
get_logger,
read_yaml,
)
# from .utils.postprocess_utils import sentence_postprocess, sentence_postprocess_sentencepiece
from utils.frontend import WavFrontend
# from .utils.timestamp_utils import time_stamp_lfr6_onnx
from utils.infer_utils import pad_list
logging = get_logger()
class SenseVoiceSmallONNX:
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
https://arxiv.org/abs/2206.08317
"""
def __init__(
self,
model_dir: Union[str, Path] = None,
batch_size: int = 1,
device_id: Union[str, int] = "-1",
plot_timestamp_to: str = "",
quantize: bool = False,
intra_op_num_threads: int = 4,
cache_dir: str = None,
**kwargs,
):
if quantize:
model_file = os.path.join(model_dir, "model_quant.onnx")
else:
model_file = os.path.join(model_dir, "model.onnx")
config_file = os.path.join(model_dir, "config.yaml")
cmvn_file = os.path.join(model_dir, "am.mvn")
config = read_yaml(config_file)
# token_list = os.path.join(model_dir, "tokens.json")
# with open(token_list, "r", encoding="utf-8") as f:
# token_list = json.load(f)
# self.converter = TokenIDConverter(token_list)
self.tokenizer = CharTokenizer()
config["frontend_conf"]['cmvn_file'] = cmvn_file
self.frontend = WavFrontend(**config["frontend_conf"])
self.ort_infer = OrtInferSession(
model_file, device_id, intra_op_num_threads=intra_op_num_threads
)
self.batch_size = batch_size
self.blank_id = 0
def __call__(self,
wav_content: Union[str, np.ndarray, List[str]],
language: List,
textnorm: List,
tokenizer=None,
**kwargs) -> List:
waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
waveform_nums = len(waveform_list)
asr_res = []
for beg_idx in range(0, waveform_nums, self.batch_size):
end_idx = min(waveform_nums, beg_idx + self.batch_size)
feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
ctc_logits, encoder_out_lens = self.infer(feats,
feats_len,
np.array(language, dtype=np.int32),
np.array(textnorm, dtype=np.int32)
)
# back to torch.Tensor
ctc_logits = torch.from_numpy(ctc_logits).float()
# support batch_size=1 only currently
x = ctc_logits[0, : encoder_out_lens[0].item(), :]
yseq = x.argmax(dim=-1)
yseq = torch.unique_consecutive(yseq, dim=-1)
mask = yseq != self.blank_id
token_int = yseq[mask].tolist()
if tokenizer is not None:
asr_res.append(tokenizer.tokens2text(token_int))
else:
asr_res.append(token_int)
return asr_res
def load_data(self, wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
def load_wav(path: str) -> np.ndarray:
waveform, _ = librosa.load(path, sr=fs)
return waveform
if isinstance(wav_content, np.ndarray):
return [wav_content]
if isinstance(wav_content, str):
return [load_wav(wav_content)]
if isinstance(wav_content, list):
return [load_wav(path) for path in wav_content]
raise TypeError(f"The type of {wav_content} is not in [str, np.ndarray, list]")
def extract_feat(self, waveform_list: List[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
feats, feats_len = [], []
for waveform in waveform_list:
speech, _ = self.frontend.fbank(waveform)
feat, feat_len = self.frontend.lfr_cmvn(speech)
feats.append(feat)
feats_len.append(feat_len)
feats = self.pad_feats(feats, np.max(feats_len))
feats_len = np.array(feats_len).astype(np.int32)
return feats, feats_len
@staticmethod
def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray:
def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray:
pad_width = ((0, max_feat_len - cur_len), (0, 0))
return np.pad(feat, pad_width, "constant", constant_values=0)
feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats]
feats = np.array(feat_res).astype(np.float32)
return feats
def infer(self,
feats: np.ndarray,
feats_len: np.ndarray,
language: np.ndarray,
textnorm: np.ndarray,) -> Tuple[np.ndarray, np.ndarray]:
outputs = self.ort_infer([feats, feats_len, language, textnorm])
return outputs