mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
support max_end_sil
This commit is contained in:
parent
bda3527dbb
commit
946000a29a
@ -22,7 +22,7 @@ if __name__ == '__main__':
|
|||||||
sample_offset = 0
|
sample_offset = 0
|
||||||
|
|
||||||
step = 160 * 10
|
step = 160 * 10
|
||||||
param_dict = {'in_cache': dict()}
|
param_dict = {'in_cache': dict(), 'max_end_sil': 800}
|
||||||
for sample_offset in range(0, speech_length, min(step, speech_length - sample_offset)):
|
for sample_offset in range(0, speech_length, min(step, speech_length - sample_offset)):
|
||||||
if sample_offset + step >= speech_length - 1:
|
if sample_offset + step >= speech_length - 1:
|
||||||
step = speech_length - sample_offset
|
step = speech_length - sample_offset
|
||||||
|
|||||||
@ -22,7 +22,7 @@ if __name__ == '__main__':
|
|||||||
sample_offset = 0
|
sample_offset = 0
|
||||||
|
|
||||||
step = 80 * 10
|
step = 80 * 10
|
||||||
param_dict = {'in_cache': dict()}
|
param_dict = {'in_cache': dict(), 'max_end_sil': 800}
|
||||||
for sample_offset in range(0, speech_length, min(step, speech_length - sample_offset)):
|
for sample_offset in range(0, speech_length, min(step, speech_length - sample_offset)):
|
||||||
if sample_offset + step >= speech_length - 1:
|
if sample_offset + step >= speech_length - 1:
|
||||||
step = speech_length - sample_offset
|
step = speech_length - sample_offset
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import os
|
|
||||||
import sys
|
import sys
|
||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -30,7 +29,8 @@ from funasr.models.frontend.wav_frontend import WavFrontendOnline
|
|||||||
from funasr.models.frontend.wav_frontend import WavFrontend
|
from funasr.models.frontend.wav_frontend import WavFrontend
|
||||||
from funasr.bin.vad_inference import Speech2VadSegment
|
from funasr.bin.vad_inference import Speech2VadSegment
|
||||||
|
|
||||||
|
header_colors = '\033[95m'
|
||||||
|
end_colors = '\033[0m'
|
||||||
|
|
||||||
|
|
||||||
class Speech2VadSegmentOnline(Speech2VadSegment):
|
class Speech2VadSegmentOnline(Speech2VadSegment):
|
||||||
@ -55,7 +55,7 @@ class Speech2VadSegmentOnline(Speech2VadSegment):
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def __call__(
|
def __call__(
|
||||||
self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
|
self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
|
||||||
in_cache: Dict[str, torch.Tensor] = dict(), is_final: bool = False
|
in_cache: Dict[str, torch.Tensor] = dict(), is_final: bool = False, max_end_sil: int = 800
|
||||||
) -> Tuple[torch.Tensor, List[List[int]], torch.Tensor]:
|
) -> Tuple[torch.Tensor, List[List[int]], torch.Tensor]:
|
||||||
"""Inference
|
"""Inference
|
||||||
|
|
||||||
@ -86,7 +86,8 @@ class Speech2VadSegmentOnline(Speech2VadSegment):
|
|||||||
"feats": feats,
|
"feats": feats,
|
||||||
"waveform": waveforms,
|
"waveform": waveforms,
|
||||||
"in_cache": in_cache,
|
"in_cache": in_cache,
|
||||||
"is_final": is_final
|
"is_final": is_final,
|
||||||
|
"max_end_sil": max_end_sil
|
||||||
}
|
}
|
||||||
# a. To device
|
# a. To device
|
||||||
batch = to_device(batch, device=self.device)
|
batch = to_device(batch, device=self.device)
|
||||||
@ -217,6 +218,7 @@ def inference_modelscope(
|
|||||||
vad_results = []
|
vad_results = []
|
||||||
batch_in_cache = param_dict['in_cache'] if param_dict is not None else dict()
|
batch_in_cache = param_dict['in_cache'] if param_dict is not None else dict()
|
||||||
is_final = param_dict['is_final'] if param_dict is not None else False
|
is_final = param_dict['is_final'] if param_dict is not None else False
|
||||||
|
max_end_sil = param_dict['max_end_sil'] if param_dict is not None else 800
|
||||||
for keys, batch in loader:
|
for keys, batch in loader:
|
||||||
assert isinstance(batch, dict), type(batch)
|
assert isinstance(batch, dict), type(batch)
|
||||||
assert all(isinstance(s, str) for s in keys), keys
|
assert all(isinstance(s, str) for s in keys), keys
|
||||||
@ -224,6 +226,7 @@ def inference_modelscope(
|
|||||||
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
|
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
|
||||||
batch['in_cache'] = batch_in_cache
|
batch['in_cache'] = batch_in_cache
|
||||||
batch['is_final'] = is_final
|
batch['is_final'] = is_final
|
||||||
|
batch['max_end_sil'] = max_end_sil
|
||||||
|
|
||||||
# do vad segment
|
# do vad segment
|
||||||
_, results, param_dict['in_cache'] = speech2vadsegment(**batch)
|
_, results, param_dict['in_cache'] = speech2vadsegment(**batch)
|
||||||
|
|||||||
3
funasr/models/e2e_vad.py
Executable file → Normal file
3
funasr/models/e2e_vad.py
Executable file → Normal file
@ -473,8 +473,9 @@ class E2EVadModel(nn.Module):
|
|||||||
return segments, in_cache
|
return segments, in_cache
|
||||||
|
|
||||||
def forward_online(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
|
def forward_online(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
|
||||||
is_final: bool = False
|
is_final: bool = False, max_end_sil: int = 800
|
||||||
) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
|
) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
|
||||||
|
self.max_end_sil_frame_cnt_thresh = max_end_sil - self.vad_opts.speech_to_sil_time_thres
|
||||||
self.waveform = waveform # compute decibel for each frame
|
self.waveform = waveform # compute decibel for each frame
|
||||||
self.ComputeDecibel()
|
self.ComputeDecibel()
|
||||||
self.ComputeScores(feats, in_cache)
|
self.ComputeScores(feats, in_cache)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user