diff --git a/demo1.py b/demo1.py index 5a1bdc8..ec8797f 100644 --- a/demo1.py +++ b/demo1.py @@ -1,9 +1,3 @@ -#!/usr/bin/env python3 -# -*- encoding: utf-8 -*- -# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. -# MIT License (https://opensource.org/licenses/MIT) - - from funasr import AutoModel from funasr.utils.postprocess_utils import rich_transcription_postprocess diff --git a/demo2.py b/demo2.py index 0ffbf13..6f08579 100644 --- a/demo2.py +++ b/demo2.py @@ -1,8 +1,3 @@ -#!/usr/bin/env python3 -# -*- encoding: utf-8 -*- -# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. -# MIT License (https://opensource.org/licenses/MIT) - from model import SenseVoiceSmall from funasr.utils.postprocess_utils import rich_transcription_postprocess diff --git a/export.py b/export.py new file mode 100644 index 0000000..9114ff1 --- /dev/null +++ b/export.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python3 +# -*- encoding: utf-8 -*- +# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. +# MIT License (https://opensource.org/licenses/MIT) + +import os +import torch +from model import SenseVoiceSmall +from utils import export_utils +from utils.model_bin import SenseVoiceSmallONNX +from funasr.utils.postprocess_utils import rich_transcription_postprocess + +quantize = False + +model_dir = "iic/SenseVoiceSmall" +model, kwargs = SenseVoiceSmall.from_pretrained(model=model_dir, device="cuda:0") + +rebuilt_model = model.export(type="onnx", quantize=False) +model_path = kwargs.get("output_dir", os.path.dirname(kwargs.get("init_param"))) + +model_file = os.path.join(model_path, "model.onnx") +if quantize: + model_file = os.path.join(model_path, "model_quant.onnx") + +# export model +if not os.path.exists(model_file): + with torch.no_grad(): + del kwargs['model'] + export_dir = export_utils.export(model=rebuilt_model, **kwargs) + print("Export model onnx to {}".format(model_file)) + +# export model init +model_bin = SenseVoiceSmallONNX(model_path) + +# build tokenizer +try: + from funasr.tokenizer.sentencepiece_tokenizer import SentencepiecesTokenizer + tokenizer = SentencepiecesTokenizer(bpemodel=os.path.join(model_path, "chn_jpn_yue_eng_ko_spectok.bpe.model")) +except: + tokenizer = None + +# inference +wav_or_scp = "/Users/shixian/Downloads/asr_example_hotword.wav" +language_list = [0] +textnorm_list = [15] +res = model_bin(wav_or_scp, language_list, textnorm_list, tokenizer=tokenizer) +print(res) diff --git a/export_meta.py b/export_meta.py index 575dafe..920ca64 100644 --- a/export_meta.py +++ b/export_meta.py @@ -5,33 +5,20 @@ import types import torch -import torch.nn as nn -from funasr.register import tables +from funasr.utils.torch_function import sequence_mask def export_rebuild_model(model, **kwargs): model.device = kwargs.get("device") - is_onnx = kwargs.get("type", "onnx") == "onnx" - # encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export") - # model.encoder = encoder_class(model.encoder, onnx=is_onnx) - - - - from funasr.utils.torch_function import sequence_mask - model.make_pad_mask = sequence_mask(kwargs["max_seq_len"], flip=False) - model.forward = types.MethodType(export_forward, model) model.export_dummy_inputs = types.MethodType(export_dummy_inputs, model) model.export_input_names = types.MethodType(export_input_names, model) model.export_output_names = types.MethodType(export_output_names, model) model.export_dynamic_axes = types.MethodType(export_dynamic_axes, model) model.export_name = types.MethodType(export_name, model) - - model.export_name = "model" return model - def export_forward( self, speech: torch.Tensor, @@ -40,12 +27,11 @@ def export_forward( textnorm: torch.Tensor, **kwargs, ): - speech = speech.to(device=kwargs["device"]) - speech_lengths = speech_lengths.to(device=kwargs["device"]) - - language_query = self.embed(language).to(speech.device) - - textnorm_query = self.embed(textnorm).to(speech.device) + # speech = speech.to(device="cuda") + # speech_lengths = speech_lengths.to(device="cuda") + language_query = self.embed(language.to(speech.device)).unsqueeze(1) + textnorm_query = self.embed(textnorm.to(speech.device)).unsqueeze(1) + print(textnorm_query.shape, speech.shape) speech = torch.cat((textnorm_query, speech), dim=1) speech_lengths += 1 @@ -56,18 +42,14 @@ def export_forward( speech = torch.cat((input_query, speech), dim=1) speech_lengths += 3 - # Encoder encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths) if isinstance(encoder_out, tuple): encoder_out = encoder_out[0] - - # c. Passed the encoder result and the beam search - ctc_logits = self.ctc.log_softmax(encoder_out) - + + ctc_logits = self.ctc.ctc_lo(encoder_out) return ctc_logits, encoder_out_lens - def export_dummy_inputs(self): speech = torch.randn(2, 30, 560) speech_lengths = torch.tensor([6, 30], dtype=torch.int32) @@ -75,27 +57,22 @@ def export_dummy_inputs(self): textnorm = torch.tensor([15, 15], dtype=torch.int32) return (speech, speech_lengths, language, textnorm) - def export_input_names(self): return ["speech", "speech_lengths", "language", "textnorm"] - def export_output_names(self): return ["ctc_logits", "encoder_out_lens"] - def export_dynamic_axes(self): return { "speech": {0: "batch_size", 1: "feats_length"}, - "speech_lengths": { - 0: "batch_size", - }, - "logits": {0: "batch_size", 1: "logits_length"}, + "speech_lengths": {0: "batch_size"}, + "language": {0: "batch_size"}, + "textnorm": {0: "batch_size"}, + "ctc_logits": {0: "batch_size", 1: "logits_length"}, + "encoder_out_lens": {0: "batch_size"}, } - -def export_name( - self, -): +def export_name(self): return "model.onnx" diff --git a/model.py b/model.py index 5bca0c2..5ac1107 100644 --- a/model.py +++ b/model.py @@ -1,24 +1,18 @@ -from typing import Iterable, Optional -import types -import time -import numpy as np -import torch -import torch.nn.functional as F -from torch import Tensor -from torch import nn -from torch.cuda.amp import autocast -from funasr.metrics.compute_acc import compute_accuracy, th_accuracy -from funasr.losses.label_smoothing_loss import LabelSmoothingLoss -from funasr.train_utils.device_funcs import force_gatherable -from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank -from funasr.utils.datadir_writer import DatadirWriter -from funasr.models.ctc.ctc import CTC +import time +import torch +from torch import nn +import torch.nn.functional as F +from typing import Iterable, Optional from funasr.register import tables - - +from funasr.models.ctc.ctc import CTC +from funasr.utils.datadir_writer import DatadirWriter from funasr.models.paraformer.search import Hypothesis +from funasr.train_utils.device_funcs import force_gatherable +from funasr.losses.label_smoothing_loss import LabelSmoothingLoss +from funasr.metrics.compute_acc import compute_accuracy, th_accuracy +from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank class SinusoidalPositionEncoder(torch.nn.Module): @@ -890,7 +884,7 @@ class SenseVoiceSmall(nn.Module): return results, meta_data def export(self, **kwargs): - from .export_meta import export_rebuild_model + from export_meta import export_rebuild_model if "max_seq_len" not in kwargs: kwargs["max_seq_len"] = 512 diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/export_utils.py b/utils/export_utils.py new file mode 100644 index 0000000..f070218 --- /dev/null +++ b/utils/export_utils.py @@ -0,0 +1,73 @@ +import os +import torch + + +def export( + model, quantize: bool = False, opset_version: int = 14, type="onnx", **kwargs +): + model_scripts = model.export(**kwargs) + export_dir = kwargs.get("output_dir", os.path.dirname(kwargs.get("init_param"))) + os.makedirs(export_dir, exist_ok=True) + + if not isinstance(model_scripts, (list, tuple)): + model_scripts = (model_scripts,) + for m in model_scripts: + m.eval() + if type == "onnx": + _onnx( + m, + quantize=quantize, + opset_version=opset_version, + export_dir=export_dir, + **kwargs, + ) + print("output dir: {}".format(export_dir)) + + return export_dir + + +def _onnx( + model, + quantize: bool = False, + opset_version: int = 14, + export_dir: str = None, + **kwargs, +): + + dummy_input = model.export_dummy_inputs() + + verbose = kwargs.get("verbose", False) + + export_name = model.export_name() + model_path = os.path.join(export_dir, export_name) + torch.onnx.export( + model, + dummy_input, + model_path, + verbose=verbose, + opset_version=opset_version, + input_names=model.export_input_names(), + output_names=model.export_output_names(), + dynamic_axes=model.export_dynamic_axes(), + ) + + if quantize: + from onnxruntime.quantization import QuantType, quantize_dynamic + import onnx + + quant_model_path = model_path.replace(".onnx", "_quant.onnx") + if not os.path.exists(quant_model_path): + onnx_model = onnx.load(model_path) + nodes = [n.name for n in onnx_model.graph.node] + nodes_to_exclude = [ + m for m in nodes if "output" in m or "bias_encoder" in m or "bias_decoder" in m + ] + quantize_dynamic( + model_input=model_path, + model_output=quant_model_path, + op_types_to_quantize=["MatMul"], + per_channel=True, + reduce_range=False, + weight_type=QuantType.QUInt8, + nodes_to_exclude=nodes_to_exclude, + ) diff --git a/utils/frontend.py b/utils/frontend.py new file mode 100644 index 0000000..7b38f8b --- /dev/null +++ b/utils/frontend.py @@ -0,0 +1,433 @@ +# -*- encoding: utf-8 -*- +from pathlib import Path +from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union +import copy + +import numpy as np +import kaldi_native_fbank as knf + +root_dir = Path(__file__).resolve().parent + +logger_initialized = {} + + +class WavFrontend: + """Conventional frontend structure for ASR.""" + + def __init__( + self, + cmvn_file: str = None, + fs: int = 16000, + window: str = "hamming", + n_mels: int = 80, + frame_length: int = 25, + frame_shift: int = 10, + lfr_m: int = 1, + lfr_n: int = 1, + dither: float = 1.0, + **kwargs, + ) -> None: + + opts = knf.FbankOptions() + opts.frame_opts.samp_freq = fs + opts.frame_opts.dither = dither + opts.frame_opts.window_type = window + opts.frame_opts.frame_shift_ms = float(frame_shift) + opts.frame_opts.frame_length_ms = float(frame_length) + opts.mel_opts.num_bins = n_mels + opts.energy_floor = 0 + opts.frame_opts.snip_edges = True + opts.mel_opts.debug_mel = False + self.opts = opts + + self.lfr_m = lfr_m + self.lfr_n = lfr_n + self.cmvn_file = cmvn_file + + if self.cmvn_file: + self.cmvn = self.load_cmvn() + self.fbank_fn = None + self.fbank_beg_idx = 0 + self.reset_status() + + def fbank(self, waveform: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + waveform = waveform * (1 << 15) + self.fbank_fn = knf.OnlineFbank(self.opts) + self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist()) + frames = self.fbank_fn.num_frames_ready + mat = np.empty([frames, self.opts.mel_opts.num_bins]) + for i in range(frames): + mat[i, :] = self.fbank_fn.get_frame(i) + feat = mat.astype(np.float32) + feat_len = np.array(mat.shape[0]).astype(np.int32) + return feat, feat_len + + def fbank_online(self, waveform: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + waveform = waveform * (1 << 15) + # self.fbank_fn = knf.OnlineFbank(self.opts) + self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist()) + frames = self.fbank_fn.num_frames_ready + mat = np.empty([frames, self.opts.mel_opts.num_bins]) + for i in range(self.fbank_beg_idx, frames): + mat[i, :] = self.fbank_fn.get_frame(i) + # self.fbank_beg_idx += (frames-self.fbank_beg_idx) + feat = mat.astype(np.float32) + feat_len = np.array(mat.shape[0]).astype(np.int32) + return feat, feat_len + + def reset_status(self): + self.fbank_fn = knf.OnlineFbank(self.opts) + self.fbank_beg_idx = 0 + + def lfr_cmvn(self, feat: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + if self.lfr_m != 1 or self.lfr_n != 1: + feat = self.apply_lfr(feat, self.lfr_m, self.lfr_n) + + if self.cmvn_file: + feat = self.apply_cmvn(feat) + + feat_len = np.array(feat.shape[0]).astype(np.int32) + return feat, feat_len + + @staticmethod + def apply_lfr(inputs: np.ndarray, lfr_m: int, lfr_n: int) -> np.ndarray: + LFR_inputs = [] + + T = inputs.shape[0] + T_lfr = int(np.ceil(T / lfr_n)) + left_padding = np.tile(inputs[0], ((lfr_m - 1) // 2, 1)) + inputs = np.vstack((left_padding, inputs)) + T = T + (lfr_m - 1) // 2 + for i in range(T_lfr): + if lfr_m <= T - i * lfr_n: + LFR_inputs.append((inputs[i * lfr_n : i * lfr_n + lfr_m]).reshape(1, -1)) + else: + # process last LFR frame + num_padding = lfr_m - (T - i * lfr_n) + frame = inputs[i * lfr_n :].reshape(-1) + for _ in range(num_padding): + frame = np.hstack((frame, inputs[-1])) + + LFR_inputs.append(frame) + LFR_outputs = np.vstack(LFR_inputs).astype(np.float32) + return LFR_outputs + + def apply_cmvn(self, inputs: np.ndarray) -> np.ndarray: + """ + Apply CMVN with mvn data + """ + frame, dim = inputs.shape + means = np.tile(self.cmvn[0:1, :dim], (frame, 1)) + vars = np.tile(self.cmvn[1:2, :dim], (frame, 1)) + inputs = (inputs + means) * vars + return inputs + + def load_cmvn( + self, + ) -> np.ndarray: + with open(self.cmvn_file, "r", encoding="utf-8") as f: + lines = f.readlines() + + means_list = [] + vars_list = [] + for i in range(len(lines)): + line_item = lines[i].split() + if line_item[0] == "": + line_item = lines[i + 1].split() + if line_item[0] == "": + add_shift_line = line_item[3 : (len(line_item) - 1)] + means_list = list(add_shift_line) + continue + elif line_item[0] == "": + line_item = lines[i + 1].split() + if line_item[0] == "": + rescale_line = line_item[3 : (len(line_item) - 1)] + vars_list = list(rescale_line) + continue + + means = np.array(means_list).astype(np.float64) + vars = np.array(vars_list).astype(np.float64) + cmvn = np.array([means, vars]) + return cmvn + + +class WavFrontendOnline(WavFrontend): + def __init__(self, **kwargs): + super().__init__(**kwargs) + # self.fbank_fn = knf.OnlineFbank(self.opts) + # add variables + self.frame_sample_length = int( + self.opts.frame_opts.frame_length_ms * self.opts.frame_opts.samp_freq / 1000 + ) + self.frame_shift_sample_length = int( + self.opts.frame_opts.frame_shift_ms * self.opts.frame_opts.samp_freq / 1000 + ) + self.waveform = None + self.reserve_waveforms = None + self.input_cache = None + self.lfr_splice_cache = [] + + @staticmethod + # inputs has catted the cache + def apply_lfr( + inputs: np.ndarray, lfr_m: int, lfr_n: int, is_final: bool = False + ) -> Tuple[np.ndarray, np.ndarray, int]: + """ + Apply lfr with data + """ + + LFR_inputs = [] + T = inputs.shape[0] # include the right context + T_lfr = int( + np.ceil((T - (lfr_m - 1) // 2) / lfr_n) + ) # minus the right context: (lfr_m - 1) // 2 + splice_idx = T_lfr + for i in range(T_lfr): + if lfr_m <= T - i * lfr_n: + LFR_inputs.append((inputs[i * lfr_n : i * lfr_n + lfr_m]).reshape(1, -1)) + else: # process last LFR frame + if is_final: + num_padding = lfr_m - (T - i * lfr_n) + frame = (inputs[i * lfr_n :]).reshape(-1) + for _ in range(num_padding): + frame = np.hstack((frame, inputs[-1])) + LFR_inputs.append(frame) + else: + # update splice_idx and break the circle + splice_idx = i + break + splice_idx = min(T - 1, splice_idx * lfr_n) + lfr_splice_cache = inputs[splice_idx:, :] + LFR_outputs = np.vstack(LFR_inputs) + return LFR_outputs.astype(np.float32), lfr_splice_cache, splice_idx + + @staticmethod + def compute_frame_num( + sample_length: int, frame_sample_length: int, frame_shift_sample_length: int + ) -> int: + frame_num = int((sample_length - frame_sample_length) / frame_shift_sample_length + 1) + return frame_num if frame_num >= 1 and sample_length >= frame_sample_length else 0 + + def fbank( + self, input: np.ndarray, input_lengths: np.ndarray + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + self.fbank_fn = knf.OnlineFbank(self.opts) + batch_size = input.shape[0] + if self.input_cache is None: + self.input_cache = np.empty((batch_size, 0), dtype=np.float32) + input = np.concatenate((self.input_cache, input), axis=1) + frame_num = self.compute_frame_num( + input.shape[-1], self.frame_sample_length, self.frame_shift_sample_length + ) + # update self.in_cache + self.input_cache = input[ + :, -(input.shape[-1] - frame_num * self.frame_shift_sample_length) : + ] + waveforms = np.empty(0, dtype=np.float32) + feats_pad = np.empty(0, dtype=np.float32) + feats_lens = np.empty(0, dtype=np.int32) + if frame_num: + waveforms = [] + feats = [] + feats_lens = [] + for i in range(batch_size): + waveform = input[i] + waveforms.append( + waveform[ + : ( + (frame_num - 1) * self.frame_shift_sample_length + + self.frame_sample_length + ) + ] + ) + waveform = waveform * (1 << 15) + + self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist()) + frames = self.fbank_fn.num_frames_ready + mat = np.empty([frames, self.opts.mel_opts.num_bins]) + for i in range(frames): + mat[i, :] = self.fbank_fn.get_frame(i) + feat = mat.astype(np.float32) + feat_len = np.array(mat.shape[0]).astype(np.int32) + feats.append(feat) + feats_lens.append(feat_len) + + waveforms = np.stack(waveforms) + feats_lens = np.array(feats_lens) + feats_pad = np.array(feats) + self.fbanks = feats_pad + self.fbanks_lens = copy.deepcopy(feats_lens) + return waveforms, feats_pad, feats_lens + + def get_fbank(self) -> Tuple[np.ndarray, np.ndarray]: + return self.fbanks, self.fbanks_lens + + def lfr_cmvn( + self, input: np.ndarray, input_lengths: np.ndarray, is_final: bool = False + ) -> Tuple[np.ndarray, np.ndarray, List[int]]: + batch_size = input.shape[0] + feats = [] + feats_lens = [] + lfr_splice_frame_idxs = [] + for i in range(batch_size): + mat = input[i, : input_lengths[i], :] + lfr_splice_frame_idx = -1 + if self.lfr_m != 1 or self.lfr_n != 1: + # update self.lfr_splice_cache in self.apply_lfr + mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr( + mat, self.lfr_m, self.lfr_n, is_final + ) + if self.cmvn_file is not None: + mat = self.apply_cmvn(mat) + feat_length = mat.shape[0] + feats.append(mat) + feats_lens.append(feat_length) + lfr_splice_frame_idxs.append(lfr_splice_frame_idx) + + feats_lens = np.array(feats_lens) + feats_pad = np.array(feats) + return feats_pad, feats_lens, lfr_splice_frame_idxs + + def extract_fbank( + self, input: np.ndarray, input_lengths: np.ndarray, is_final: bool = False + ) -> Tuple[np.ndarray, np.ndarray]: + batch_size = input.shape[0] + assert ( + batch_size == 1 + ), "we support to extract feature online only when the batch size is equal to 1 now" + waveforms, feats, feats_lengths = self.fbank(input, input_lengths) # input shape: B T D + if feats.shape[0]: + self.waveforms = ( + waveforms + if self.reserve_waveforms is None + else np.concatenate((self.reserve_waveforms, waveforms), axis=1) + ) + if not self.lfr_splice_cache: + for i in range(batch_size): + self.lfr_splice_cache.append( + np.expand_dims(feats[i][0, :], axis=0).repeat((self.lfr_m - 1) // 2, axis=0) + ) + + if feats_lengths[0] + self.lfr_splice_cache[0].shape[0] >= self.lfr_m: + lfr_splice_cache_np = np.stack(self.lfr_splice_cache) # B T D + feats = np.concatenate((lfr_splice_cache_np, feats), axis=1) + feats_lengths += lfr_splice_cache_np[0].shape[0] + frame_from_waveforms = int( + (self.waveforms.shape[1] - self.frame_sample_length) + / self.frame_shift_sample_length + + 1 + ) + minus_frame = (self.lfr_m - 1) // 2 if self.reserve_waveforms is None else 0 + feats, feats_lengths, lfr_splice_frame_idxs = self.lfr_cmvn( + feats, feats_lengths, is_final + ) + if self.lfr_m == 1: + self.reserve_waveforms = None + else: + reserve_frame_idx = lfr_splice_frame_idxs[0] - minus_frame + # print('reserve_frame_idx: ' + str(reserve_frame_idx)) + # print('frame_frame: ' + str(frame_from_waveforms)) + self.reserve_waveforms = self.waveforms[ + :, + reserve_frame_idx + * self.frame_shift_sample_length : frame_from_waveforms + * self.frame_shift_sample_length, + ] + sample_length = ( + frame_from_waveforms - 1 + ) * self.frame_shift_sample_length + self.frame_sample_length + self.waveforms = self.waveforms[:, :sample_length] + else: + # update self.reserve_waveforms and self.lfr_splice_cache + self.reserve_waveforms = self.waveforms[ + :, : -(self.frame_sample_length - self.frame_shift_sample_length) + ] + for i in range(batch_size): + self.lfr_splice_cache[i] = np.concatenate( + (self.lfr_splice_cache[i], feats[i]), axis=0 + ) + return np.empty(0, dtype=np.float32), feats_lengths + else: + if is_final: + self.waveforms = ( + waveforms if self.reserve_waveforms is None else self.reserve_waveforms + ) + feats = np.stack(self.lfr_splice_cache) + feats_lengths = np.zeros(batch_size, dtype=np.int32) + feats.shape[1] + feats, feats_lengths, _ = self.lfr_cmvn(feats, feats_lengths, is_final) + if is_final: + self.cache_reset() + return feats, feats_lengths + + def get_waveforms(self): + return self.waveforms + + def cache_reset(self): + self.fbank_fn = knf.OnlineFbank(self.opts) + self.reserve_waveforms = None + self.input_cache = None + self.lfr_splice_cache = [] + + +def load_bytes(input): + middle_data = np.frombuffer(input, dtype=np.int16) + middle_data = np.asarray(middle_data) + if middle_data.dtype.kind not in "iu": + raise TypeError("'middle_data' must be an array of integers") + dtype = np.dtype("float32") + if dtype.kind != "f": + raise TypeError("'dtype' must be a floating point type") + + i = np.iinfo(middle_data.dtype) + abs_max = 2 ** (i.bits - 1) + offset = i.min + abs_max + array = np.frombuffer((middle_data.astype(dtype) - offset) / abs_max, dtype=np.float32) + return array + + +class SinusoidalPositionEncoderOnline: + """Streaming Positional encoding.""" + + def encode(self, positions: np.ndarray = None, depth: int = None, dtype: np.dtype = np.float32): + batch_size = positions.shape[0] + positions = positions.astype(dtype) + log_timescale_increment = np.log(np.array([10000], dtype=dtype)) / (depth / 2 - 1) + inv_timescales = np.exp(np.arange(depth / 2).astype(dtype) * (-log_timescale_increment)) + inv_timescales = np.reshape(inv_timescales, [batch_size, -1]) + scaled_time = np.reshape(positions, [1, -1, 1]) * np.reshape(inv_timescales, [1, 1, -1]) + encoding = np.concatenate((np.sin(scaled_time), np.cos(scaled_time)), axis=2) + return encoding.astype(dtype) + + def forward(self, x, start_idx=0): + batch_size, timesteps, input_dim = x.shape + positions = np.arange(1, timesteps + 1 + start_idx)[None, :] + position_encoding = self.encode(positions, input_dim, x.dtype) + + return x + position_encoding[:, start_idx : start_idx + timesteps] + + +def test(): + path = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav" + import librosa + + cmvn_file = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/am.mvn" + config_file = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/config.yaml" + from funasr.runtime.python.onnxruntime.rapid_paraformer.utils.utils import read_yaml + + config = read_yaml(config_file) + waveform, _ = librosa.load(path, sr=None) + frontend = WavFrontend( + cmvn_file=cmvn_file, + **config["frontend_conf"], + ) + speech, _ = frontend.fbank_online(waveform) # 1d, (sample,), numpy + feat, feat_len = frontend.lfr_cmvn( + speech + ) # 2d, (frame, 450), np.float32 -> torch, torch.from_numpy(), dtype, (1, frame, 450) + + frontend.reset_status() # clear cache + return feat, feat_len + + +if __name__ == "__main__": + test() diff --git a/utils/infer_utils.py b/utils/infer_utils.py new file mode 100644 index 0000000..c39d433 --- /dev/null +++ b/utils/infer_utils.py @@ -0,0 +1,395 @@ +# -*- encoding: utf-8 -*- + +import functools +import logging +from pathlib import Path +from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union + +import re +import numpy as np +import yaml + +try: + from onnxruntime import ( + GraphOptimizationLevel, + InferenceSession, + SessionOptions, + get_available_providers, + get_device, + ) +except: + print("please pip3 install onnxruntime") +import jieba +import warnings + +root_dir = Path(__file__).resolve().parent + +logger_initialized = {} + + +def pad_list(xs, pad_value, max_len=None): + n_batch = len(xs) + if max_len is None: + max_len = max(x.size(0) for x in xs) + # pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value) + # numpy format + pad = (np.zeros((n_batch, max_len)) + pad_value).astype(np.int32) + for i in range(n_batch): + pad[i, : xs[i].shape[0]] = xs[i] + + return pad + + +""" +def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None): + if length_dim == 0: + raise ValueError("length_dim cannot be 0: {}".format(length_dim)) + + if not isinstance(lengths, list): + lengths = lengths.tolist() + bs = int(len(lengths)) + if maxlen is None: + if xs is None: + maxlen = int(max(lengths)) + else: + maxlen = xs.size(length_dim) + else: + assert xs is None + assert maxlen >= int(max(lengths)) + + seq_range = torch.arange(0, maxlen, dtype=torch.int64) + seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen) + seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1) + mask = seq_range_expand >= seq_length_expand + + if xs is not None: + assert xs.size(0) == bs, (xs.size(0), bs) + + if length_dim < 0: + length_dim = xs.dim() + length_dim + # ind = (:, None, ..., None, :, , None, ..., None) + ind = tuple( + slice(None) if i in (0, length_dim) else None for i in range(xs.dim()) + ) + mask = mask[ind].expand_as(xs).to(xs.device) + return mask +""" + + +class TokenIDConverter: + def __init__( + self, + token_list: Union[List, str], + ): + + self.token_list = token_list + self.unk_symbol = token_list[-1] + self.token2id = {v: i for i, v in enumerate(self.token_list)} + self.unk_id = self.token2id[self.unk_symbol] + + def get_num_vocabulary_size(self) -> int: + return len(self.token_list) + + def ids2tokens(self, integers: Union[np.ndarray, Iterable[int]]) -> List[str]: + if isinstance(integers, np.ndarray) and integers.ndim != 1: + raise TokenIDConverterError(f"Must be 1 dim ndarray, but got {integers.ndim}") + return [self.token_list[i] for i in integers] + + def tokens2ids(self, tokens: Iterable[str]) -> List[int]: + + return [self.token2id.get(i, self.unk_id) for i in tokens] + + +class CharTokenizer: + def __init__( + self, + symbol_value: Union[Path, str, Iterable[str]] = None, + space_symbol: str = "", + remove_non_linguistic_symbols: bool = False, + ): + + self.space_symbol = space_symbol + self.non_linguistic_symbols = self.load_symbols(symbol_value) + self.remove_non_linguistic_symbols = remove_non_linguistic_symbols + + @staticmethod + def load_symbols(value: Union[Path, str, Iterable[str]] = None) -> Set: + if value is None: + return set() + + if isinstance(value, Iterable[str]): + return set(value) + + file_path = Path(value) + if not file_path.exists(): + logging.warning("%s doesn't exist.", file_path) + return set() + + with file_path.open("r", encoding="utf-8") as f: + return set(line.rstrip() for line in f) + + def text2tokens(self, line: Union[str, list]) -> List[str]: + tokens = [] + while len(line) != 0: + for w in self.non_linguistic_symbols: + if line.startswith(w): + if not self.remove_non_linguistic_symbols: + tokens.append(line[: len(w)]) + line = line[len(w) :] + break + else: + t = line[0] + if t == " ": + t = "" + tokens.append(t) + line = line[1:] + return tokens + + def tokens2text(self, tokens: Iterable[str]) -> str: + tokens = [t if t != self.space_symbol else " " for t in tokens] + return "".join(tokens) + + def __repr__(self): + return ( + f"{self.__class__.__name__}(" + f'space_symbol="{self.space_symbol}"' + f'non_linguistic_symbols="{self.non_linguistic_symbols}"' + f")" + ) + + +class Hypothesis(NamedTuple): + """Hypothesis data type.""" + + yseq: np.ndarray + score: Union[float, np.ndarray] = 0 + scores: Dict[str, Union[float, np.ndarray]] = dict() + states: Dict[str, Any] = dict() + + def asdict(self) -> dict: + """Convert data to JSON-friendly dict.""" + return self._replace( + yseq=self.yseq.tolist(), + score=float(self.score), + scores={k: float(v) for k, v in self.scores.items()}, + )._asdict() + + +class TokenIDConverterError(Exception): + pass + + +class ONNXRuntimeError(Exception): + pass + + +class OrtInferSession: + def __init__(self, model_file, device_id=-1, intra_op_num_threads=4): + device_id = str(device_id) + sess_opt = SessionOptions() + sess_opt.intra_op_num_threads = intra_op_num_threads + sess_opt.log_severity_level = 4 + sess_opt.enable_cpu_mem_arena = False + sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL + + cuda_ep = "CUDAExecutionProvider" + cuda_provider_options = { + "device_id": device_id, + "arena_extend_strategy": "kNextPowerOfTwo", + "cudnn_conv_algo_search": "EXHAUSTIVE", + "do_copy_in_default_stream": "true", + } + cpu_ep = "CPUExecutionProvider" + cpu_provider_options = { + "arena_extend_strategy": "kSameAsRequested", + } + + EP_list = [] + if device_id != "-1" and get_device() == "GPU" and cuda_ep in get_available_providers(): + EP_list = [(cuda_ep, cuda_provider_options)] + EP_list.append((cpu_ep, cpu_provider_options)) + + self._verify_model(model_file) + self.session = InferenceSession(model_file, sess_options=sess_opt, providers=EP_list) + + if device_id != "-1" and cuda_ep not in self.session.get_providers(): + warnings.warn( + f"{cuda_ep} is not avaiable for current env, the inference part is automatically shifted to be executed under {cpu_ep}.\n" + "Please ensure the installed onnxruntime-gpu version matches your cuda and cudnn version, " + "you can check their relations from the offical web site: " + "https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html", + RuntimeWarning, + ) + + def __call__(self, input_content: List[Union[np.ndarray, np.ndarray]]) -> np.ndarray: + input_dict = dict(zip(self.get_input_names(), input_content)) + try: + return self.session.run(self.get_output_names(), input_dict) + except Exception as e: + raise ONNXRuntimeError("ONNXRuntime inferece failed.") from e + + def get_input_names( + self, + ): + return [v.name for v in self.session.get_inputs()] + + def get_output_names( + self, + ): + return [v.name for v in self.session.get_outputs()] + + def get_character_list(self, key: str = "character"): + return self.meta_dict[key].splitlines() + + def have_key(self, key: str = "character") -> bool: + self.meta_dict = self.session.get_modelmeta().custom_metadata_map + if key in self.meta_dict.keys(): + return True + return False + + @staticmethod + def _verify_model(model_path): + model_path = Path(model_path) + if not model_path.exists(): + raise FileNotFoundError(f"{model_path} does not exists.") + if not model_path.is_file(): + raise FileExistsError(f"{model_path} is not a file.") + + +def split_to_mini_sentence(words: list, word_limit: int = 20): + assert word_limit > 1 + if len(words) <= word_limit: + return [words] + sentences = [] + length = len(words) + sentence_len = length // word_limit + for i in range(sentence_len): + sentences.append(words[i * word_limit : (i + 1) * word_limit]) + if length % word_limit > 0: + sentences.append(words[sentence_len * word_limit :]) + return sentences + + +def code_mix_split_words(text: str): + words = [] + segs = text.split() + for seg in segs: + # There is no space in seg. + current_word = "" + for c in seg: + if len(c.encode()) == 1: + # This is an ASCII char. + current_word += c + else: + # This is a Chinese char. + if len(current_word) > 0: + words.append(current_word) + current_word = "" + words.append(c) + if len(current_word) > 0: + words.append(current_word) + return words + + +def isEnglish(text: str): + if re.search("^[a-zA-Z']+$", text): + return True + else: + return False + + +def join_chinese_and_english(input_list): + line = "" + for token in input_list: + if isEnglish(token): + line = line + " " + token + else: + line = line + token + + line = line.strip() + return line + + +def code_mix_split_words_jieba(seg_dict_file: str): + jieba.load_userdict(seg_dict_file) + + def _fn(text: str): + input_list = text.split() + token_list_all = [] + langauge_list = [] + token_list_tmp = [] + language_flag = None + for token in input_list: + if isEnglish(token) and language_flag == "Chinese": + token_list_all.append(token_list_tmp) + langauge_list.append("Chinese") + token_list_tmp = [] + elif not isEnglish(token) and language_flag == "English": + token_list_all.append(token_list_tmp) + langauge_list.append("English") + token_list_tmp = [] + + token_list_tmp.append(token) + + if isEnglish(token): + language_flag = "English" + else: + language_flag = "Chinese" + + if token_list_tmp: + token_list_all.append(token_list_tmp) + langauge_list.append(language_flag) + + result_list = [] + for token_list_tmp, language_flag in zip(token_list_all, langauge_list): + if language_flag == "English": + result_list.extend(token_list_tmp) + else: + seg_list = jieba.cut(join_chinese_and_english(token_list_tmp), HMM=False) + result_list.extend(seg_list) + + return result_list + + return _fn + + +def read_yaml(yaml_path: Union[str, Path]) -> Dict: + if not Path(yaml_path).exists(): + raise FileExistsError(f"The {yaml_path} does not exist.") + + with open(str(yaml_path), "rb") as f: + data = yaml.load(f, Loader=yaml.Loader) + return data + + +@functools.lru_cache() +def get_logger(name="funasr_onnx"): + """Initialize and get a logger by name. + If the logger has not been initialized, this method will initialize the + logger by adding one or two handlers, otherwise the initialized logger will + be directly returned. During initialization, a StreamHandler will always be + added. + Args: + name (str): Logger name. + Returns: + logging.Logger: The expected logger. + """ + logger = logging.getLogger(name) + if name in logger_initialized: + return logger + + for logger_name in logger_initialized: + if name.startswith(logger_name): + return logger + + formatter = logging.Formatter( + "[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%Y/%m/%d %H:%M:%S" + ) + + sh = logging.StreamHandler() + sh.setFormatter(formatter) + logger.addHandler(sh) + logger_initialized[name] = True + logger.propagate = False + logging.basicConfig(level=logging.ERROR) + return logger diff --git a/utils/model_bin.py b/utils/model_bin.py new file mode 100644 index 0000000..f67be65 --- /dev/null +++ b/utils/model_bin.py @@ -0,0 +1,146 @@ +# -*- encoding: utf-8 -*- +# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. +# MIT License (https://opensource.org/licenses/MIT) + +import os.path +from pathlib import Path +from typing import List, Union, Tuple +import torch +import librosa +import numpy as np + +from utils.infer_utils import ( + CharTokenizer, + Hypothesis, + ONNXRuntimeError, + OrtInferSession, + TokenIDConverter, + get_logger, + read_yaml, +) +# from .utils.postprocess_utils import sentence_postprocess, sentence_postprocess_sentencepiece +from utils.frontend import WavFrontend +# from .utils.timestamp_utils import time_stamp_lfr6_onnx +from utils.infer_utils import pad_list + +logging = get_logger() + + +class SenseVoiceSmallONNX: + """ + Author: Speech Lab of DAMO Academy, Alibaba Group + Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition + https://arxiv.org/abs/2206.08317 + """ + + def __init__( + self, + model_dir: Union[str, Path] = None, + batch_size: int = 1, + device_id: Union[str, int] = "-1", + plot_timestamp_to: str = "", + quantize: bool = False, + intra_op_num_threads: int = 4, + cache_dir: str = None, + **kwargs, + ): + if quantize: + model_file = os.path.join(model_dir, "model_quant.onnx") + else: + model_file = os.path.join(model_dir, "model.onnx") + + config_file = os.path.join(model_dir, "config.yaml") + cmvn_file = os.path.join(model_dir, "am.mvn") + config = read_yaml(config_file) + # token_list = os.path.join(model_dir, "tokens.json") + # with open(token_list, "r", encoding="utf-8") as f: + # token_list = json.load(f) + + # self.converter = TokenIDConverter(token_list) + self.tokenizer = CharTokenizer() + config["frontend_conf"]['cmvn_file'] = cmvn_file + self.frontend = WavFrontend(**config["frontend_conf"]) + self.ort_infer = OrtInferSession( + model_file, device_id, intra_op_num_threads=intra_op_num_threads + ) + self.batch_size = batch_size + self.blank_id = 0 + + def __call__(self, + wav_content: Union[str, np.ndarray, List[str]], + language: List, + textnorm: List, + tokenizer=None, + **kwargs) -> List: + waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq) + waveform_nums = len(waveform_list) + asr_res = [] + for beg_idx in range(0, waveform_nums, self.batch_size): + end_idx = min(waveform_nums, beg_idx + self.batch_size) + feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx]) + ctc_logits, encoder_out_lens = self.infer(feats, + feats_len, + np.array(language, dtype=np.int32), + np.array(textnorm, dtype=np.int32) + ) + # back to torch.Tensor + ctc_logits = torch.from_numpy(ctc_logits).float() + # support batch_size=1 only currently + x = ctc_logits[0, : encoder_out_lens[0].item(), :] + yseq = x.argmax(dim=-1) + yseq = torch.unique_consecutive(yseq, dim=-1) + + mask = yseq != self.blank_id + token_int = yseq[mask].tolist() + + if tokenizer is not None: + asr_res.append(tokenizer.tokens2text(token_int)) + else: + asr_res.append(token_int) + return asr_res + + def load_data(self, wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List: + def load_wav(path: str) -> np.ndarray: + waveform, _ = librosa.load(path, sr=fs) + return waveform + + if isinstance(wav_content, np.ndarray): + return [wav_content] + + if isinstance(wav_content, str): + return [load_wav(wav_content)] + + if isinstance(wav_content, list): + return [load_wav(path) for path in wav_content] + + raise TypeError(f"The type of {wav_content} is not in [str, np.ndarray, list]") + + def extract_feat(self, waveform_list: List[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]: + feats, feats_len = [], [] + for waveform in waveform_list: + speech, _ = self.frontend.fbank(waveform) + feat, feat_len = self.frontend.lfr_cmvn(speech) + feats.append(feat) + feats_len.append(feat_len) + + feats = self.pad_feats(feats, np.max(feats_len)) + feats_len = np.array(feats_len).astype(np.int32) + return feats, feats_len + + @staticmethod + def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray: + def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray: + pad_width = ((0, max_feat_len - cur_len), (0, 0)) + return np.pad(feat, pad_width, "constant", constant_values=0) + + feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats] + feats = np.array(feat_res).astype(np.float32) + return feats + + def infer(self, + feats: np.ndarray, + feats_len: np.ndarray, + language: np.ndarray, + textnorm: np.ndarray,) -> Tuple[np.ndarray, np.ndarray]: + outputs = self.ort_infer([feats, feats_len, language, textnorm]) + return outputs