support max_end_sil

This commit is contained in:
凌匀 2023-03-24 11:11:41 +08:00
parent bda3527dbb
commit 946000a29a
4 changed files with 11 additions and 7 deletions

View File

@ -22,7 +22,7 @@ if __name__ == '__main__':
sample_offset = 0 sample_offset = 0
step = 160 * 10 step = 160 * 10
param_dict = {'in_cache': dict()} param_dict = {'in_cache': dict(), 'max_end_sil': 800}
for sample_offset in range(0, speech_length, min(step, speech_length - sample_offset)): for sample_offset in range(0, speech_length, min(step, speech_length - sample_offset)):
if sample_offset + step >= speech_length - 1: if sample_offset + step >= speech_length - 1:
step = speech_length - sample_offset step = speech_length - sample_offset

View File

@ -22,7 +22,7 @@ if __name__ == '__main__':
sample_offset = 0 sample_offset = 0
step = 80 * 10 step = 80 * 10
param_dict = {'in_cache': dict()} param_dict = {'in_cache': dict(), 'max_end_sil': 800}
for sample_offset in range(0, speech_length, min(step, speech_length - sample_offset)): for sample_offset in range(0, speech_length, min(step, speech_length - sample_offset)):
if sample_offset + step >= speech_length - 1: if sample_offset + step >= speech_length - 1:
step = speech_length - sample_offset step = speech_length - sample_offset

View File

@ -1,6 +1,5 @@
import argparse import argparse
import logging import logging
import os
import sys import sys
import json import json
from pathlib import Path from pathlib import Path
@ -30,7 +29,8 @@ from funasr.models.frontend.wav_frontend import WavFrontendOnline
from funasr.models.frontend.wav_frontend import WavFrontend from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.bin.vad_inference import Speech2VadSegment from funasr.bin.vad_inference import Speech2VadSegment
header_colors = '\033[95m'
end_colors = '\033[0m'
class Speech2VadSegmentOnline(Speech2VadSegment): class Speech2VadSegmentOnline(Speech2VadSegment):
@ -55,7 +55,7 @@ class Speech2VadSegmentOnline(Speech2VadSegment):
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(
self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None, self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
in_cache: Dict[str, torch.Tensor] = dict(), is_final: bool = False in_cache: Dict[str, torch.Tensor] = dict(), is_final: bool = False, max_end_sil: int = 800
) -> Tuple[torch.Tensor, List[List[int]], torch.Tensor]: ) -> Tuple[torch.Tensor, List[List[int]], torch.Tensor]:
"""Inference """Inference
@ -86,7 +86,8 @@ class Speech2VadSegmentOnline(Speech2VadSegment):
"feats": feats, "feats": feats,
"waveform": waveforms, "waveform": waveforms,
"in_cache": in_cache, "in_cache": in_cache,
"is_final": is_final "is_final": is_final,
"max_end_sil": max_end_sil
} }
# a. To device # a. To device
batch = to_device(batch, device=self.device) batch = to_device(batch, device=self.device)
@ -217,6 +218,7 @@ def inference_modelscope(
vad_results = [] vad_results = []
batch_in_cache = param_dict['in_cache'] if param_dict is not None else dict() batch_in_cache = param_dict['in_cache'] if param_dict is not None else dict()
is_final = param_dict['is_final'] if param_dict is not None else False is_final = param_dict['is_final'] if param_dict is not None else False
max_end_sil = param_dict['max_end_sil'] if param_dict is not None else 800
for keys, batch in loader: for keys, batch in loader:
assert isinstance(batch, dict), type(batch) assert isinstance(batch, dict), type(batch)
assert all(isinstance(s, str) for s in keys), keys assert all(isinstance(s, str) for s in keys), keys
@ -224,6 +226,7 @@ def inference_modelscope(
assert len(keys) == _bs, f"{len(keys)} != {_bs}" assert len(keys) == _bs, f"{len(keys)} != {_bs}"
batch['in_cache'] = batch_in_cache batch['in_cache'] = batch_in_cache
batch['is_final'] = is_final batch['is_final'] = is_final
batch['max_end_sil'] = max_end_sil
# do vad segment # do vad segment
_, results, param_dict['in_cache'] = speech2vadsegment(**batch) _, results, param_dict['in_cache'] = speech2vadsegment(**batch)

3
funasr/models/e2e_vad.py Executable file → Normal file
View File

@ -473,8 +473,9 @@ class E2EVadModel(nn.Module):
return segments, in_cache return segments, in_cache
def forward_online(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(), def forward_online(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
is_final: bool = False is_final: bool = False, max_end_sil: int = 800
) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]: ) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
self.max_end_sil_frame_cnt_thresh = max_end_sil - self.vad_opts.speech_to_sil_time_thres
self.waveform = waveform # compute decibel for each frame self.waveform = waveform # compute decibel for each frame
self.ComputeDecibel() self.ComputeDecibel()
self.ComputeScores(feats, in_cache) self.ComputeScores(feats, in_cache)