funasr1.0 streaming

This commit is contained in:
游雁 2024-01-12 22:48:30 +08:00
parent c3c78fc5e7
commit c0b186b5b6
15 changed files with 245 additions and 1389 deletions

View File

@ -8,7 +8,7 @@ from funasr import AutoModel
model = AutoModel(model="damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
model_revision="v2.0.0",
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
vad_model_revision="v2.0.0",
vad_model_revision="v2.0.1",
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
punc_model_revision="v2.0.0",
spk_model="/Users/shixian/code/modelscope_models/speech_campplus_sv_zh-cn_16k-common",
@ -21,7 +21,7 @@ print(res)
model = AutoModel(model="damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
model_revision="v2.0.0",
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
vad_model_revision="v2.0.0",
vad_model_revision="v2.0.1",
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
punc_model_revision="v2.0.0",
spk_model="/Users/shixian/code/modelscope_models/speech_campplus_sv_zh-cn_16k-common",

View File

@ -1,11 +0,0 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
# MIT License (https://opensource.org/licenses/MIT)
from funasr import AutoModel
model = AutoModel(model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", model_revision="v2.0.0")
res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav")
print(res)

View File

@ -1,11 +0,0 @@
model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
model_revision="v2.0.0"
python funasr/bin/inference.py \
+model=${model} \
+model_revision=${model_revision} \
+input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" \
+output_dir="./outputs/debug" \
+device="cpu" \

View File

@ -7,11 +7,9 @@ from funasr import AutoModel
wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav"
chunk_size = 60000 # ms
model = AutoModel(model="/Users/zhifu/Downloads/modelscope_models/speech_fsmn_vad_zh-cn-16k-common-streaming", model_revision="v2.0.0")
model = AutoModel(model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", model_revision="v2.0.1")
res = model(input=wav_file,
chunk_size=chunk_size,
)
res = model(input=wav_file, chunk_size=chunk_size, )
print(res)
@ -22,7 +20,7 @@ import os
wav_file = os.path.join(model.model_path, "example/vad_example.wav")
speech, sample_rate = soundfile.read(wav_file)
chunk_stride = int(chunk_size * 16000 / 1000)
chunk_stride = int(chunk_size * sample_rate / 1000)
cache = {}
@ -35,4 +33,5 @@ for i in range(total_chunk_num):
is_final=is_final,
chunk_size=chunk_size,
)
print(res)
if len(res[0]["value"]):
print(res)

View File

@ -1,7 +1,7 @@
model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
model_revision="v2.0.0"
model_revision="v2.0.1"
python funasr/bin/inference.py \
+model=${model} \

View File

@ -8,7 +8,7 @@ from funasr import AutoModel
model = AutoModel(model="damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
model_revision="v2.0.0",
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
vad_model_revision="v2.0.0",
vad_model_revision="v2.0.1",
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
punc_model_revision="v2.0.0",
spk_model="damo/speech_campplus_sv_zh-cn_16k-common",

View File

@ -2,7 +2,7 @@
model="damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
model_revision="v2.0.0"
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
vad_model_revision="v2.0.0"
vad_model_revision="v2.0.1"
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
punc_model_revision="v2.0.0"
spk_model="damo/speech_campplus_sv_zh-cn_16k-common"

View File

@ -8,7 +8,7 @@ from funasr import AutoModel
model = AutoModel(model="damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
model_revision="v2.0.0",
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
vad_model_revision="v2.0.0",
vad_model_revision="v2.0.1",
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
punc_model_revision="v2.0.0",
)

View File

@ -2,7 +2,7 @@
model="damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
model_revision="v2.0.0"
vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch"
vad_model_revision="v2.0.0"
vad_model_revision="v2.0.1"
punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch"
punc_model_revision="v2.0.0"

View File

@ -402,8 +402,7 @@ class WavFrontendOnline(nn.Module):
self, input: torch.Tensor, input_lengths: torch.Tensor, cache: dict = {}, **kwargs
):
is_final = kwargs.get("is_final", False)
reset = kwargs.get("reset", False)
if len(cache) == 0 or reset:
if len(cache) == 0:
self.init_cache(cache)
batch_size = input.shape[0]

View File

@ -1,303 +0,0 @@
from typing import Tuple, Dict
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from funasr.register import tables
class LinearTransform(nn.Module):
def __init__(self, input_dim, output_dim):
super(LinearTransform, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.linear = nn.Linear(input_dim, output_dim, bias=False)
def forward(self, input):
output = self.linear(input)
return output
class AffineTransform(nn.Module):
def __init__(self, input_dim, output_dim):
super(AffineTransform, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.linear = nn.Linear(input_dim, output_dim)
def forward(self, input):
output = self.linear(input)
return output
class RectifiedLinear(nn.Module):
def __init__(self, input_dim, output_dim):
super(RectifiedLinear, self).__init__()
self.dim = input_dim
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.1)
def forward(self, input):
out = self.relu(input)
return out
class FSMNBlock(nn.Module):
def __init__(
self,
input_dim: int,
output_dim: int,
lorder=None,
rorder=None,
lstride=1,
rstride=1,
):
super(FSMNBlock, self).__init__()
self.dim = input_dim
if lorder is None:
return
self.lorder = lorder
self.rorder = rorder
self.lstride = lstride
self.rstride = rstride
self.conv_left = nn.Conv2d(
self.dim, self.dim, [lorder, 1], dilation=[lstride, 1], groups=self.dim, bias=False)
if self.rorder > 0:
self.conv_right = nn.Conv2d(
self.dim, self.dim, [rorder, 1], dilation=[rstride, 1], groups=self.dim, bias=False)
else:
self.conv_right = None
def forward(self, input: torch.Tensor, cache: torch.Tensor):
x = torch.unsqueeze(input, 1)
x_per = x.permute(0, 3, 2, 1) # B D T C
cache = cache.to(x_per.device)
y_left = torch.cat((cache, x_per), dim=2)
cache = y_left[:, :, -(self.lorder - 1) * self.lstride:, :]
y_left = self.conv_left(y_left)
out = x_per + y_left
if self.conv_right is not None:
# maybe need to check
y_right = F.pad(x_per, [0, 0, 0, self.rorder * self.rstride])
y_right = y_right[:, :, self.rstride:, :]
y_right = self.conv_right(y_right)
out += y_right
out_per = out.permute(0, 3, 2, 1)
output = out_per.squeeze(1)
return output, cache
class BasicBlock(nn.Sequential):
def __init__(self,
linear_dim: int,
proj_dim: int,
lorder: int,
rorder: int,
lstride: int,
rstride: int,
stack_layer: int
):
super(BasicBlock, self).__init__()
self.lorder = lorder
self.rorder = rorder
self.lstride = lstride
self.rstride = rstride
self.stack_layer = stack_layer
self.linear = LinearTransform(linear_dim, proj_dim)
self.fsmn_block = FSMNBlock(proj_dim, proj_dim, lorder, rorder, lstride, rstride)
self.affine = AffineTransform(proj_dim, linear_dim)
self.relu = RectifiedLinear(linear_dim, linear_dim)
def forward(self, input: torch.Tensor, cache: Dict[str, torch.Tensor]):
x1 = self.linear(input) # B T D
cache_layer_name = 'cache_layer_{}'.format(self.stack_layer)
if cache_layer_name not in cache:
cache[cache_layer_name] = torch.zeros(x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1)
x2, cache[cache_layer_name] = self.fsmn_block(x1, cache[cache_layer_name])
x3 = self.affine(x2)
x4 = self.relu(x3)
return x4
class FsmnStack(nn.Sequential):
def __init__(self, *args):
super(FsmnStack, self).__init__(*args)
def forward(self, input: torch.Tensor, cache: Dict[str, torch.Tensor]):
x = input
for module in self._modules.values():
x = module(x, cache)
return x
'''
FSMN net for keyword spotting
input_dim: input dimension
linear_dim: fsmn input dimensionll
proj_dim: fsmn projection dimension
lorder: fsmn left order
rorder: fsmn right order
num_syn: output dimension
fsmn_layers: no. of sequential fsmn layers
'''
@tables.register("encoder_classes", "FSMN")
class FSMN(nn.Module):
def __init__(
self,
input_dim: int,
input_affine_dim: int,
fsmn_layers: int,
linear_dim: int,
proj_dim: int,
lorder: int,
rorder: int,
lstride: int,
rstride: int,
output_affine_dim: int,
output_dim: int
):
super(FSMN, self).__init__()
self.input_dim = input_dim
self.input_affine_dim = input_affine_dim
self.fsmn_layers = fsmn_layers
self.linear_dim = linear_dim
self.proj_dim = proj_dim
self.output_affine_dim = output_affine_dim
self.output_dim = output_dim
self.in_linear1 = AffineTransform(input_dim, input_affine_dim)
self.in_linear2 = AffineTransform(input_affine_dim, linear_dim)
self.relu = RectifiedLinear(linear_dim, linear_dim)
self.fsmn = FsmnStack(*[BasicBlock(linear_dim, proj_dim, lorder, rorder, lstride, rstride, i) for i in
range(fsmn_layers)])
self.out_linear1 = AffineTransform(linear_dim, output_affine_dim)
self.out_linear2 = AffineTransform(output_affine_dim, output_dim)
self.softmax = nn.Softmax(dim=-1)
def fuse_modules(self):
pass
def forward(
self,
input: torch.Tensor,
cache: Dict[str, torch.Tensor]
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Args:
input (torch.Tensor): Input tensor (B, T, D)
cache: when cache is not None, the forward is in streaming. The type of cache is a dict, egs,
{'cache_layer_1': torch.Tensor(B, T1, D)}, T1 is equal to self.lorder. It is {} for the 1st frame
"""
x1 = self.in_linear1(input)
x2 = self.in_linear2(x1)
x3 = self.relu(x2)
x4 = self.fsmn(x3, cache) # self.cache will update automatically in self.fsmn
x5 = self.out_linear1(x4)
x6 = self.out_linear2(x5)
x7 = self.softmax(x6)
return x7
'''
one deep fsmn layer
dimproj: projection dimension, input and output dimension of memory blocks
dimlinear: dimension of mapping layer
lorder: left order
rorder: right order
lstride: left stride
rstride: right stride
'''
@tables.register("encoder_classes", "DFSMN")
class DFSMN(nn.Module):
def __init__(self, dimproj=64, dimlinear=128, lorder=20, rorder=1, lstride=1, rstride=1):
super(DFSMN, self).__init__()
self.lorder = lorder
self.rorder = rorder
self.lstride = lstride
self.rstride = rstride
self.expand = AffineTransform(dimproj, dimlinear)
self.shrink = LinearTransform(dimlinear, dimproj)
self.conv_left = nn.Conv2d(
dimproj, dimproj, [lorder, 1], dilation=[lstride, 1], groups=dimproj, bias=False)
if rorder > 0:
self.conv_right = nn.Conv2d(
dimproj, dimproj, [rorder, 1], dilation=[rstride, 1], groups=dimproj, bias=False)
else:
self.conv_right = None
def forward(self, input):
f1 = F.relu(self.expand(input))
p1 = self.shrink(f1)
x = torch.unsqueeze(p1, 1)
x_per = x.permute(0, 3, 2, 1)
y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0])
if self.conv_right is not None:
y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride])
y_right = y_right[:, :, self.rstride:, :]
out = x_per + self.conv_left(y_left) + self.conv_right(y_right)
else:
out = x_per + self.conv_left(y_left)
out1 = out.permute(0, 3, 2, 1)
output = input + out1.squeeze(1)
return output
'''
build stacked dfsmn layers
'''
def buildDFSMNRepeats(linear_dim=128, proj_dim=64, lorder=20, rorder=1, fsmn_layers=6):
repeats = [
nn.Sequential(
DFSMN(proj_dim, linear_dim, lorder, rorder, 1, 1))
for i in range(fsmn_layers)
]
return nn.Sequential(*repeats)
if __name__ == '__main__':
fsmn = FSMN(400, 140, 4, 250, 128, 10, 2, 1, 1, 140, 2599)
print(fsmn)
num_params = sum(p.numel() for p in fsmn.parameters())
print('the number of model params: {}'.format(num_params))
x = torch.zeros(128, 200, 400) # batch-size * time * dim
y, _ = fsmn(x) # batch-size * time * dim
print('input shape: {}'.format(x.shape))
print('output shape: {}'.format(y.shape))
print(fsmn.to_kaldi_net())

View File

@ -1,740 +0,0 @@
from enum import Enum
from typing import List, Tuple, Dict, Any
import logging
import os
import json
import torch
from torch import nn
import math
from typing import Optional
import time
from funasr.register import tables
from funasr.utils.load_utils import load_audio_text_image_video,extract_fbank
from funasr.utils.datadir_writer import DatadirWriter
from torch.nn.utils.rnn import pad_sequence
from funasr.train_utils.device_funcs import to_device
class VadStateMachine(Enum):
kVadInStateStartPointNotDetected = 1
kVadInStateInSpeechSegment = 2
kVadInStateEndPointDetected = 3
class FrameState(Enum):
kFrameStateInvalid = -1
kFrameStateSpeech = 1
kFrameStateSil = 0
# final voice/unvoice state per frame
class AudioChangeState(Enum):
kChangeStateSpeech2Speech = 0
kChangeStateSpeech2Sil = 1
kChangeStateSil2Sil = 2
kChangeStateSil2Speech = 3
kChangeStateNoBegin = 4
kChangeStateInvalid = 5
class VadDetectMode(Enum):
kVadSingleUtteranceDetectMode = 0
kVadMutipleUtteranceDetectMode = 1
class VADXOptions:
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
https://arxiv.org/abs/1803.05030
"""
def __init__(
self,
sample_rate: int = 16000,
detect_mode: int = VadDetectMode.kVadMutipleUtteranceDetectMode.value,
snr_mode: int = 0,
max_end_silence_time: int = 800,
max_start_silence_time: int = 3000,
do_start_point_detection: bool = True,
do_end_point_detection: bool = True,
window_size_ms: int = 200,
sil_to_speech_time_thres: int = 150,
speech_to_sil_time_thres: int = 150,
speech_2_noise_ratio: float = 1.0,
do_extend: int = 1,
lookback_time_start_point: int = 200,
lookahead_time_end_point: int = 100,
max_single_segment_time: int = 60000,
nn_eval_block_size: int = 8,
dcd_block_size: int = 4,
snr_thres: int = -100.0,
noise_frame_num_used_for_snr: int = 100,
decibel_thres: int = -100.0,
speech_noise_thres: float = 0.6,
fe_prior_thres: float = 1e-4,
silence_pdf_num: int = 1,
sil_pdf_ids: List[int] = [0],
speech_noise_thresh_low: float = -0.1,
speech_noise_thresh_high: float = 0.3,
output_frame_probs: bool = False,
frame_in_ms: int = 10,
frame_length_ms: int = 25,
**kwargs,
):
self.sample_rate = sample_rate
self.detect_mode = detect_mode
self.snr_mode = snr_mode
self.max_end_silence_time = max_end_silence_time
self.max_start_silence_time = max_start_silence_time
self.do_start_point_detection = do_start_point_detection
self.do_end_point_detection = do_end_point_detection
self.window_size_ms = window_size_ms
self.sil_to_speech_time_thres = sil_to_speech_time_thres
self.speech_to_sil_time_thres = speech_to_sil_time_thres
self.speech_2_noise_ratio = speech_2_noise_ratio
self.do_extend = do_extend
self.lookback_time_start_point = lookback_time_start_point
self.lookahead_time_end_point = lookahead_time_end_point
self.max_single_segment_time = max_single_segment_time
self.nn_eval_block_size = nn_eval_block_size
self.dcd_block_size = dcd_block_size
self.snr_thres = snr_thres
self.noise_frame_num_used_for_snr = noise_frame_num_used_for_snr
self.decibel_thres = decibel_thres
self.speech_noise_thres = speech_noise_thres
self.fe_prior_thres = fe_prior_thres
self.silence_pdf_num = silence_pdf_num
self.sil_pdf_ids = sil_pdf_ids
self.speech_noise_thresh_low = speech_noise_thresh_low
self.speech_noise_thresh_high = speech_noise_thresh_high
self.output_frame_probs = output_frame_probs
self.frame_in_ms = frame_in_ms
self.frame_length_ms = frame_length_ms
class E2EVadSpeechBufWithDoa(object):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
https://arxiv.org/abs/1803.05030
"""
def __init__(self):
self.start_ms = 0
self.end_ms = 0
self.buffer = []
self.contain_seg_start_point = False
self.contain_seg_end_point = False
self.doa = 0
def Reset(self):
self.start_ms = 0
self.end_ms = 0
self.buffer = []
self.contain_seg_start_point = False
self.contain_seg_end_point = False
self.doa = 0
class E2EVadFrameProb(object):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
https://arxiv.org/abs/1803.05030
"""
def __init__(self):
self.noise_prob = 0.0
self.speech_prob = 0.0
self.score = 0.0
self.frame_id = 0
self.frm_state = 0
class WindowDetector(object):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
https://arxiv.org/abs/1803.05030
"""
def __init__(self, window_size_ms: int, sil_to_speech_time: int,
speech_to_sil_time: int, frame_size_ms: int):
self.window_size_ms = window_size_ms
self.sil_to_speech_time = sil_to_speech_time
self.speech_to_sil_time = speech_to_sil_time
self.frame_size_ms = frame_size_ms
self.win_size_frame = int(window_size_ms / frame_size_ms)
self.win_sum = 0
self.win_state = [0] * self.win_size_frame # 初始化窗
self.cur_win_pos = 0
self.pre_frame_state = FrameState.kFrameStateSil
self.cur_frame_state = FrameState.kFrameStateSil
self.sil_to_speech_frmcnt_thres = int(sil_to_speech_time / frame_size_ms)
self.speech_to_sil_frmcnt_thres = int(speech_to_sil_time / frame_size_ms)
self.voice_last_frame_count = 0
self.noise_last_frame_count = 0
self.hydre_frame_count = 0
def Reset(self) -> None:
self.cur_win_pos = 0
self.win_sum = 0
self.win_state = [0] * self.win_size_frame
self.pre_frame_state = FrameState.kFrameStateSil
self.cur_frame_state = FrameState.kFrameStateSil
self.voice_last_frame_count = 0
self.noise_last_frame_count = 0
self.hydre_frame_count = 0
def GetWinSize(self) -> int:
return int(self.win_size_frame)
def DetectOneFrame(self, frameState: FrameState, frame_count: int) -> AudioChangeState:
cur_frame_state = FrameState.kFrameStateSil
if frameState == FrameState.kFrameStateSpeech:
cur_frame_state = 1
elif frameState == FrameState.kFrameStateSil:
cur_frame_state = 0
else:
return AudioChangeState.kChangeStateInvalid
self.win_sum -= self.win_state[self.cur_win_pos]
self.win_sum += cur_frame_state
self.win_state[self.cur_win_pos] = cur_frame_state
self.cur_win_pos = (self.cur_win_pos + 1) % self.win_size_frame
if self.pre_frame_state == FrameState.kFrameStateSil and self.win_sum >= self.sil_to_speech_frmcnt_thres:
self.pre_frame_state = FrameState.kFrameStateSpeech
return AudioChangeState.kChangeStateSil2Speech
if self.pre_frame_state == FrameState.kFrameStateSpeech and self.win_sum <= self.speech_to_sil_frmcnt_thres:
self.pre_frame_state = FrameState.kFrameStateSil
return AudioChangeState.kChangeStateSpeech2Sil
if self.pre_frame_state == FrameState.kFrameStateSil:
return AudioChangeState.kChangeStateSil2Sil
if self.pre_frame_state == FrameState.kFrameStateSpeech:
return AudioChangeState.kChangeStateSpeech2Speech
return AudioChangeState.kChangeStateInvalid
def FrameSizeMs(self) -> int:
return int(self.frame_size_ms)
@tables.register("model_classes", "FsmnVAD")
class FsmnVAD(nn.Module):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
https://arxiv.org/abs/1803.05030
"""
def __init__(self,
encoder: str = None,
encoder_conf: Optional[Dict] = None,
vad_post_args: Dict[str, Any] = None,
**kwargs,
):
super().__init__()
self.vad_opts = VADXOptions(**kwargs)
self.windows_detector = WindowDetector(self.vad_opts.window_size_ms,
self.vad_opts.sil_to_speech_time_thres,
self.vad_opts.speech_to_sil_time_thres,
self.vad_opts.frame_in_ms)
encoder_class = tables.encoder_classes.get(encoder.lower())
encoder = encoder_class(**encoder_conf)
self.encoder = encoder
# init variables
self.data_buf_start_frame = 0
self.frm_cnt = 0
self.latest_confirmed_speech_frame = 0
self.lastest_confirmed_silence_frame = -1
self.continous_silence_frame_count = 0
self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
self.confirmed_start_frame = -1
self.confirmed_end_frame = -1
self.number_end_time_detected = 0
self.sil_frame = 0
self.sil_pdf_ids = self.vad_opts.sil_pdf_ids
self.noise_average_decibel = -100.0
self.pre_end_silence_detected = False
self.next_seg = True
self.output_data_buf = []
self.output_data_buf_offset = 0
self.frame_probs = []
self.max_end_sil_frame_cnt_thresh = self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres
self.speech_noise_thres = self.vad_opts.speech_noise_thres
self.scores = None
self.max_time_out = False
self.decibel = []
self.data_buf = None
self.data_buf_all = None
self.waveform = None
self.last_drop_frames = 0
def AllResetDetection(self):
self.data_buf_start_frame = 0
self.frm_cnt = 0
self.latest_confirmed_speech_frame = 0
self.lastest_confirmed_silence_frame = -1
self.continous_silence_frame_count = 0
self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
self.confirmed_start_frame = -1
self.confirmed_end_frame = -1
self.number_end_time_detected = 0
self.sil_frame = 0
self.sil_pdf_ids = self.vad_opts.sil_pdf_ids
self.noise_average_decibel = -100.0
self.pre_end_silence_detected = False
self.next_seg = True
self.output_data_buf = []
self.output_data_buf_offset = 0
self.frame_probs = []
self.max_end_sil_frame_cnt_thresh = self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres
self.speech_noise_thres = self.vad_opts.speech_noise_thres
self.scores = None
self.max_time_out = False
self.decibel = []
self.data_buf = None
self.data_buf_all = None
self.waveform = None
self.last_drop_frames = 0
self.windows_detector.Reset()
def ResetDetection(self):
self.continous_silence_frame_count = 0
self.latest_confirmed_speech_frame = 0
self.lastest_confirmed_silence_frame = -1
self.confirmed_start_frame = -1
self.confirmed_end_frame = -1
self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
self.windows_detector.Reset()
self.sil_frame = 0
self.frame_probs = []
if self.output_data_buf:
assert self.output_data_buf[-1].contain_seg_end_point == True
drop_frames = int(self.output_data_buf[-1].end_ms / self.vad_opts.frame_in_ms)
real_drop_frames = drop_frames - self.last_drop_frames
self.last_drop_frames = drop_frames
self.data_buf_all = self.data_buf_all[real_drop_frames * int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
self.decibel = self.decibel[real_drop_frames:]
self.scores = self.scores[:, real_drop_frames:, :]
def ComputeDecibel(self) -> None:
frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000)
frame_shift_length = int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
if self.data_buf_all is None:
self.data_buf_all = self.waveform[0] # self.data_buf is pointed to self.waveform[0]
self.data_buf = self.data_buf_all
else:
self.data_buf_all = torch.cat((self.data_buf_all, self.waveform[0]))
for offset in range(0, self.waveform.shape[1] - frame_sample_length + 1, frame_shift_length):
self.decibel.append(
10 * math.log10((self.waveform[0][offset: offset + frame_sample_length]).square().sum() + \
0.000001))
def ComputeScores(self, feats: torch.Tensor, cache: Dict[str, torch.Tensor]) -> None:
scores = self.encoder(feats, cache).to('cpu') # return B * T * D
assert scores.shape[1] == feats.shape[1], "The shape between feats and scores does not match"
self.vad_opts.nn_eval_block_size = scores.shape[1]
self.frm_cnt += scores.shape[1] # count total frames
if self.scores is None:
self.scores = scores # the first calculation
else:
self.scores = torch.cat((self.scores, scores), dim=1)
def PopDataBufTillFrame(self, frame_idx: int) -> None: # need check again
while self.data_buf_start_frame < frame_idx:
if len(self.data_buf) >= int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):
self.data_buf_start_frame += 1
self.data_buf = self.data_buf_all[(self.data_buf_start_frame - self.last_drop_frames) * int(
self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
def PopDataToOutputBuf(self, start_frm: int, frm_cnt: int, first_frm_is_start_point: bool,
last_frm_is_end_point: bool, end_point_is_sent_end: bool) -> None:
self.PopDataBufTillFrame(start_frm)
expected_sample_number = int(frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000)
if last_frm_is_end_point:
extra_sample = max(0, int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000 - \
self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000))
expected_sample_number += int(extra_sample)
if end_point_is_sent_end:
expected_sample_number = max(expected_sample_number, len(self.data_buf))
if len(self.data_buf) < expected_sample_number:
print('error in calling pop data_buf\n')
if len(self.output_data_buf) == 0 or first_frm_is_start_point:
self.output_data_buf.append(E2EVadSpeechBufWithDoa())
self.output_data_buf[-1].Reset()
self.output_data_buf[-1].start_ms = start_frm * self.vad_opts.frame_in_ms
self.output_data_buf[-1].end_ms = self.output_data_buf[-1].start_ms
self.output_data_buf[-1].doa = 0
cur_seg = self.output_data_buf[-1]
if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
print('warning\n')
out_pos = len(cur_seg.buffer) # cur_seg.buff现在没做任何操作
data_to_pop = 0
if end_point_is_sent_end:
data_to_pop = expected_sample_number
else:
data_to_pop = int(frm_cnt * self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
if data_to_pop > len(self.data_buf):
print('VAD data_to_pop is bigger than self.data_buf.size()!!!\n')
data_to_pop = len(self.data_buf)
expected_sample_number = len(self.data_buf)
cur_seg.doa = 0
for sample_cpy_out in range(0, data_to_pop):
# cur_seg.buffer[out_pos ++] = data_buf_.back();
out_pos += 1
for sample_cpy_out in range(data_to_pop, expected_sample_number):
# cur_seg.buffer[out_pos++] = data_buf_.back()
out_pos += 1
if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
print('Something wrong with the VAD algorithm\n')
self.data_buf_start_frame += frm_cnt
cur_seg.end_ms = (start_frm + frm_cnt) * self.vad_opts.frame_in_ms
if first_frm_is_start_point:
cur_seg.contain_seg_start_point = True
if last_frm_is_end_point:
cur_seg.contain_seg_end_point = True
def OnSilenceDetected(self, valid_frame: int):
self.lastest_confirmed_silence_frame = valid_frame
if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
self.PopDataBufTillFrame(valid_frame)
# silence_detected_callback_
# pass
def OnVoiceDetected(self, valid_frame: int) -> None:
self.latest_confirmed_speech_frame = valid_frame
self.PopDataToOutputBuf(valid_frame, 1, False, False, False)
def OnVoiceStart(self, start_frame: int, fake_result: bool = False) -> None:
if self.vad_opts.do_start_point_detection:
pass
if self.confirmed_start_frame != -1:
print('not reset vad properly\n')
else:
self.confirmed_start_frame = start_frame
if not fake_result and self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
self.PopDataToOutputBuf(self.confirmed_start_frame, 1, True, False, False)
def OnVoiceEnd(self, end_frame: int, fake_result: bool, is_last_frame: bool) -> None:
for t in range(self.latest_confirmed_speech_frame + 1, end_frame):
self.OnVoiceDetected(t)
if self.vad_opts.do_end_point_detection:
pass
if self.confirmed_end_frame != -1:
print('not reset vad properly\n')
else:
self.confirmed_end_frame = end_frame
if not fake_result:
self.sil_frame = 0
self.PopDataToOutputBuf(self.confirmed_end_frame, 1, False, True, is_last_frame)
self.number_end_time_detected += 1
def MaybeOnVoiceEndIfLastFrame(self, is_final_frame: bool, cur_frm_idx: int) -> None:
if is_final_frame:
self.OnVoiceEnd(cur_frm_idx, False, True)
self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
def GetLatency(self) -> int:
return int(self.LatencyFrmNumAtStartPoint() * self.vad_opts.frame_in_ms)
def LatencyFrmNumAtStartPoint(self) -> int:
vad_latency = self.windows_detector.GetWinSize()
if self.vad_opts.do_extend:
vad_latency += int(self.vad_opts.lookback_time_start_point / self.vad_opts.frame_in_ms)
return vad_latency
def GetFrameState(self, t: int):
frame_state = FrameState.kFrameStateInvalid
cur_decibel = self.decibel[t]
cur_snr = cur_decibel - self.noise_average_decibel
# for each frame, calc log posterior probability of each state
if cur_decibel < self.vad_opts.decibel_thres:
frame_state = FrameState.kFrameStateSil
self.DetectOneFrame(frame_state, t, False)
return frame_state
sum_score = 0.0
noise_prob = 0.0
assert len(self.sil_pdf_ids) == self.vad_opts.silence_pdf_num
if len(self.sil_pdf_ids) > 0:
assert len(self.scores) == 1 # 只支持batch_size = 1的测试
sil_pdf_scores = [self.scores[0][t][sil_pdf_id] for sil_pdf_id in self.sil_pdf_ids]
sum_score = sum(sil_pdf_scores)
noise_prob = math.log(sum_score) * self.vad_opts.speech_2_noise_ratio
total_score = 1.0
sum_score = total_score - sum_score
speech_prob = math.log(sum_score)
if self.vad_opts.output_frame_probs:
frame_prob = E2EVadFrameProb()
frame_prob.noise_prob = noise_prob
frame_prob.speech_prob = speech_prob
frame_prob.score = sum_score
frame_prob.frame_id = t
self.frame_probs.append(frame_prob)
if math.exp(speech_prob) >= math.exp(noise_prob) + self.speech_noise_thres:
if cur_snr >= self.vad_opts.snr_thres and cur_decibel >= self.vad_opts.decibel_thres:
frame_state = FrameState.kFrameStateSpeech
else:
frame_state = FrameState.kFrameStateSil
else:
frame_state = FrameState.kFrameStateSil
if self.noise_average_decibel < -99.9:
self.noise_average_decibel = cur_decibel
else:
self.noise_average_decibel = (cur_decibel + self.noise_average_decibel * (
self.vad_opts.noise_frame_num_used_for_snr
- 1)) / self.vad_opts.noise_frame_num_used_for_snr
return frame_state
def forward(self, feats: torch.Tensor, waveform: torch.tensor, cache: Dict[str, torch.Tensor] = dict(),
is_final: bool = False
):
if not cache:
self.AllResetDetection()
self.waveform = waveform # compute decibel for each frame
self.ComputeDecibel()
self.ComputeScores(feats, cache)
if not is_final:
self.DetectCommonFrames()
else:
self.DetectLastFrames()
segments = []
for batch_num in range(0, feats.shape[0]): # only support batch_size = 1 now
segment_batch = []
if len(self.output_data_buf) > 0:
for i in range(self.output_data_buf_offset, len(self.output_data_buf)):
if not is_final and (not self.output_data_buf[i].contain_seg_start_point or not self.output_data_buf[
i].contain_seg_end_point):
continue
segment = [self.output_data_buf[i].start_ms, self.output_data_buf[i].end_ms]
segment_batch.append(segment)
self.output_data_buf_offset += 1 # need update this parameter
if segment_batch:
segments.append(segment_batch)
if is_final:
# reset class variables and clear the dict for the next query
self.AllResetDetection()
return segments, cache
def generate(self,
data_in,
data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
**kwargs,
):
meta_data = {}
audio_sample_list = [data_in]
if isinstance(data_in, torch.Tensor): # fbank
speech, speech_lengths = data_in, data_lengths
if len(speech.shape) < 3:
speech = speech[None, :, :]
if speech_lengths is None:
speech_lengths = speech.shape[1]
else:
# extract fbank feats
time1 = time.perf_counter()
audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
frontend=frontend)
time3 = time.perf_counter()
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
meta_data[
"batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
speech = speech.to(device=kwargs["device"])
speech_lengths = speech_lengths.to(device=kwargs["device"])
# b. Forward Encoder streaming
t_offset = 0
feats = speech
feats_len = speech_lengths.max().item()
waveform = pad_sequence(audio_sample_list, batch_first=True).to(device=kwargs["device"]) # data: [batch, N]
cache = kwargs.get("cache", {})
batch_size = kwargs.get("batch_size", 1)
step = min(feats_len, 6000)
segments = [[]] * batch_size
for t_offset in range(0, feats_len, min(step, feats_len - t_offset)):
if t_offset + step >= feats_len - 1:
step = feats_len - t_offset
is_final = True
else:
is_final = False
batch = {
"feats": feats[:, t_offset:t_offset + step, :],
"waveform": waveform[:, t_offset * 160:min(waveform.shape[-1], (t_offset + step - 1) * 160 + 400)],
"is_final": is_final,
"cache": cache
}
batch = to_device(batch, device=kwargs["device"])
segments_part, cache = self.forward(**batch)
if segments_part:
for batch_num in range(0, batch_size):
segments[batch_num] += segments_part[batch_num]
ibest_writer = None
if ibest_writer is None and kwargs.get("output_dir") is not None:
writer = DatadirWriter(kwargs.get("output_dir"))
ibest_writer = writer[f"{1}best_recog"]
results = []
for i in range(batch_size):
if ibest_writer is not None:
ibest_writer["text"][key[i]] = segments[i]
result_i = {"key": key[i], "value": segments[i]}
results.append(result_i)
if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
results[i] = json.dumps(results[i])
return results, meta_data
def DetectCommonFrames(self) -> int:
if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
return 0
for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
frame_state = FrameState.kFrameStateInvalid
frame_state = self.GetFrameState(self.frm_cnt - 1 - i - self.last_drop_frames)
self.DetectOneFrame(frame_state, self.frm_cnt - 1 - i, False)
return 0
def DetectLastFrames(self) -> int:
if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
return 0
for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
frame_state = FrameState.kFrameStateInvalid
frame_state = self.GetFrameState(self.frm_cnt - 1 - i - self.last_drop_frames)
if i != 0:
self.DetectOneFrame(frame_state, self.frm_cnt - 1 - i, False)
else:
self.DetectOneFrame(frame_state, self.frm_cnt - 1, True)
return 0
def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool) -> None:
tmp_cur_frm_state = FrameState.kFrameStateInvalid
if cur_frm_state == FrameState.kFrameStateSpeech:
if math.fabs(1.0) > self.vad_opts.fe_prior_thres:
tmp_cur_frm_state = FrameState.kFrameStateSpeech
else:
tmp_cur_frm_state = FrameState.kFrameStateSil
elif cur_frm_state == FrameState.kFrameStateSil:
tmp_cur_frm_state = FrameState.kFrameStateSil
state_change = self.windows_detector.DetectOneFrame(tmp_cur_frm_state, cur_frm_idx)
frm_shift_in_ms = self.vad_opts.frame_in_ms
if AudioChangeState.kChangeStateSil2Speech == state_change:
silence_frame_count = self.continous_silence_frame_count
self.continous_silence_frame_count = 0
self.pre_end_silence_detected = False
start_frame = 0
if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
start_frame = max(self.data_buf_start_frame, cur_frm_idx - self.LatencyFrmNumAtStartPoint())
self.OnVoiceStart(start_frame)
self.vad_state_machine = VadStateMachine.kVadInStateInSpeechSegment
for t in range(start_frame + 1, cur_frm_idx + 1):
self.OnVoiceDetected(t)
elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
for t in range(self.latest_confirmed_speech_frame + 1, cur_frm_idx):
self.OnVoiceDetected(t)
if cur_frm_idx - self.confirmed_start_frame + 1 > \
self.vad_opts.max_single_segment_time / frm_shift_in_ms:
self.OnVoiceEnd(cur_frm_idx, False, False)
self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
elif not is_final_frame:
self.OnVoiceDetected(cur_frm_idx)
else:
self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx)
else:
pass
elif AudioChangeState.kChangeStateSpeech2Sil == state_change:
self.continous_silence_frame_count = 0
if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
pass
elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
if cur_frm_idx - self.confirmed_start_frame + 1 > \
self.vad_opts.max_single_segment_time / frm_shift_in_ms:
self.OnVoiceEnd(cur_frm_idx, False, False)
self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
elif not is_final_frame:
self.OnVoiceDetected(cur_frm_idx)
else:
self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx)
else:
pass
elif AudioChangeState.kChangeStateSpeech2Speech == state_change:
self.continous_silence_frame_count = 0
if self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
if cur_frm_idx - self.confirmed_start_frame + 1 > \
self.vad_opts.max_single_segment_time / frm_shift_in_ms:
self.max_time_out = True
self.OnVoiceEnd(cur_frm_idx, False, False)
self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
elif not is_final_frame:
self.OnVoiceDetected(cur_frm_idx)
else:
self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx)
else:
pass
elif AudioChangeState.kChangeStateSil2Sil == state_change:
self.continous_silence_frame_count += 1
if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
# silence timeout, return zero length decision
if ((self.vad_opts.detect_mode == VadDetectMode.kVadSingleUtteranceDetectMode.value) and (
self.continous_silence_frame_count * frm_shift_in_ms > self.vad_opts.max_start_silence_time)) \
or (is_final_frame and self.number_end_time_detected == 0):
for t in range(self.lastest_confirmed_silence_frame + 1, cur_frm_idx):
self.OnSilenceDetected(t)
self.OnVoiceStart(0, True)
self.OnVoiceEnd(0, True, False);
self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
else:
if cur_frm_idx >= self.LatencyFrmNumAtStartPoint():
self.OnSilenceDetected(cur_frm_idx - self.LatencyFrmNumAtStartPoint())
elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
if self.continous_silence_frame_count * frm_shift_in_ms >= self.max_end_sil_frame_cnt_thresh:
lookback_frame = int(self.max_end_sil_frame_cnt_thresh / frm_shift_in_ms)
if self.vad_opts.do_extend:
lookback_frame -= int(self.vad_opts.lookahead_time_end_point / frm_shift_in_ms)
lookback_frame -= 1
lookback_frame = max(0, lookback_frame)
self.OnVoiceEnd(cur_frm_idx - lookback_frame, False, False)
self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
elif cur_frm_idx - self.confirmed_start_frame + 1 > \
self.vad_opts.max_single_segment_time / frm_shift_in_ms:
self.OnVoiceEnd(cur_frm_idx, False, False)
self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
elif self.vad_opts.do_extend and not is_final_frame:
if self.continous_silence_frame_count <= int(
self.vad_opts.lookahead_time_end_point / frm_shift_in_ms):
self.OnVoiceDetected(cur_frm_idx)
else:
self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx)
else:
pass
if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected and \
self.vad_opts.detect_mode == VadDetectMode.kVadMutipleUtteranceDetectMode.value:
self.ResetDetection()

View File

@ -1,62 +0,0 @@
# This is an example that demonstrates how to configure a model file.
# You can modify the configuration according to your own requirements.
# to print the register_table:
# from funasr.register import tables
# tables.print()
# network architecture
model: FsmnVAD
model_conf:
sample_rate: 16000
detect_mode: 1
snr_mode: 0
max_end_silence_time: 800
max_start_silence_time: 3000
do_start_point_detection: True
do_end_point_detection: True
window_size_ms: 200
sil_to_speech_time_thres: 150
speech_to_sil_time_thres: 150
speech_2_noise_ratio: 1.0
do_extend: 1
lookback_time_start_point: 200
lookahead_time_end_point: 100
max_single_segment_time: 60000
snr_thres: -100.0
noise_frame_num_used_for_snr: 100
decibel_thres: -100.0
speech_noise_thres: 0.6
fe_prior_thres: 0.0001
silence_pdf_num: 1
sil_pdf_ids: [0]
speech_noise_thresh_low: -0.1
speech_noise_thresh_high: 0.3
output_frame_probs: False
frame_in_ms: 10
frame_length_ms: 25
encoder: FSMN
encoder_conf:
input_dim: 400
input_affine_dim: 140
fsmn_layers: 4
linear_dim: 250
proj_dim: 128
lorder: 20
rorder: 0
lstride: 1
rstride: 0
output_affine_dim: 140
output_dim: 248
frontend: WavFrontend
frontend_conf:
fs: 16000
window: hamming
n_mels: 80
frame_length: 25
frame_shift: 10
dither: 0.0
lfr_m: 5
lfr_n: 1

View File

@ -11,7 +11,8 @@ import time
from funasr.register import tables
from funasr.utils.load_utils import load_audio_text_image_video,extract_fbank
from funasr.utils.datadir_writer import DatadirWriter
from torch.nn.utils.rnn import pad_sequence
from dataclasses import dataclass
class VadStateMachine(Enum):
kVadInStateStartPointNotDetected = 1
@ -39,7 +40,6 @@ class VadDetectMode(Enum):
kVadSingleUtteranceDetectMode = 0
kVadMutipleUtteranceDetectMode = 1
class VADXOptions:
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
@ -153,8 +153,10 @@ class WindowDetector(object):
Deep-FSMN for Large Vocabulary Continuous Speech Recognition
https://arxiv.org/abs/1803.05030
"""
def __init__(self, window_size_ms: int, sil_to_speech_time: int,
speech_to_sil_time: int, frame_size_ms: int):
def __init__(self, window_size_ms: int,
sil_to_speech_time: int,
speech_to_sil_time: int,
frame_size_ms: int):
self.window_size_ms = window_size_ms
self.sil_to_speech_time = sil_to_speech_time
self.speech_to_sil_time = speech_to_sil_time
@ -187,7 +189,7 @@ class WindowDetector(object):
def GetWinSize(self) -> int:
return int(self.win_size_frame)
def DetectOneFrame(self, frameState: FrameState, frame_count: int) -> AudioChangeState:
def DetectOneFrame(self, frameState: FrameState, frame_count: int, cache: dict={}) -> AudioChangeState:
cur_frame_state = FrameState.kFrameStateSil
if frameState == FrameState.kFrameStateSpeech:
cur_frame_state = 1
@ -218,6 +220,38 @@ class WindowDetector(object):
return int(self.frame_size_ms)
@dataclass
class StatsItem:
# init variables
data_buf_start_frame = 0
frm_cnt = 0
latest_confirmed_speech_frame = 0
lastest_confirmed_silence_frame = -1
continous_silence_frame_count = 0
vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
confirmed_start_frame = -1
confirmed_end_frame = -1
number_end_time_detected = 0
sil_frame = 0
sil_pdf_ids: list
noise_average_decibel = -100.0
pre_end_silence_detected = False
next_seg = True # unused
output_data_buf = []
output_data_buf_offset = 0
frame_probs = [] # unused
max_end_sil_frame_cnt_thresh: int
speech_noise_thres: float
scores = None
max_time_out = False #unused
decibel = []
data_buf = None
data_buf_all = None
waveform = None
last_drop_frames = 0
@tables.register("model_classes", "FsmnVADStreaming")
class FsmnVADStreaming(nn.Module):
"""
@ -233,143 +267,82 @@ class FsmnVADStreaming(nn.Module):
):
super().__init__()
self.vad_opts = VADXOptions(**kwargs)
self.windows_detector = WindowDetector(self.vad_opts.window_size_ms,
self.vad_opts.sil_to_speech_time_thres,
self.vad_opts.speech_to_sil_time_thres,
self.vad_opts.frame_in_ms)
encoder_class = tables.encoder_classes.get(encoder.lower())
encoder = encoder_class(**encoder_conf)
self.encoder = encoder
# init variables
self.data_buf_start_frame = 0
self.frm_cnt = 0
self.latest_confirmed_speech_frame = 0
self.lastest_confirmed_silence_frame = -1
self.continous_silence_frame_count = 0
self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
self.confirmed_start_frame = -1
self.confirmed_end_frame = -1
self.number_end_time_detected = 0
self.sil_frame = 0
self.sil_pdf_ids = self.vad_opts.sil_pdf_ids
self.noise_average_decibel = -100.0
self.pre_end_silence_detected = False
self.next_seg = True
self.output_data_buf = []
self.output_data_buf_offset = 0
self.frame_probs = []
self.max_end_sil_frame_cnt_thresh = self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres
self.speech_noise_thres = self.vad_opts.speech_noise_thres
self.scores = None
self.max_time_out = False
self.decibel = []
self.data_buf = None
self.data_buf_all = None
self.waveform = None
self.last_drop_frames = 0
def AllResetDetection(self):
self.data_buf_start_frame = 0
self.frm_cnt = 0
self.latest_confirmed_speech_frame = 0
self.lastest_confirmed_silence_frame = -1
self.continous_silence_frame_count = 0
self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
self.confirmed_start_frame = -1
self.confirmed_end_frame = -1
self.number_end_time_detected = 0
self.sil_frame = 0
self.sil_pdf_ids = self.vad_opts.sil_pdf_ids
self.noise_average_decibel = -100.0
self.pre_end_silence_detected = False
self.next_seg = True
def ResetDetection(self, cache: dict = {}):
cache["stats"].continous_silence_frame_count = 0
cache["stats"].latest_confirmed_speech_frame = 0
cache["stats"].lastest_confirmed_silence_frame = -1
cache["stats"].confirmed_start_frame = -1
cache["stats"].confirmed_end_frame = -1
cache["stats"].vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
cache["windows_detector"].Reset()
cache["stats"].sil_frame = 0
cache["stats"].frame_probs = []
self.output_data_buf = []
self.output_data_buf_offset = 0
self.frame_probs = []
self.max_end_sil_frame_cnt_thresh = self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres
self.speech_noise_thres = self.vad_opts.speech_noise_thres
self.scores = None
self.max_time_out = False
self.decibel = []
self.data_buf = None
self.data_buf_all = None
self.waveform = None
self.last_drop_frames = 0
self.windows_detector.Reset()
if cache["stats"].output_data_buf:
assert cache["stats"].output_data_buf[-1].contain_seg_end_point == True
drop_frames = int(cache["stats"].output_data_buf[-1].end_ms / self.vad_opts.frame_in_ms)
real_drop_frames = drop_frames - cache["stats"].last_drop_frames
cache["stats"].last_drop_frames = drop_frames
cache["stats"].data_buf_all = cache["stats"].data_buf_all[real_drop_frames * int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
cache["stats"].decibel = cache["stats"].decibel[real_drop_frames:]
cache["stats"].scores = cache["stats"].scores[:, real_drop_frames:, :]
def ResetDetection(self):
self.continous_silence_frame_count = 0
self.latest_confirmed_speech_frame = 0
self.lastest_confirmed_silence_frame = -1
self.confirmed_start_frame = -1
self.confirmed_end_frame = -1
self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected
self.windows_detector.Reset()
self.sil_frame = 0
self.frame_probs = []
if self.output_data_buf:
assert self.output_data_buf[-1].contain_seg_end_point == True
drop_frames = int(self.output_data_buf[-1].end_ms / self.vad_opts.frame_in_ms)
real_drop_frames = drop_frames - self.last_drop_frames
self.last_drop_frames = drop_frames
self.data_buf_all = self.data_buf_all[real_drop_frames * int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
self.decibel = self.decibel[real_drop_frames:]
self.scores = self.scores[:, real_drop_frames:, :]
def ComputeDecibel(self) -> None:
def ComputeDecibel(self, cache: dict = {}) -> None:
frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000)
frame_shift_length = int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
if self.data_buf_all is None:
self.data_buf_all = self.waveform[0] # self.data_buf is pointed to self.waveform[0]
self.data_buf = self.data_buf_all
if cache["stats"].data_buf_all is None:
cache["stats"].data_buf_all = cache["stats"].waveform[0] # cache["stats"].data_buf is pointed to cache["stats"].waveform[0]
cache["stats"].data_buf = cache["stats"].data_buf_all
else:
self.data_buf_all = torch.cat((self.data_buf_all, self.waveform[0]))
for offset in range(0, self.waveform.shape[1] - frame_sample_length + 1, frame_shift_length):
self.decibel.append(
10 * math.log10((self.waveform[0][offset: offset + frame_sample_length]).square().sum() + \
cache["stats"].data_buf_all = torch.cat((cache["stats"].data_buf_all, cache["stats"].waveform[0]))
for offset in range(0, cache["stats"].waveform.shape[1] - frame_sample_length + 1, frame_shift_length):
cache["stats"].decibel.append(
10 * math.log10((cache["stats"].waveform[0][offset: offset + frame_sample_length]).square().sum() + \
0.000001))
def ComputeScores(self, feats: torch.Tensor, cache: Dict[str, torch.Tensor]) -> None:
scores = self.encoder(feats, cache).to('cpu') # return B * T * D
def ComputeScores(self, feats: torch.Tensor, cache: dict = {}) -> None:
scores = self.encoder(feats, cache=cache["encoder"]).to('cpu') # return B * T * D
assert scores.shape[1] == feats.shape[1], "The shape between feats and scores does not match"
self.vad_opts.nn_eval_block_size = scores.shape[1]
self.frm_cnt += scores.shape[1] # count total frames
if self.scores is None:
self.scores = scores # the first calculation
cache["stats"].frm_cnt += scores.shape[1] # count total frames
if cache["stats"].scores is None:
cache["stats"].scores = scores # the first calculation
else:
self.scores = torch.cat((self.scores, scores), dim=1)
cache["stats"].scores = torch.cat((cache["stats"].scores, scores), dim=1)
def PopDataBufTillFrame(self, frame_idx: int) -> None: # need check again
while self.data_buf_start_frame < frame_idx:
if len(self.data_buf) >= int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):
self.data_buf_start_frame += 1
self.data_buf = self.data_buf_all[(self.data_buf_start_frame - self.last_drop_frames) * int(
def PopDataBufTillFrame(self, frame_idx: int, cache: dict={}) -> None: # need check again
while cache["stats"].data_buf_start_frame < frame_idx:
if len(cache["stats"].data_buf) >= int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):
cache["stats"].data_buf_start_frame += 1
cache["stats"].data_buf = cache["stats"].data_buf_all[(cache["stats"].data_buf_start_frame - cache["stats"].last_drop_frames) * int(
self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):]
def PopDataToOutputBuf(self, start_frm: int, frm_cnt: int, first_frm_is_start_point: bool,
last_frm_is_end_point: bool, end_point_is_sent_end: bool) -> None:
self.PopDataBufTillFrame(start_frm)
last_frm_is_end_point: bool, end_point_is_sent_end: bool, cache: dict={}) -> None:
self.PopDataBufTillFrame(start_frm, cache=cache)
expected_sample_number = int(frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000)
if last_frm_is_end_point:
extra_sample = max(0, int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000 - \
self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000))
expected_sample_number += int(extra_sample)
if end_point_is_sent_end:
expected_sample_number = max(expected_sample_number, len(self.data_buf))
if len(self.data_buf) < expected_sample_number:
expected_sample_number = max(expected_sample_number, len(cache["stats"].data_buf))
if len(cache["stats"].data_buf) < expected_sample_number:
print('error in calling pop data_buf\n')
if len(self.output_data_buf) == 0 or first_frm_is_start_point:
self.output_data_buf.append(E2EVadSpeechBufWithDoa())
self.output_data_buf[-1].Reset()
self.output_data_buf[-1].start_ms = start_frm * self.vad_opts.frame_in_ms
self.output_data_buf[-1].end_ms = self.output_data_buf[-1].start_ms
self.output_data_buf[-1].doa = 0
cur_seg = self.output_data_buf[-1]
if len(cache["stats"].output_data_buf) == 0 or first_frm_is_start_point:
cache["stats"].output_data_buf.append(E2EVadSpeechBufWithDoa())
cache["stats"].output_data_buf[-1].Reset()
cache["stats"].output_data_buf[-1].start_ms = start_frm * self.vad_opts.frame_in_ms
cache["stats"].output_data_buf[-1].end_ms = cache["stats"].output_data_buf[-1].start_ms
cache["stats"].output_data_buf[-1].doa = 0
cur_seg = cache["stats"].output_data_buf[-1]
if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
print('warning\n')
out_pos = len(cur_seg.buffer) # cur_seg.buff现在没做任何操作
@ -378,10 +351,10 @@ class FsmnVADStreaming(nn.Module):
data_to_pop = expected_sample_number
else:
data_to_pop = int(frm_cnt * self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)
if data_to_pop > len(self.data_buf):
print('VAD data_to_pop is bigger than self.data_buf.size()!!!\n')
data_to_pop = len(self.data_buf)
expected_sample_number = len(self.data_buf)
if data_to_pop > len(cache["stats"].data_buf):
print('VAD data_to_pop is bigger than cache["stats"].data_buf.size()!!!\n')
data_to_pop = len(cache["stats"].data_buf)
expected_sample_number = len(cache["stats"].data_buf)
cur_seg.doa = 0
for sample_cpy_out in range(0, data_to_pop):
@ -392,79 +365,79 @@ class FsmnVADStreaming(nn.Module):
out_pos += 1
if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms:
print('Something wrong with the VAD algorithm\n')
self.data_buf_start_frame += frm_cnt
cache["stats"].data_buf_start_frame += frm_cnt
cur_seg.end_ms = (start_frm + frm_cnt) * self.vad_opts.frame_in_ms
if first_frm_is_start_point:
cur_seg.contain_seg_start_point = True
if last_frm_is_end_point:
cur_seg.contain_seg_end_point = True
def OnSilenceDetected(self, valid_frame: int):
self.lastest_confirmed_silence_frame = valid_frame
if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
self.PopDataBufTillFrame(valid_frame)
def OnSilenceDetected(self, valid_frame: int, cache: dict = {}):
cache["stats"].lastest_confirmed_silence_frame = valid_frame
if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
self.PopDataBufTillFrame(valid_frame, cache=cache)
# silence_detected_callback_
# pass
def OnVoiceDetected(self, valid_frame: int) -> None:
self.latest_confirmed_speech_frame = valid_frame
self.PopDataToOutputBuf(valid_frame, 1, False, False, False)
def OnVoiceDetected(self, valid_frame: int, cache:dict={}) -> None:
cache["stats"].latest_confirmed_speech_frame = valid_frame
self.PopDataToOutputBuf(valid_frame, 1, False, False, False, cache=cache)
def OnVoiceStart(self, start_frame: int, fake_result: bool = False) -> None:
def OnVoiceStart(self, start_frame: int, fake_result: bool = False, cache:dict={}) -> None:
if self.vad_opts.do_start_point_detection:
pass
if self.confirmed_start_frame != -1:
if cache["stats"].confirmed_start_frame != -1:
print('not reset vad properly\n')
else:
self.confirmed_start_frame = start_frame
cache["stats"].confirmed_start_frame = start_frame
if not fake_result and self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
self.PopDataToOutputBuf(self.confirmed_start_frame, 1, True, False, False)
if not fake_result and cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
self.PopDataToOutputBuf(cache["stats"].confirmed_start_frame, 1, True, False, False, cache=cache)
def OnVoiceEnd(self, end_frame: int, fake_result: bool, is_last_frame: bool) -> None:
for t in range(self.latest_confirmed_speech_frame + 1, end_frame):
self.OnVoiceDetected(t)
def OnVoiceEnd(self, end_frame: int, fake_result: bool, is_last_frame: bool, cache:dict={}) -> None:
for t in range(cache["stats"].latest_confirmed_speech_frame + 1, end_frame):
self.OnVoiceDetected(t, cache=cache)
if self.vad_opts.do_end_point_detection:
pass
if self.confirmed_end_frame != -1:
if cache["stats"].confirmed_end_frame != -1:
print('not reset vad properly\n')
else:
self.confirmed_end_frame = end_frame
cache["stats"].confirmed_end_frame = end_frame
if not fake_result:
self.sil_frame = 0
self.PopDataToOutputBuf(self.confirmed_end_frame, 1, False, True, is_last_frame)
self.number_end_time_detected += 1
cache["stats"].sil_frame = 0
self.PopDataToOutputBuf(cache["stats"].confirmed_end_frame, 1, False, True, is_last_frame, cache=cache)
cache["stats"].number_end_time_detected += 1
def MaybeOnVoiceEndIfLastFrame(self, is_final_frame: bool, cur_frm_idx: int) -> None:
def MaybeOnVoiceEndIfLastFrame(self, is_final_frame: bool, cur_frm_idx: int, cache: dict = {}) -> None:
if is_final_frame:
self.OnVoiceEnd(cur_frm_idx, False, True)
self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
self.OnVoiceEnd(cur_frm_idx, False, True, cache=cache)
cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
def GetLatency(self) -> int:
return int(self.LatencyFrmNumAtStartPoint() * self.vad_opts.frame_in_ms)
def GetLatency(self, cache: dict = {}) -> int:
return int(self.LatencyFrmNumAtStartPoint(cache=cache) * self.vad_opts.frame_in_ms)
def LatencyFrmNumAtStartPoint(self) -> int:
vad_latency = self.windows_detector.GetWinSize()
def LatencyFrmNumAtStartPoint(self, cache: dict = {}) -> int:
vad_latency = cache["windows_detector"].GetWinSize()
if self.vad_opts.do_extend:
vad_latency += int(self.vad_opts.lookback_time_start_point / self.vad_opts.frame_in_ms)
return vad_latency
def GetFrameState(self, t: int):
def GetFrameState(self, t: int, cache: dict = {}):
frame_state = FrameState.kFrameStateInvalid
cur_decibel = self.decibel[t]
cur_snr = cur_decibel - self.noise_average_decibel
cur_decibel = cache["stats"].decibel[t]
cur_snr = cur_decibel - cache["stats"].noise_average_decibel
# for each frame, calc log posterior probability of each state
if cur_decibel < self.vad_opts.decibel_thres:
frame_state = FrameState.kFrameStateSil
self.DetectOneFrame(frame_state, t, False)
self.DetectOneFrame(frame_state, t, False, cache=cache)
return frame_state
sum_score = 0.0
noise_prob = 0.0
assert len(self.sil_pdf_ids) == self.vad_opts.silence_pdf_num
if len(self.sil_pdf_ids) > 0:
assert len(self.scores) == 1 # 只支持batch_size = 1的测试
sil_pdf_scores = [self.scores[0][t][sil_pdf_id] for sil_pdf_id in self.sil_pdf_ids]
assert len(cache["stats"].sil_pdf_ids) == self.vad_opts.silence_pdf_num
if len(cache["stats"].sil_pdf_ids) > 0:
assert len(cache["stats"].scores) == 1 # 只支持batch_size = 1的测试
sil_pdf_scores = [cache["stats"].scores[0][t][sil_pdf_id] for sil_pdf_id in cache["stats"].sil_pdf_ids]
sum_score = sum(sil_pdf_scores)
noise_prob = math.log(sum_score) * self.vad_opts.speech_2_noise_ratio
total_score = 1.0
@ -476,58 +449,69 @@ class FsmnVADStreaming(nn.Module):
frame_prob.speech_prob = speech_prob
frame_prob.score = sum_score
frame_prob.frame_id = t
self.frame_probs.append(frame_prob)
if math.exp(speech_prob) >= math.exp(noise_prob) + self.speech_noise_thres:
cache["stats"].frame_probs.append(frame_prob)
if math.exp(speech_prob) >= math.exp(noise_prob) + cache["stats"].speech_noise_thres:
if cur_snr >= self.vad_opts.snr_thres and cur_decibel >= self.vad_opts.decibel_thres:
frame_state = FrameState.kFrameStateSpeech
else:
frame_state = FrameState.kFrameStateSil
else:
frame_state = FrameState.kFrameStateSil
if self.noise_average_decibel < -99.9:
self.noise_average_decibel = cur_decibel
if cache["stats"].noise_average_decibel < -99.9:
cache["stats"].noise_average_decibel = cur_decibel
else:
self.noise_average_decibel = (cur_decibel + self.noise_average_decibel * (
cache["stats"].noise_average_decibel = (cur_decibel + cache["stats"].noise_average_decibel * (
self.vad_opts.noise_frame_num_used_for_snr
- 1)) / self.vad_opts.noise_frame_num_used_for_snr
return frame_state
def forward(self, feats: torch.Tensor, waveform: torch.tensor, cache: Dict[str, torch.Tensor] = dict(),
def forward(self, feats: torch.Tensor, waveform: torch.tensor, cache: dict = {},
is_final: bool = False
):
if len(cache) == 0:
self.AllResetDetection()
self.waveform = waveform # compute decibel for each frame
self.ComputeDecibel()
self.ComputeScores(feats, cache)
# if len(cache) == 0:
# self.AllResetDetection()
# self.waveform = waveform # compute decibel for each frame
cache["stats"].waveform = waveform
self.ComputeDecibel(cache=cache)
self.ComputeScores(feats, cache=cache)
if not is_final:
self.DetectCommonFrames()
self.DetectCommonFrames(cache=cache)
else:
self.DetectLastFrames()
self.DetectLastFrames(cache=cache)
segments = []
for batch_num in range(0, feats.shape[0]): # only support batch_size = 1 now
segment_batch = []
if len(self.output_data_buf) > 0:
for i in range(self.output_data_buf_offset, len(self.output_data_buf)):
if not is_final and (not self.output_data_buf[i].contain_seg_start_point or not self.output_data_buf[
if len(cache["stats"].output_data_buf) > 0:
for i in range(cache["stats"].output_data_buf_offset, len(cache["stats"].output_data_buf)):
if not is_final and (not cache["stats"].output_data_buf[i].contain_seg_start_point or not cache["stats"].output_data_buf[
i].contain_seg_end_point):
continue
segment = [self.output_data_buf[i].start_ms, self.output_data_buf[i].end_ms]
segment = [cache["stats"].output_data_buf[i].start_ms, cache["stats"].output_data_buf[i].end_ms]
segment_batch.append(segment)
self.output_data_buf_offset += 1 # need update this parameter
cache["stats"].output_data_buf_offset += 1 # need update this parameter
if segment_batch:
segments.append(segment_batch)
if is_final:
# reset class variables and clear the dict for the next query
self.AllResetDetection()
# if is_final:
# # reset class variables and clear the dict for the next query
# self.AllResetDetection()
return segments
def init_cache(self, cache: dict = {}, **kwargs):
cache["frontend"] = {}
cache["prev_samples"] = torch.empty(0)
cache["encoder"] = {}
windows_detector = WindowDetector(self.vad_opts.window_size_ms,
self.vad_opts.sil_to_speech_time_thres,
self.vad_opts.speech_to_sil_time_thres,
self.vad_opts.frame_in_ms)
stats = StatsItem(sil_pdf_ids=self.vad_opts.sil_pdf_ids,
max_end_sil_frame_cnt_thresh=self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres,
speech_noise_thres=self.vad_opts.speech_noise_thres,
)
cache["windows_detector"] = windows_detector
cache["stats"] = stats
return cache
def generate(self,
@ -544,7 +528,7 @@ class FsmnVADStreaming(nn.Module):
self.init_cache(cache, **kwargs)
meta_data = {}
chunk_size = kwargs.get("chunk_size", 50) # 50ms
chunk_size = kwargs.get("chunk_size", 60000) # 50ms
chunk_stride_samples = int(chunk_size * frontend.fs / 1000)
time1 = time.perf_counter()
@ -585,10 +569,11 @@ class FsmnVADStreaming(nn.Module):
"feats": speech,
"waveform": cache["frontend"]["waveforms"],
"is_final": kwargs["is_final"],
"cache": cache["encoder"]
"cache": cache
}
segments_i = self.forward(**batch)
segments.extend(segments_i)
if len(segments_i) > 0:
segments.extend(*segments_i)
cache["prev_samples"] = audio_sample[:-m]
@ -614,30 +599,30 @@ class FsmnVADStreaming(nn.Module):
return results, meta_data
def DetectCommonFrames(self) -> int:
if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
def DetectCommonFrames(self, cache: dict = {}) -> int:
if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
return 0
for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
frame_state = FrameState.kFrameStateInvalid
frame_state = self.GetFrameState(self.frm_cnt - 1 - i - self.last_drop_frames)
self.DetectOneFrame(frame_state, self.frm_cnt - 1 - i, False)
frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache)
self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache)
return 0
def DetectLastFrames(self) -> int:
if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
def DetectLastFrames(self, cache: dict = {}) -> int:
if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
return 0
for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1):
frame_state = FrameState.kFrameStateInvalid
frame_state = self.GetFrameState(self.frm_cnt - 1 - i - self.last_drop_frames)
frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache)
if i != 0:
self.DetectOneFrame(frame_state, self.frm_cnt - 1 - i, False)
self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache)
else:
self.DetectOneFrame(frame_state, self.frm_cnt - 1, True)
self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1, True, cache=cache)
return 0
def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool) -> None:
def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool, cache: dict = {}) -> None:
tmp_cur_frm_state = FrameState.kFrameStateInvalid
if cur_frm_state == FrameState.kFrameStateSpeech:
if math.fabs(1.0) > self.vad_opts.fe_prior_thres:
@ -646,101 +631,101 @@ class FsmnVADStreaming(nn.Module):
tmp_cur_frm_state = FrameState.kFrameStateSil
elif cur_frm_state == FrameState.kFrameStateSil:
tmp_cur_frm_state = FrameState.kFrameStateSil
state_change = self.windows_detector.DetectOneFrame(tmp_cur_frm_state, cur_frm_idx)
state_change = cache["windows_detector"].DetectOneFrame(tmp_cur_frm_state, cur_frm_idx, cache=cache)
frm_shift_in_ms = self.vad_opts.frame_in_ms
if AudioChangeState.kChangeStateSil2Speech == state_change:
silence_frame_count = self.continous_silence_frame_count
self.continous_silence_frame_count = 0
self.pre_end_silence_detected = False
silence_frame_count = cache["stats"].continous_silence_frame_count
cache["stats"].continous_silence_frame_count = 0
cache["stats"].pre_end_silence_detected = False
start_frame = 0
if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
start_frame = max(self.data_buf_start_frame, cur_frm_idx - self.LatencyFrmNumAtStartPoint())
self.OnVoiceStart(start_frame)
self.vad_state_machine = VadStateMachine.kVadInStateInSpeechSegment
if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
start_frame = max(cache["stats"].data_buf_start_frame, cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache))
self.OnVoiceStart(start_frame, cache=cache)
cache["stats"].vad_state_machine = VadStateMachine.kVadInStateInSpeechSegment
for t in range(start_frame + 1, cur_frm_idx + 1):
self.OnVoiceDetected(t)
elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
for t in range(self.latest_confirmed_speech_frame + 1, cur_frm_idx):
self.OnVoiceDetected(t)
if cur_frm_idx - self.confirmed_start_frame + 1 > \
self.OnVoiceDetected(t, cache=cache)
elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
for t in range(cache["stats"].latest_confirmed_speech_frame + 1, cur_frm_idx):
self.OnVoiceDetected(t, cache=cache)
if cur_frm_idx - cache["stats"].confirmed_start_frame + 1 > \
self.vad_opts.max_single_segment_time / frm_shift_in_ms:
self.OnVoiceEnd(cur_frm_idx, False, False)
self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache)
cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
elif not is_final_frame:
self.OnVoiceDetected(cur_frm_idx)
self.OnVoiceDetected(cur_frm_idx, cache=cache)
else:
self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx)
self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache)
else:
pass
elif AudioChangeState.kChangeStateSpeech2Sil == state_change:
self.continous_silence_frame_count = 0
if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
cache["stats"].continous_silence_frame_count = 0
if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
pass
elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
if cur_frm_idx - self.confirmed_start_frame + 1 > \
elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
if cur_frm_idx - cache["stats"].confirmed_start_frame + 1 > \
self.vad_opts.max_single_segment_time / frm_shift_in_ms:
self.OnVoiceEnd(cur_frm_idx, False, False)
self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache)
cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
elif not is_final_frame:
self.OnVoiceDetected(cur_frm_idx)
self.OnVoiceDetected(cur_frm_idx, cache=cache)
else:
self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx)
self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache)
else:
pass
elif AudioChangeState.kChangeStateSpeech2Speech == state_change:
self.continous_silence_frame_count = 0
if self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
if cur_frm_idx - self.confirmed_start_frame + 1 > \
cache["stats"].continous_silence_frame_count = 0
if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
if cur_frm_idx - cache["stats"].confirmed_start_frame + 1 > \
self.vad_opts.max_single_segment_time / frm_shift_in_ms:
self.max_time_out = True
self.OnVoiceEnd(cur_frm_idx, False, False)
self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
cache["stats"].max_time_out = True
self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache)
cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
elif not is_final_frame:
self.OnVoiceDetected(cur_frm_idx)
self.OnVoiceDetected(cur_frm_idx, cache=cache)
else:
self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx)
self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache)
else:
pass
elif AudioChangeState.kChangeStateSil2Sil == state_change:
self.continous_silence_frame_count += 1
if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
cache["stats"].continous_silence_frame_count += 1
if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected:
# silence timeout, return zero length decision
if ((self.vad_opts.detect_mode == VadDetectMode.kVadSingleUtteranceDetectMode.value) and (
self.continous_silence_frame_count * frm_shift_in_ms > self.vad_opts.max_start_silence_time)) \
or (is_final_frame and self.number_end_time_detected == 0):
for t in range(self.lastest_confirmed_silence_frame + 1, cur_frm_idx):
self.OnSilenceDetected(t)
self.OnVoiceStart(0, True)
self.OnVoiceEnd(0, True, False);
self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
cache["stats"].continous_silence_frame_count * frm_shift_in_ms > self.vad_opts.max_start_silence_time)) \
or (is_final_frame and cache["stats"].number_end_time_detected == 0):
for t in range(cache["stats"].lastest_confirmed_silence_frame + 1, cur_frm_idx):
self.OnSilenceDetected(t, cache=cache)
self.OnVoiceStart(0, True, cache=cache)
self.OnVoiceEnd(0, True, False, cache=cache)
cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
else:
if cur_frm_idx >= self.LatencyFrmNumAtStartPoint():
self.OnSilenceDetected(cur_frm_idx - self.LatencyFrmNumAtStartPoint())
elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
if self.continous_silence_frame_count * frm_shift_in_ms >= self.max_end_sil_frame_cnt_thresh:
lookback_frame = int(self.max_end_sil_frame_cnt_thresh / frm_shift_in_ms)
if cur_frm_idx >= self.LatencyFrmNumAtStartPoint(cache=cache):
self.OnSilenceDetected(cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache), cache=cache)
elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment:
if cache["stats"].continous_silence_frame_count * frm_shift_in_ms >= cache["stats"].max_end_sil_frame_cnt_thresh:
lookback_frame = int(cache["stats"].max_end_sil_frame_cnt_thresh / frm_shift_in_ms)
if self.vad_opts.do_extend:
lookback_frame -= int(self.vad_opts.lookahead_time_end_point / frm_shift_in_ms)
lookback_frame -= 1
lookback_frame = max(0, lookback_frame)
self.OnVoiceEnd(cur_frm_idx - lookback_frame, False, False)
self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
elif cur_frm_idx - self.confirmed_start_frame + 1 > \
self.OnVoiceEnd(cur_frm_idx - lookback_frame, False, False, cache=cache)
cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
elif cur_frm_idx - cache["stats"].confirmed_start_frame + 1 > \
self.vad_opts.max_single_segment_time / frm_shift_in_ms:
self.OnVoiceEnd(cur_frm_idx, False, False)
self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache)
cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected
elif self.vad_opts.do_extend and not is_final_frame:
if self.continous_silence_frame_count <= int(
if cache["stats"].continous_silence_frame_count <= int(
self.vad_opts.lookahead_time_end_point / frm_shift_in_ms):
self.OnVoiceDetected(cur_frm_idx)
self.OnVoiceDetected(cur_frm_idx, cache=cache)
else:
self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx)
self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache)
else:
pass
if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected and \
if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected and \
self.vad_opts.detect_mode == VadDetectMode.kVadMutipleUtteranceDetectMode.value:
self.ResetDetection()
self.ResetDetection(cache=cache)