Fsmn_vad支持多线程并发调用 (#2519)

* 修复WavFrontend.fbank多线程调用下共享fbank_fn导致的bug

* Fsmn_vad支持多线程并发调用

---------

Co-authored-by: wangmengdi06 <wangmengdi06@58.com>
This commit is contained in:
王梦迪 2025-05-20 16:10:59 +08:00 committed by GitHub
parent ae013cf597
commit fe588bc508
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 7 additions and 8 deletions

View File

@ -52,12 +52,12 @@ class WavFrontend:
def fbank(self, waveform: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
waveform = waveform * (1 << 15)
self.fbank_fn = knf.OnlineFbank(self.opts)
self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
frames = self.fbank_fn.num_frames_ready
fbank_fn = knf.OnlineFbank(self.opts)
fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
frames = fbank_fn.num_frames_ready
mat = np.empty([frames, self.opts.mel_opts.num_bins])
for i in range(frames):
mat[i, :] = self.fbank_fn.get_frame(i)
mat[i, :] = fbank_fn.get_frame(i)
feat = mat.astype(np.float32)
feat_len = np.array(mat.shape[0]).astype(np.int32)
return feat, feat_len

View File

@ -69,7 +69,7 @@ class Fsmn_vad:
model_file, device_id, intra_op_num_threads=intra_op_num_threads
)
self.batch_size = batch_size
self.vad_scorer = E2EVadModel(config["model_conf"])
self.vad_scorer_config = config["model_conf"]
self.max_end_sil = (
max_end_sil if max_end_sil is not None else config["model_conf"]["max_end_silence_time"]
)
@ -90,10 +90,9 @@ class Fsmn_vad:
waveform_list = self.load_data(audio_in, self.frontend.opts.frame_opts.samp_freq)
waveform_nums = len(waveform_list)
is_final = kwargs.get("kwargs", False)
segments = [[]] * self.batch_size
for beg_idx in range(0, waveform_nums, self.batch_size):
vad_scorer = E2EVadModel(self.vad_scorer_config)
end_idx = min(waveform_nums, beg_idx + self.batch_size)
waveform = waveform_list[beg_idx:end_idx]
feats, feats_len = self.extract_feat(waveform)
@ -122,7 +121,7 @@ class Fsmn_vad:
inputs.extend(in_cache)
scores, out_caches = self.infer(inputs)
in_cache = out_caches
segments_part = self.vad_scorer(
segments_part = vad_scorer(
scores,
waveform_package,
is_final=is_final,