mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
Fsmn_vad支持多线程并发调用 (#2519)
* 修复WavFrontend.fbank多线程调用下共享fbank_fn导致的bug * Fsmn_vad支持多线程并发调用 --------- Co-authored-by: wangmengdi06 <wangmengdi06@58.com>
This commit is contained in:
parent
ae013cf597
commit
fe588bc508
@ -52,12 +52,12 @@ class WavFrontend:
|
||||
|
||||
def fbank(self, waveform: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
||||
waveform = waveform * (1 << 15)
|
||||
self.fbank_fn = knf.OnlineFbank(self.opts)
|
||||
self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
|
||||
frames = self.fbank_fn.num_frames_ready
|
||||
fbank_fn = knf.OnlineFbank(self.opts)
|
||||
fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
|
||||
frames = fbank_fn.num_frames_ready
|
||||
mat = np.empty([frames, self.opts.mel_opts.num_bins])
|
||||
for i in range(frames):
|
||||
mat[i, :] = self.fbank_fn.get_frame(i)
|
||||
mat[i, :] = fbank_fn.get_frame(i)
|
||||
feat = mat.astype(np.float32)
|
||||
feat_len = np.array(mat.shape[0]).astype(np.int32)
|
||||
return feat, feat_len
|
||||
|
||||
@ -69,7 +69,7 @@ class Fsmn_vad:
|
||||
model_file, device_id, intra_op_num_threads=intra_op_num_threads
|
||||
)
|
||||
self.batch_size = batch_size
|
||||
self.vad_scorer = E2EVadModel(config["model_conf"])
|
||||
self.vad_scorer_config = config["model_conf"]
|
||||
self.max_end_sil = (
|
||||
max_end_sil if max_end_sil is not None else config["model_conf"]["max_end_silence_time"]
|
||||
)
|
||||
@ -90,10 +90,9 @@ class Fsmn_vad:
|
||||
waveform_list = self.load_data(audio_in, self.frontend.opts.frame_opts.samp_freq)
|
||||
waveform_nums = len(waveform_list)
|
||||
is_final = kwargs.get("kwargs", False)
|
||||
|
||||
segments = [[]] * self.batch_size
|
||||
for beg_idx in range(0, waveform_nums, self.batch_size):
|
||||
|
||||
vad_scorer = E2EVadModel(self.vad_scorer_config)
|
||||
end_idx = min(waveform_nums, beg_idx + self.batch_size)
|
||||
waveform = waveform_list[beg_idx:end_idx]
|
||||
feats, feats_len = self.extract_feat(waveform)
|
||||
@ -122,7 +121,7 @@ class Fsmn_vad:
|
||||
inputs.extend(in_cache)
|
||||
scores, out_caches = self.infer(inputs)
|
||||
in_cache = out_caches
|
||||
segments_part = self.vad_scorer(
|
||||
segments_part = vad_scorer(
|
||||
scores,
|
||||
waveform_package,
|
||||
is_final=is_final,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user