diff --git a/runtime/python/onnxruntime/funasr_onnx/utils/frontend.py b/runtime/python/onnxruntime/funasr_onnx/utils/frontend.py index 7b38f8b76..f4cc8fff8 100644 --- a/runtime/python/onnxruntime/funasr_onnx/utils/frontend.py +++ b/runtime/python/onnxruntime/funasr_onnx/utils/frontend.py @@ -52,12 +52,12 @@ class WavFrontend: def fbank(self, waveform: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: waveform = waveform * (1 << 15) - self.fbank_fn = knf.OnlineFbank(self.opts) - self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist()) - frames = self.fbank_fn.num_frames_ready + fbank_fn = knf.OnlineFbank(self.opts) + fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist()) + frames = fbank_fn.num_frames_ready mat = np.empty([frames, self.opts.mel_opts.num_bins]) for i in range(frames): - mat[i, :] = self.fbank_fn.get_frame(i) + mat[i, :] = fbank_fn.get_frame(i) feat = mat.astype(np.float32) feat_len = np.array(mat.shape[0]).astype(np.int32) return feat, feat_len diff --git a/runtime/python/onnxruntime/funasr_onnx/vad_bin.py b/runtime/python/onnxruntime/funasr_onnx/vad_bin.py index 3f21004bf..af4663a7d 100644 --- a/runtime/python/onnxruntime/funasr_onnx/vad_bin.py +++ b/runtime/python/onnxruntime/funasr_onnx/vad_bin.py @@ -69,7 +69,7 @@ class Fsmn_vad: model_file, device_id, intra_op_num_threads=intra_op_num_threads ) self.batch_size = batch_size - self.vad_scorer = E2EVadModel(config["model_conf"]) + self.vad_scorer_config = config["model_conf"] self.max_end_sil = ( max_end_sil if max_end_sil is not None else config["model_conf"]["max_end_silence_time"] ) @@ -90,10 +90,9 @@ class Fsmn_vad: waveform_list = self.load_data(audio_in, self.frontend.opts.frame_opts.samp_freq) waveform_nums = len(waveform_list) is_final = kwargs.get("kwargs", False) - segments = [[]] * self.batch_size for beg_idx in range(0, waveform_nums, self.batch_size): - + vad_scorer = E2EVadModel(self.vad_scorer_config) end_idx = min(waveform_nums, beg_idx + self.batch_size) waveform = waveform_list[beg_idx:end_idx] feats, feats_len = self.extract_feat(waveform) @@ -122,7 +121,7 @@ class Fsmn_vad: inputs.extend(in_cache) scores, out_caches = self.infer(inputs) in_cache = out_caches - segments_part = self.vad_scorer( + segments_part = vad_scorer( scores, waveform_package, is_final=is_final,