diff --git a/funasr/export/models/CT_Transformer.py b/funasr/export/models/CT_Transformer.py index ea6ff4f53..932e3afe6 100644 --- a/funasr/export/models/CT_Transformer.py +++ b/funasr/export/models/CT_Transformer.py @@ -10,7 +10,7 @@ from funasr.export.models.encoder.sanm_encoder import SANMVadEncoder as SANMVadE class CT_Transformer(nn.Module): """ - Author: Speech Lab, Alibaba Group, China + Author: Speech Lab of DAMO Academy, Alibaba Group CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection https://arxiv.org/pdf/2003.01309.pdf """ @@ -81,7 +81,7 @@ class CT_Transformer(nn.Module): class CT_Transformer_VadRealtime(nn.Module): """ - Author: Speech Lab, Alibaba Group, China + Author: Speech Lab of DAMO Academy, Alibaba Group CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection https://arxiv.org/pdf/2003.01309.pdf """ diff --git a/funasr/export/models/e2e_asr_paraformer.py b/funasr/export/models/e2e_asr_paraformer.py index 0db61e0c5..52ad320ac 100644 --- a/funasr/export/models/e2e_asr_paraformer.py +++ b/funasr/export/models/e2e_asr_paraformer.py @@ -19,7 +19,7 @@ from funasr.export.models.decoder.transformer_decoder import ParaformerDecoderSA class Paraformer(nn.Module): """ - Author: Speech Lab, Alibaba Group, China + Author: Speech Lab of DAMO Academy, Alibaba Group Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition https://arxiv.org/abs/2206.08317 """ @@ -112,7 +112,7 @@ class Paraformer(nn.Module): class BiCifParaformer(nn.Module): """ - Author: Speech Lab, Alibaba Group, China + Author: Speech Lab of DAMO Academy, Alibaba Group Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition https://arxiv.org/abs/2206.08317 """ diff --git a/funasr/models/decoder/contextual_decoder.py b/funasr/models/decoder/contextual_decoder.py index 3b462e712..78105ab31 100644 --- a/funasr/models/decoder/contextual_decoder.py +++ b/funasr/models/decoder/contextual_decoder.py @@ -102,7 +102,7 @@ class ContextualBiasDecoder(nn.Module): class ContextualParaformerDecoder(ParaformerSANMDecoder): """ - author: Speech Lab, Alibaba Group, China + Author: Speech Lab of DAMO Academy, Alibaba Group Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition https://arxiv.org/abs/2006.01713 """ diff --git a/funasr/models/decoder/sanm_decoder.py b/funasr/models/decoder/sanm_decoder.py index 463918a0f..18cd343e7 100644 --- a/funasr/models/decoder/sanm_decoder.py +++ b/funasr/models/decoder/sanm_decoder.py @@ -151,7 +151,7 @@ class DecoderLayerSANM(nn.Module): class FsmnDecoderSCAMAOpt(BaseTransformerDecoder): """ - author: Speech Lab, Alibaba Group, China + Author: Speech Lab of DAMO Academy, Alibaba Group SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition https://arxiv.org/abs/2006.01713 @@ -812,7 +812,7 @@ class FsmnDecoderSCAMAOpt(BaseTransformerDecoder): class ParaformerSANMDecoder(BaseTransformerDecoder): """ - author: Speech Lab, Alibaba Group, China + Author: Speech Lab of DAMO Academy, Alibaba Group Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition https://arxiv.org/abs/2006.01713 """ diff --git a/funasr/models/decoder/transformer_decoder.py b/funasr/models/decoder/transformer_decoder.py index 5f1bb2436..aed7f206d 100644 --- a/funasr/models/decoder/transformer_decoder.py +++ b/funasr/models/decoder/transformer_decoder.py @@ -405,7 +405,7 @@ class TransformerDecoder(BaseTransformerDecoder): class ParaformerDecoderSAN(BaseTransformerDecoder): """ - author: Speech Lab, Alibaba Group, China + Author: Speech Lab of DAMO Academy, Alibaba Group Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition https://arxiv.org/abs/2006.01713 """ diff --git a/funasr/models/e2e_asr_paraformer.py b/funasr/models/e2e_asr_paraformer.py index f1bb2bfc1..5c8560d00 100644 --- a/funasr/models/e2e_asr_paraformer.py +++ b/funasr/models/e2e_asr_paraformer.py @@ -44,7 +44,7 @@ else: class Paraformer(AbsESPnetModel): """ - Author: Speech Lab, Alibaba Group, China + Author: Speech Lab of DAMO Academy, Alibaba Group Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition https://arxiv.org/abs/2206.08317 """ @@ -612,7 +612,7 @@ class Paraformer(AbsESPnetModel): class ParaformerBert(Paraformer): """ - Author: Speech Lab, Alibaba Group, China + Author: Speech Lab of DAMO Academy, Alibaba Group Paraformer2: advanced paraformer with LFMMI and bert for non-autoregressive end-to-end speech recognition """ diff --git a/funasr/models/e2e_tp.py b/funasr/models/e2e_tp.py index 887439c5e..d1367ab98 100644 --- a/funasr/models/e2e_tp.py +++ b/funasr/models/e2e_tp.py @@ -32,7 +32,7 @@ else: class TimestampPredictor(AbsESPnetModel): """ - Author: Speech Lab, Alibaba Group, China + Author: Speech Lab of DAMO Academy, Alibaba Group """ def __init__( diff --git a/funasr/models/e2e_uni_asr.py b/funasr/models/e2e_uni_asr.py index ac4db329b..ca76244b9 100644 --- a/funasr/models/e2e_uni_asr.py +++ b/funasr/models/e2e_uni_asr.py @@ -40,7 +40,7 @@ else: class UniASR(AbsESPnetModel): """ - Author: Speech Lab, Alibaba Group, China + Author: Speech Lab of DAMO Academy, Alibaba Group """ def __init__( diff --git a/funasr/models/e2e_vad.py b/funasr/models/e2e_vad.py index 440a049bf..50ec47515 100644 --- a/funasr/models/e2e_vad.py +++ b/funasr/models/e2e_vad.py @@ -35,6 +35,11 @@ class VadDetectMode(Enum): class VADXOptions: + """ + Author: Speech Lab of DAMO Academy, Alibaba Group + Deep-FSMN for Large Vocabulary Continuous Speech Recognition + https://arxiv.org/abs/1803.05030 + """ def __init__( self, sample_rate: int = 16000, @@ -99,6 +104,11 @@ class VADXOptions: class E2EVadSpeechBufWithDoa(object): + """ + Author: Speech Lab of DAMO Academy, Alibaba Group + Deep-FSMN for Large Vocabulary Continuous Speech Recognition + https://arxiv.org/abs/1803.05030 + """ def __init__(self): self.start_ms = 0 self.end_ms = 0 @@ -117,6 +127,11 @@ class E2EVadSpeechBufWithDoa(object): class E2EVadFrameProb(object): + """ + Author: Speech Lab of DAMO Academy, Alibaba Group + Deep-FSMN for Large Vocabulary Continuous Speech Recognition + https://arxiv.org/abs/1803.05030 + """ def __init__(self): self.noise_prob = 0.0 self.speech_prob = 0.0 @@ -126,6 +141,11 @@ class E2EVadFrameProb(object): class WindowDetector(object): + """ + Author: Speech Lab of DAMO Academy, Alibaba Group + Deep-FSMN for Large Vocabulary Continuous Speech Recognition + https://arxiv.org/abs/1803.05030 + """ def __init__(self, window_size_ms: int, sil_to_speech_time: int, speech_to_sil_time: int, frame_size_ms: int): self.window_size_ms = window_size_ms @@ -192,6 +212,11 @@ class WindowDetector(object): class E2EVadModel(nn.Module): + """ + Author: Speech Lab of DAMO Academy, Alibaba Group + Deep-FSMN for Large Vocabulary Continuous Speech Recognition + https://arxiv.org/abs/1803.05030 + """ def __init__(self, encoder: FSMN, vad_post_args: Dict[str, Any], frontend=None): super(E2EVadModel, self).__init__() self.vad_opts = VADXOptions(**vad_post_args) diff --git a/funasr/models/encoder/opennmt_encoders/conv_encoder.py b/funasr/models/encoder/opennmt_encoders/conv_encoder.py index a33e0b718..eec854fe4 100644 --- a/funasr/models/encoder/opennmt_encoders/conv_encoder.py +++ b/funasr/models/encoder/opennmt_encoders/conv_encoder.py @@ -67,7 +67,7 @@ class EncoderLayer(nn.Module): class ConvEncoder(AbsEncoder): """ - author: Speech Lab, Alibaba Group, China + Author: Speech Lab of DAMO Academy, Alibaba Group Convolution encoder in OpenNMT framework """ diff --git a/funasr/models/encoder/opennmt_encoders/self_attention_encoder.py b/funasr/models/encoder/opennmt_encoders/self_attention_encoder.py index cf77bce4b..db30f085e 100644 --- a/funasr/models/encoder/opennmt_encoders/self_attention_encoder.py +++ b/funasr/models/encoder/opennmt_encoders/self_attention_encoder.py @@ -117,7 +117,7 @@ class EncoderLayer(nn.Module): class SelfAttentionEncoder(AbsEncoder): """ - author: Speech Lab, Alibaba Group, China + Author: Speech Lab of DAMO Academy, Alibaba Group Self attention encoder in OpenNMT framework """ diff --git a/funasr/models/encoder/sanm_encoder.py b/funasr/models/encoder/sanm_encoder.py index 2a3a35353..7ac912137 100644 --- a/funasr/models/encoder/sanm_encoder.py +++ b/funasr/models/encoder/sanm_encoder.py @@ -117,7 +117,7 @@ class EncoderLayerSANM(nn.Module): class SANMEncoder(AbsEncoder): """ - author: Speech Lab, Alibaba Group, China + Author: Speech Lab of DAMO Academy, Alibaba Group San-m: Memory equipped self-attention for end-to-end speech recognition https://arxiv.org/abs/2006.01713 @@ -549,7 +549,7 @@ class SANMEncoder(AbsEncoder): class SANMEncoderChunkOpt(AbsEncoder): """ - author: Speech Lab, Alibaba Group, China + Author: Speech Lab of DAMO Academy, Alibaba Group SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition https://arxiv.org/abs/2006.01713 @@ -962,7 +962,7 @@ class SANMEncoderChunkOpt(AbsEncoder): class SANMVadEncoder(AbsEncoder): """ - author: Speech Lab, Alibaba Group, China + Author: Speech Lab of DAMO Academy, Alibaba Group """ diff --git a/funasr/models/target_delay_transformer.py b/funasr/models/target_delay_transformer.py index 8cd435747..e893c657f 100644 --- a/funasr/models/target_delay_transformer.py +++ b/funasr/models/target_delay_transformer.py @@ -14,7 +14,7 @@ from funasr.train.abs_model import AbsPunctuation class TargetDelayTransformer(AbsPunctuation): """ - Author: Speech Lab, Alibaba Group, China + Author: Speech Lab of DAMO Academy, Alibaba Group CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection https://arxiv.org/pdf/2003.01309.pdf """ diff --git a/funasr/models/vad_realtime_transformer.py b/funasr/models/vad_realtime_transformer.py index 381067252..fe298ce83 100644 --- a/funasr/models/vad_realtime_transformer.py +++ b/funasr/models/vad_realtime_transformer.py @@ -12,7 +12,7 @@ from funasr.train.abs_model import AbsPunctuation class VadRealtimeTransformer(AbsPunctuation): """ - Author: Speech Lab, Alibaba Group, China + Author: Speech Lab of DAMO Academy, Alibaba Group CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection https://arxiv.org/pdf/2003.01309.pdf """ diff --git a/funasr/modules/streaming_utils/chunk_utilis.py b/funasr/modules/streaming_utils/chunk_utilis.py index ea37c68cc..ed8b31eb2 100644 --- a/funasr/modules/streaming_utils/chunk_utilis.py +++ b/funasr/modules/streaming_utils/chunk_utilis.py @@ -11,7 +11,7 @@ from funasr.modules.streaming_utils.utils import sequence_mask class overlap_chunk(): """ - author: Speech Lab, Alibaba Group, China + Author: Speech Lab of DAMO Academy, Alibaba Group San-m: Memory equipped self-attention for end-to-end speech recognition https://arxiv.org/abs/2006.01713 diff --git a/funasr/runtime/python/onnxruntime/demo_vad_offline.py b/funasr/runtime/python/onnxruntime/demo_vad_offline.py index 69ca94543..ea76ec34c 100644 --- a/funasr/runtime/python/onnxruntime/demo_vad_offline.py +++ b/funasr/runtime/python/onnxruntime/demo_vad_offline.py @@ -1,5 +1,5 @@ import soundfile -from funasr_onnx.vad_bin import Fsmn_vad +from funasr_onnx import Fsmn_vad model_dir = "/mnt/ailsa.zly/tfbase/espnet_work/FunASR_dev_zly/export/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch" diff --git a/funasr/runtime/python/onnxruntime/demo_vad_online.py b/funasr/runtime/python/onnxruntime/demo_vad_online.py index 15e62da2b..1ab4d9d73 100644 --- a/funasr/runtime/python/onnxruntime/demo_vad_online.py +++ b/funasr/runtime/python/onnxruntime/demo_vad_online.py @@ -1,10 +1,10 @@ import soundfile -from funasr_onnx.vad_online_bin import Fsmn_vad +from funasr_onnx import Fsmn_vad_online model_dir = "/mnt/ailsa.zly/tfbase/espnet_work/FunASR_dev_zly/export/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch" wav_path = "/mnt/ailsa.zly/tfbase/espnet_work/FunASR_dev_zly/egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/vad_example_16k.wav" -model = Fsmn_vad(model_dir) +model = Fsmn_vad_online(model_dir) ##online vad diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/__init__.py b/funasr/runtime/python/onnxruntime/funasr_onnx/__init__.py index 86f0e8e52..7d8d6620f 100644 --- a/funasr/runtime/python/onnxruntime/funasr_onnx/__init__.py +++ b/funasr/runtime/python/onnxruntime/funasr_onnx/__init__.py @@ -1,5 +1,6 @@ # -*- encoding: utf-8 -*- from .paraformer_bin import Paraformer from .vad_bin import Fsmn_vad +from .vad_bin import Fsmn_vad_online from .punc_bin import CT_Transformer from .punc_bin import CT_Transformer_VadRealtime diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py b/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py index 2f1b3b76e..bbbb9133e 100644 --- a/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py +++ b/funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py @@ -14,7 +14,7 @@ logging = get_logger() class CT_Transformer(): """ - Author: Speech Lab, Alibaba Group, China + Author: Speech Lab of DAMO Academy, Alibaba Group CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection https://arxiv.org/pdf/2003.01309.pdf """ @@ -125,7 +125,7 @@ class CT_Transformer(): class CT_Transformer_VadRealtime(CT_Transformer): """ - Author: Speech Lab, Alibaba Group, China + Author: Speech Lab of DAMO Academy, Alibaba Group CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection https://arxiv.org/pdf/2003.01309.pdf """ diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/vad_bin.py b/funasr/runtime/python/onnxruntime/funasr_onnx/vad_bin.py index 5ad426687..ab8f0412f 100644 --- a/funasr/runtime/python/onnxruntime/funasr_onnx/vad_bin.py +++ b/funasr/runtime/python/onnxruntime/funasr_onnx/vad_bin.py @@ -11,13 +11,18 @@ import numpy as np from .utils.utils import (ONNXRuntimeError, OrtInferSession, get_logger, read_yaml) -from .utils.frontend import WavFrontend +from .utils.frontend import WavFrontend, WavFrontendOnline from .utils.e2e_vad import E2EVadModel logging = get_logger() class Fsmn_vad(): + """ + Author: Speech Lab of DAMO Academy, Alibaba Group + Deep-FSMN for Large Vocabulary Continuous Speech Recognition + https://arxiv.org/abs/1803.05030 + """ def __init__(self, model_dir: Union[str, Path] = None, batch_size: int = 1, device_id: Union[str, int] = "-1", @@ -151,4 +156,125 @@ class Fsmn_vad(): outputs = self.ort_infer(feats) scores, out_caches = outputs[0], outputs[1:] return scores, out_caches + + +class Fsmn_vad_online(): + """ + Author: Speech Lab of DAMO Academy, Alibaba Group + Deep-FSMN for Large Vocabulary Continuous Speech Recognition + https://arxiv.org/abs/1803.05030 + """ + def __init__(self, model_dir: Union[str, Path] = None, + batch_size: int = 1, + device_id: Union[str, int] = "-1", + quantize: bool = False, + intra_op_num_threads: int = 4, + max_end_sil: int = None, + ): + + if not Path(model_dir).exists(): + raise FileNotFoundError(f'{model_dir} does not exist.') + + model_file = os.path.join(model_dir, 'model.onnx') + if quantize: + model_file = os.path.join(model_dir, 'model_quant.onnx') + config_file = os.path.join(model_dir, 'vad.yaml') + cmvn_file = os.path.join(model_dir, 'vad.mvn') + config = read_yaml(config_file) + + self.frontend = WavFrontendOnline( + cmvn_file=cmvn_file, + **config['frontend_conf'] + ) + self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads) + self.batch_size = batch_size + self.vad_scorer = E2EVadModel(config["vad_post_conf"]) + self.max_end_sil = max_end_sil if max_end_sil is not None else config["vad_post_conf"]["max_end_silence_time"] + self.encoder_conf = config["encoder_conf"] + def prepare_cache(self, in_cache: list = []): + if len(in_cache) > 0: + return in_cache + fsmn_layers = self.encoder_conf["fsmn_layers"] + proj_dim = self.encoder_conf["proj_dim"] + lorder = self.encoder_conf["lorder"] + for i in range(fsmn_layers): + cache = np.zeros((1, proj_dim, lorder - 1, 1)).astype(np.float32) + in_cache.append(cache) + return in_cache + + def __call__(self, audio_in: np.ndarray, **kwargs) -> List: + waveforms = np.expand_dims(audio_in, axis=0) + + param_dict = kwargs.get('param_dict', dict()) + is_final = param_dict.get('is_final', False) + feats, feats_len = self.extract_feat(waveforms, is_final) + segments = [] + if feats.size != 0: + in_cache = param_dict.get('in_cache', list()) + in_cache = self.prepare_cache(in_cache) + try: + inputs = [feats] + inputs.extend(in_cache) + scores, out_caches = self.infer(inputs) + param_dict['in_cache'] = out_caches + waveforms = self.frontend.get_waveforms() + segments = self.vad_scorer(scores, waveforms, is_final=is_final, max_end_sil=self.max_end_sil, + online=True) + + + except ONNXRuntimeError: + # logging.warning(traceback.format_exc()) + logging.warning("input wav is silence or noise") + segments = [] + return segments + + def load_data(self, + wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List: + def load_wav(path: str) -> np.ndarray: + waveform, _ = librosa.load(path, sr=fs) + return waveform + + if isinstance(wav_content, np.ndarray): + return [wav_content] + + if isinstance(wav_content, str): + return [load_wav(wav_content)] + + if isinstance(wav_content, list): + return [load_wav(path) for path in wav_content] + + raise TypeError( + f'The type of {wav_content} is not in [str, np.ndarray, list]') + + def extract_feat(self, + waveforms: np.ndarray, is_final: bool = False + ) -> Tuple[np.ndarray, np.ndarray]: + waveforms_lens = np.zeros(waveforms.shape[0]).astype(np.int32) + for idx, waveform in enumerate(waveforms): + waveforms_lens[idx] = waveform.shape[-1] + + feats, feats_len = self.frontend.extract_fbank(waveforms, waveforms_lens, is_final) + # feats.append(feat) + # feats_len.append(feat_len) + + # feats = self.pad_feats(feats, np.max(feats_len)) + # feats_len = np.array(feats_len).astype(np.int32) + return feats.astype(np.float32), feats_len.astype(np.int32) + + @staticmethod + def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray: + def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray: + pad_width = ((0, max_feat_len - cur_len), (0, 0)) + return np.pad(feat, pad_width, 'constant', constant_values=0) + + feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats] + feats = np.array(feat_res).astype(np.float32) + return feats + + def infer(self, feats: List) -> Tuple[np.ndarray, np.ndarray]: + + outputs = self.ort_infer(feats) + scores, out_caches = outputs[0], outputs[1:] + return scores, out_caches + diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/vad_online_bin.py b/funasr/runtime/python/onnxruntime/funasr_onnx/vad_online_bin.py deleted file mode 100644 index 83e9420e6..000000000 --- a/funasr/runtime/python/onnxruntime/funasr_onnx/vad_online_bin.py +++ /dev/null @@ -1,134 +0,0 @@ -# -*- encoding: utf-8 -*- - -import os.path -from pathlib import Path -from typing import List, Union, Tuple - -import copy -import librosa -import numpy as np - -from .utils.utils import (ONNXRuntimeError, - OrtInferSession, get_logger, - read_yaml) -from .utils.frontend import WavFrontendOnline -from .utils.e2e_vad import E2EVadModel - -logging = get_logger() - - -class Fsmn_vad(): - def __init__(self, model_dir: Union[str, Path] = None, - batch_size: int = 1, - device_id: Union[str, int] = "-1", - quantize: bool = False, - intra_op_num_threads: int = 4, - max_end_sil: int = None, - ): - - if not Path(model_dir).exists(): - raise FileNotFoundError(f'{model_dir} does not exist.') - - model_file = os.path.join(model_dir, 'model.onnx') - if quantize: - model_file = os.path.join(model_dir, 'model_quant.onnx') - config_file = os.path.join(model_dir, 'vad.yaml') - cmvn_file = os.path.join(model_dir, 'vad.mvn') - config = read_yaml(config_file) - - self.frontend = WavFrontendOnline( - cmvn_file=cmvn_file, - **config['frontend_conf'] - ) - self.ort_infer = OrtInferSession(model_file, device_id, intra_op_num_threads=intra_op_num_threads) - self.batch_size = batch_size - self.vad_scorer = E2EVadModel(config["vad_post_conf"]) - self.max_end_sil = max_end_sil if max_end_sil is not None else config["vad_post_conf"]["max_end_silence_time"] - self.encoder_conf = config["encoder_conf"] - - def prepare_cache(self, in_cache: list = []): - if len(in_cache) > 0: - return in_cache - fsmn_layers = self.encoder_conf["fsmn_layers"] - proj_dim = self.encoder_conf["proj_dim"] - lorder = self.encoder_conf["lorder"] - for i in range(fsmn_layers): - cache = np.zeros((1, proj_dim, lorder-1, 1)).astype(np.float32) - in_cache.append(cache) - return in_cache - - - def __call__(self, audio_in: np.ndarray, **kwargs) -> List: - waveforms = np.expand_dims(audio_in, axis=0) - - param_dict = kwargs.get('param_dict', dict()) - is_final = param_dict.get('is_final', False) - feats, feats_len = self.extract_feat(waveforms, is_final) - segments = [] - if feats.size != 0: - in_cache = param_dict.get('in_cache', list()) - in_cache = self.prepare_cache(in_cache) - try: - inputs = [feats] - inputs.extend(in_cache) - scores, out_caches = self.infer(inputs) - param_dict['in_cache'] = out_caches - waveforms = self.frontend.get_waveforms() - segments = self.vad_scorer(scores, waveforms, is_final=is_final, max_end_sil=self.max_end_sil, online=True) - - - except ONNXRuntimeError: - logging.warning(traceback.format_exc()) - logging.warning("input wav is silence or noise") - segments = [] - return segments - - def load_data(self, - wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List: - def load_wav(path: str) -> np.ndarray: - waveform, _ = librosa.load(path, sr=fs) - return waveform - - if isinstance(wav_content, np.ndarray): - return [wav_content] - - if isinstance(wav_content, str): - return [load_wav(wav_content)] - - if isinstance(wav_content, list): - return [load_wav(path) for path in wav_content] - - raise TypeError( - f'The type of {wav_content} is not in [str, np.ndarray, list]') - - def extract_feat(self, - waveforms: np.ndarray, is_final: bool = False - ) -> Tuple[np.ndarray, np.ndarray]: - waveforms_lens = np.zeros(waveforms.shape[0]).astype(np.int32) - for idx, waveform in enumerate(waveforms): - waveforms_lens[idx] = waveform.shape[-1] - - feats, feats_len = self.frontend.extract_fbank(waveforms, waveforms_lens, is_final) - # feats.append(feat) - # feats_len.append(feat_len) - - # feats = self.pad_feats(feats, np.max(feats_len)) - # feats_len = np.array(feats_len).astype(np.int32) - return feats.astype(np.float32), feats_len.astype(np.int32) - - @staticmethod - def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray: - def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray: - pad_width = ((0, max_feat_len - cur_len), (0, 0)) - return np.pad(feat, pad_width, 'constant', constant_values=0) - - feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats] - feats = np.array(feat_res).astype(np.float32) - return feats - - def infer(self, feats: List) -> Tuple[np.ndarray, np.ndarray]: - - outputs = self.ort_infer(feats) - scores, out_caches = outputs[0], outputs[1:] - return scores, out_caches - diff --git a/funasr/runtime/python/onnxruntime/setup.py b/funasr/runtime/python/onnxruntime/setup.py index 1a8ed7b31..1e1c6b174 100644 --- a/funasr/runtime/python/onnxruntime/setup.py +++ b/funasr/runtime/python/onnxruntime/setup.py @@ -13,7 +13,7 @@ def get_readme(): MODULE_NAME = 'funasr_onnx' -VERSION_NUM = '0.0.3' +VERSION_NUM = '0.0.4' setuptools.setup( name=MODULE_NAME,