From c0b186b5b6e950472920964932ba3de546e06dbf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Fri, 12 Jan 2024 22:48:30 +0800 Subject: [PATCH] funasr1.0 streaming --- .../bicif_paraformer/demo.py | 4 +- .../fsmn_vad/demo.py | 11 - .../fsmn_vad/infer.sh | 11 - .../fsmn_vad_streaming/demo.py | 11 +- .../fsmn_vad_streaming/infer.sh | 2 +- .../paraformer-zh-spk/demo.py | 2 +- .../paraformer-zh-spk/infer.sh | 2 +- .../seaco_paraformer/demo.py | 2 +- .../seaco_paraformer/infer.sh | 2 +- funasr/frontends/wav_frontend.py | 3 +- funasr/models/fsmn_vad/__init__.py | 0 funasr/models/fsmn_vad/encoder.py | 303 ------- funasr/models/fsmn_vad/model.py | 740 ------------------ funasr/models/fsmn_vad/template.yaml | 62 -- funasr/models/fsmn_vad_streaming/model.py | 479 ++++++------ 15 files changed, 245 insertions(+), 1389 deletions(-) delete mode 100644 examples/industrial_data_pretraining/fsmn_vad/demo.py delete mode 100644 examples/industrial_data_pretraining/fsmn_vad/infer.sh delete mode 100644 funasr/models/fsmn_vad/__init__.py delete mode 100755 funasr/models/fsmn_vad/encoder.py delete mode 100644 funasr/models/fsmn_vad/model.py delete mode 100644 funasr/models/fsmn_vad/template.yaml diff --git a/examples/industrial_data_pretraining/bicif_paraformer/demo.py b/examples/industrial_data_pretraining/bicif_paraformer/demo.py index 16eed3702..84b0e80b8 100644 --- a/examples/industrial_data_pretraining/bicif_paraformer/demo.py +++ b/examples/industrial_data_pretraining/bicif_paraformer/demo.py @@ -8,7 +8,7 @@ from funasr import AutoModel model = AutoModel(model="damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch", model_revision="v2.0.0", vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", - vad_model_revision="v2.0.0", + vad_model_revision="v2.0.1", punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", punc_model_revision="v2.0.0", spk_model="/Users/shixian/code/modelscope_models/speech_campplus_sv_zh-cn_16k-common", @@ -21,7 +21,7 @@ print(res) model = AutoModel(model="damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch", model_revision="v2.0.0", vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", - vad_model_revision="v2.0.0", + vad_model_revision="v2.0.1", punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", punc_model_revision="v2.0.0", spk_model="/Users/shixian/code/modelscope_models/speech_campplus_sv_zh-cn_16k-common", diff --git a/examples/industrial_data_pretraining/fsmn_vad/demo.py b/examples/industrial_data_pretraining/fsmn_vad/demo.py deleted file mode 100644 index 2a157ee23..000000000 --- a/examples/industrial_data_pretraining/fsmn_vad/demo.py +++ /dev/null @@ -1,11 +0,0 @@ -#!/usr/bin/env python3 -# -*- encoding: utf-8 -*- -# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. -# MIT License (https://opensource.org/licenses/MIT) - -from funasr import AutoModel - -model = AutoModel(model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", model_revision="v2.0.0") - -res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav") -print(res) \ No newline at end of file diff --git a/examples/industrial_data_pretraining/fsmn_vad/infer.sh b/examples/industrial_data_pretraining/fsmn_vad/infer.sh deleted file mode 100644 index dedd14abb..000000000 --- a/examples/industrial_data_pretraining/fsmn_vad/infer.sh +++ /dev/null @@ -1,11 +0,0 @@ - - -model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch" -model_revision="v2.0.0" - -python funasr/bin/inference.py \ -+model=${model} \ -+model_revision=${model_revision} \ -+input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" \ -+output_dir="./outputs/debug" \ -+device="cpu" \ diff --git a/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py b/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py index 6831cbae7..4e3cb7092 100644 --- a/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py +++ b/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py @@ -7,11 +7,9 @@ from funasr import AutoModel wav_file = "https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" chunk_size = 60000 # ms -model = AutoModel(model="/Users/zhifu/Downloads/modelscope_models/speech_fsmn_vad_zh-cn-16k-common-streaming", model_revision="v2.0.0") +model = AutoModel(model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", model_revision="v2.0.1") -res = model(input=wav_file, - chunk_size=chunk_size, - ) +res = model(input=wav_file, chunk_size=chunk_size, ) print(res) @@ -22,7 +20,7 @@ import os wav_file = os.path.join(model.model_path, "example/vad_example.wav") speech, sample_rate = soundfile.read(wav_file) -chunk_stride = int(chunk_size * 16000 / 1000) +chunk_stride = int(chunk_size * sample_rate / 1000) cache = {} @@ -35,4 +33,5 @@ for i in range(total_chunk_num): is_final=is_final, chunk_size=chunk_size, ) - print(res) + if len(res[0]["value"]): + print(res) diff --git a/examples/industrial_data_pretraining/fsmn_vad_streaming/infer.sh b/examples/industrial_data_pretraining/fsmn_vad_streaming/infer.sh index dedd14abb..08ef8bd7d 100644 --- a/examples/industrial_data_pretraining/fsmn_vad_streaming/infer.sh +++ b/examples/industrial_data_pretraining/fsmn_vad_streaming/infer.sh @@ -1,7 +1,7 @@ model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch" -model_revision="v2.0.0" +model_revision="v2.0.1" python funasr/bin/inference.py \ +model=${model} \ diff --git a/examples/industrial_data_pretraining/paraformer-zh-spk/demo.py b/examples/industrial_data_pretraining/paraformer-zh-spk/demo.py index 123ec414b..774d757e1 100644 --- a/examples/industrial_data_pretraining/paraformer-zh-spk/demo.py +++ b/examples/industrial_data_pretraining/paraformer-zh-spk/demo.py @@ -8,7 +8,7 @@ from funasr import AutoModel model = AutoModel(model="damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", model_revision="v2.0.0", vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", - vad_model_revision="v2.0.0", + vad_model_revision="v2.0.1", punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", punc_model_revision="v2.0.0", spk_model="damo/speech_campplus_sv_zh-cn_16k-common", diff --git a/examples/industrial_data_pretraining/paraformer-zh-spk/infer.sh b/examples/industrial_data_pretraining/paraformer-zh-spk/infer.sh index c2325a336..a45740194 100644 --- a/examples/industrial_data_pretraining/paraformer-zh-spk/infer.sh +++ b/examples/industrial_data_pretraining/paraformer-zh-spk/infer.sh @@ -2,7 +2,7 @@ model="damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" model_revision="v2.0.0" vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch" -vad_model_revision="v2.0.0" +vad_model_revision="v2.0.1" punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch" punc_model_revision="v2.0.0" spk_model="damo/speech_campplus_sv_zh-cn_16k-common" diff --git a/examples/industrial_data_pretraining/seaco_paraformer/demo.py b/examples/industrial_data_pretraining/seaco_paraformer/demo.py index 84be0d8ff..63f155eb2 100644 --- a/examples/industrial_data_pretraining/seaco_paraformer/demo.py +++ b/examples/industrial_data_pretraining/seaco_paraformer/demo.py @@ -8,7 +8,7 @@ from funasr import AutoModel model = AutoModel(model="damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", model_revision="v2.0.0", vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", - vad_model_revision="v2.0.0", + vad_model_revision="v2.0.1", punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", punc_model_revision="v2.0.0", ) diff --git a/examples/industrial_data_pretraining/seaco_paraformer/infer.sh b/examples/industrial_data_pretraining/seaco_paraformer/infer.sh index e92d59873..26eeee1d3 100644 --- a/examples/industrial_data_pretraining/seaco_paraformer/infer.sh +++ b/examples/industrial_data_pretraining/seaco_paraformer/infer.sh @@ -2,7 +2,7 @@ model="damo/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" model_revision="v2.0.0" vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch" -vad_model_revision="v2.0.0" +vad_model_revision="v2.0.1" punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch" punc_model_revision="v2.0.0" diff --git a/funasr/frontends/wav_frontend.py b/funasr/frontends/wav_frontend.py index f4100859e..9c896f118 100644 --- a/funasr/frontends/wav_frontend.py +++ b/funasr/frontends/wav_frontend.py @@ -402,8 +402,7 @@ class WavFrontendOnline(nn.Module): self, input: torch.Tensor, input_lengths: torch.Tensor, cache: dict = {}, **kwargs ): is_final = kwargs.get("is_final", False) - reset = kwargs.get("reset", False) - if len(cache) == 0 or reset: + if len(cache) == 0: self.init_cache(cache) batch_size = input.shape[0] diff --git a/funasr/models/fsmn_vad/__init__.py b/funasr/models/fsmn_vad/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/funasr/models/fsmn_vad/encoder.py b/funasr/models/fsmn_vad/encoder.py deleted file mode 100755 index a0a379da5..000000000 --- a/funasr/models/fsmn_vad/encoder.py +++ /dev/null @@ -1,303 +0,0 @@ -from typing import Tuple, Dict -import copy - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F - -from funasr.register import tables - -class LinearTransform(nn.Module): - - def __init__(self, input_dim, output_dim): - super(LinearTransform, self).__init__() - self.input_dim = input_dim - self.output_dim = output_dim - self.linear = nn.Linear(input_dim, output_dim, bias=False) - - def forward(self, input): - output = self.linear(input) - - return output - - -class AffineTransform(nn.Module): - - def __init__(self, input_dim, output_dim): - super(AffineTransform, self).__init__() - self.input_dim = input_dim - self.output_dim = output_dim - self.linear = nn.Linear(input_dim, output_dim) - - def forward(self, input): - output = self.linear(input) - - return output - - -class RectifiedLinear(nn.Module): - - def __init__(self, input_dim, output_dim): - super(RectifiedLinear, self).__init__() - self.dim = input_dim - self.relu = nn.ReLU() - self.dropout = nn.Dropout(0.1) - - def forward(self, input): - out = self.relu(input) - return out - - -class FSMNBlock(nn.Module): - - def __init__( - self, - input_dim: int, - output_dim: int, - lorder=None, - rorder=None, - lstride=1, - rstride=1, - ): - super(FSMNBlock, self).__init__() - - self.dim = input_dim - - if lorder is None: - return - - self.lorder = lorder - self.rorder = rorder - self.lstride = lstride - self.rstride = rstride - - self.conv_left = nn.Conv2d( - self.dim, self.dim, [lorder, 1], dilation=[lstride, 1], groups=self.dim, bias=False) - - if self.rorder > 0: - self.conv_right = nn.Conv2d( - self.dim, self.dim, [rorder, 1], dilation=[rstride, 1], groups=self.dim, bias=False) - else: - self.conv_right = None - - def forward(self, input: torch.Tensor, cache: torch.Tensor): - x = torch.unsqueeze(input, 1) - x_per = x.permute(0, 3, 2, 1) # B D T C - - cache = cache.to(x_per.device) - y_left = torch.cat((cache, x_per), dim=2) - cache = y_left[:, :, -(self.lorder - 1) * self.lstride:, :] - y_left = self.conv_left(y_left) - out = x_per + y_left - - if self.conv_right is not None: - # maybe need to check - y_right = F.pad(x_per, [0, 0, 0, self.rorder * self.rstride]) - y_right = y_right[:, :, self.rstride:, :] - y_right = self.conv_right(y_right) - out += y_right - - out_per = out.permute(0, 3, 2, 1) - output = out_per.squeeze(1) - - return output, cache - - -class BasicBlock(nn.Sequential): - def __init__(self, - linear_dim: int, - proj_dim: int, - lorder: int, - rorder: int, - lstride: int, - rstride: int, - stack_layer: int - ): - super(BasicBlock, self).__init__() - self.lorder = lorder - self.rorder = rorder - self.lstride = lstride - self.rstride = rstride - self.stack_layer = stack_layer - self.linear = LinearTransform(linear_dim, proj_dim) - self.fsmn_block = FSMNBlock(proj_dim, proj_dim, lorder, rorder, lstride, rstride) - self.affine = AffineTransform(proj_dim, linear_dim) - self.relu = RectifiedLinear(linear_dim, linear_dim) - - def forward(self, input: torch.Tensor, cache: Dict[str, torch.Tensor]): - x1 = self.linear(input) # B T D - cache_layer_name = 'cache_layer_{}'.format(self.stack_layer) - if cache_layer_name not in cache: - cache[cache_layer_name] = torch.zeros(x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1) - x2, cache[cache_layer_name] = self.fsmn_block(x1, cache[cache_layer_name]) - x3 = self.affine(x2) - x4 = self.relu(x3) - return x4 - - -class FsmnStack(nn.Sequential): - def __init__(self, *args): - super(FsmnStack, self).__init__(*args) - - def forward(self, input: torch.Tensor, cache: Dict[str, torch.Tensor]): - x = input - for module in self._modules.values(): - x = module(x, cache) - return x - - -''' -FSMN net for keyword spotting -input_dim: input dimension -linear_dim: fsmn input dimensionll -proj_dim: fsmn projection dimension -lorder: fsmn left order -rorder: fsmn right order -num_syn: output dimension -fsmn_layers: no. of sequential fsmn layers -''' - -@tables.register("encoder_classes", "FSMN") -class FSMN(nn.Module): - def __init__( - self, - input_dim: int, - input_affine_dim: int, - fsmn_layers: int, - linear_dim: int, - proj_dim: int, - lorder: int, - rorder: int, - lstride: int, - rstride: int, - output_affine_dim: int, - output_dim: int - ): - super(FSMN, self).__init__() - - self.input_dim = input_dim - self.input_affine_dim = input_affine_dim - self.fsmn_layers = fsmn_layers - self.linear_dim = linear_dim - self.proj_dim = proj_dim - self.output_affine_dim = output_affine_dim - self.output_dim = output_dim - - self.in_linear1 = AffineTransform(input_dim, input_affine_dim) - self.in_linear2 = AffineTransform(input_affine_dim, linear_dim) - self.relu = RectifiedLinear(linear_dim, linear_dim) - self.fsmn = FsmnStack(*[BasicBlock(linear_dim, proj_dim, lorder, rorder, lstride, rstride, i) for i in - range(fsmn_layers)]) - self.out_linear1 = AffineTransform(linear_dim, output_affine_dim) - self.out_linear2 = AffineTransform(output_affine_dim, output_dim) - self.softmax = nn.Softmax(dim=-1) - - def fuse_modules(self): - pass - - def forward( - self, - input: torch.Tensor, - cache: Dict[str, torch.Tensor] - ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: - """ - Args: - input (torch.Tensor): Input tensor (B, T, D) - cache: when cache is not None, the forward is in streaming. The type of cache is a dict, egs, - {'cache_layer_1': torch.Tensor(B, T1, D)}, T1 is equal to self.lorder. It is {} for the 1st frame - """ - - x1 = self.in_linear1(input) - x2 = self.in_linear2(x1) - x3 = self.relu(x2) - x4 = self.fsmn(x3, cache) # self.cache will update automatically in self.fsmn - x5 = self.out_linear1(x4) - x6 = self.out_linear2(x5) - x7 = self.softmax(x6) - - return x7 - - -''' -one deep fsmn layer -dimproj: projection dimension, input and output dimension of memory blocks -dimlinear: dimension of mapping layer -lorder: left order -rorder: right order -lstride: left stride -rstride: right stride -''' - -@tables.register("encoder_classes", "DFSMN") -class DFSMN(nn.Module): - - def __init__(self, dimproj=64, dimlinear=128, lorder=20, rorder=1, lstride=1, rstride=1): - super(DFSMN, self).__init__() - - self.lorder = lorder - self.rorder = rorder - self.lstride = lstride - self.rstride = rstride - - self.expand = AffineTransform(dimproj, dimlinear) - self.shrink = LinearTransform(dimlinear, dimproj) - - self.conv_left = nn.Conv2d( - dimproj, dimproj, [lorder, 1], dilation=[lstride, 1], groups=dimproj, bias=False) - - if rorder > 0: - self.conv_right = nn.Conv2d( - dimproj, dimproj, [rorder, 1], dilation=[rstride, 1], groups=dimproj, bias=False) - else: - self.conv_right = None - - def forward(self, input): - f1 = F.relu(self.expand(input)) - p1 = self.shrink(f1) - - x = torch.unsqueeze(p1, 1) - x_per = x.permute(0, 3, 2, 1) - - y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0]) - - if self.conv_right is not None: - y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride]) - y_right = y_right[:, :, self.rstride:, :] - out = x_per + self.conv_left(y_left) + self.conv_right(y_right) - else: - out = x_per + self.conv_left(y_left) - - out1 = out.permute(0, 3, 2, 1) - output = input + out1.squeeze(1) - - return output - - -''' -build stacked dfsmn layers -''' - - -def buildDFSMNRepeats(linear_dim=128, proj_dim=64, lorder=20, rorder=1, fsmn_layers=6): - repeats = [ - nn.Sequential( - DFSMN(proj_dim, linear_dim, lorder, rorder, 1, 1)) - for i in range(fsmn_layers) - ] - - return nn.Sequential(*repeats) - - -if __name__ == '__main__': - fsmn = FSMN(400, 140, 4, 250, 128, 10, 2, 1, 1, 140, 2599) - print(fsmn) - - num_params = sum(p.numel() for p in fsmn.parameters()) - print('the number of model params: {}'.format(num_params)) - x = torch.zeros(128, 200, 400) # batch-size * time * dim - y, _ = fsmn(x) # batch-size * time * dim - print('input shape: {}'.format(x.shape)) - print('output shape: {}'.format(y.shape)) - - print(fsmn.to_kaldi_net()) diff --git a/funasr/models/fsmn_vad/model.py b/funasr/models/fsmn_vad/model.py deleted file mode 100644 index b31e0612e..000000000 --- a/funasr/models/fsmn_vad/model.py +++ /dev/null @@ -1,740 +0,0 @@ -from enum import Enum -from typing import List, Tuple, Dict, Any -import logging -import os -import json -import torch -from torch import nn -import math -from typing import Optional -import time -from funasr.register import tables -from funasr.utils.load_utils import load_audio_text_image_video,extract_fbank -from funasr.utils.datadir_writer import DatadirWriter -from torch.nn.utils.rnn import pad_sequence -from funasr.train_utils.device_funcs import to_device - -class VadStateMachine(Enum): - kVadInStateStartPointNotDetected = 1 - kVadInStateInSpeechSegment = 2 - kVadInStateEndPointDetected = 3 - - -class FrameState(Enum): - kFrameStateInvalid = -1 - kFrameStateSpeech = 1 - kFrameStateSil = 0 - - -# final voice/unvoice state per frame -class AudioChangeState(Enum): - kChangeStateSpeech2Speech = 0 - kChangeStateSpeech2Sil = 1 - kChangeStateSil2Sil = 2 - kChangeStateSil2Speech = 3 - kChangeStateNoBegin = 4 - kChangeStateInvalid = 5 - - -class VadDetectMode(Enum): - kVadSingleUtteranceDetectMode = 0 - kVadMutipleUtteranceDetectMode = 1 - - -class VADXOptions: - """ - Author: Speech Lab of DAMO Academy, Alibaba Group - Deep-FSMN for Large Vocabulary Continuous Speech Recognition - https://arxiv.org/abs/1803.05030 - """ - def __init__( - self, - sample_rate: int = 16000, - detect_mode: int = VadDetectMode.kVadMutipleUtteranceDetectMode.value, - snr_mode: int = 0, - max_end_silence_time: int = 800, - max_start_silence_time: int = 3000, - do_start_point_detection: bool = True, - do_end_point_detection: bool = True, - window_size_ms: int = 200, - sil_to_speech_time_thres: int = 150, - speech_to_sil_time_thres: int = 150, - speech_2_noise_ratio: float = 1.0, - do_extend: int = 1, - lookback_time_start_point: int = 200, - lookahead_time_end_point: int = 100, - max_single_segment_time: int = 60000, - nn_eval_block_size: int = 8, - dcd_block_size: int = 4, - snr_thres: int = -100.0, - noise_frame_num_used_for_snr: int = 100, - decibel_thres: int = -100.0, - speech_noise_thres: float = 0.6, - fe_prior_thres: float = 1e-4, - silence_pdf_num: int = 1, - sil_pdf_ids: List[int] = [0], - speech_noise_thresh_low: float = -0.1, - speech_noise_thresh_high: float = 0.3, - output_frame_probs: bool = False, - frame_in_ms: int = 10, - frame_length_ms: int = 25, - **kwargs, - ): - self.sample_rate = sample_rate - self.detect_mode = detect_mode - self.snr_mode = snr_mode - self.max_end_silence_time = max_end_silence_time - self.max_start_silence_time = max_start_silence_time - self.do_start_point_detection = do_start_point_detection - self.do_end_point_detection = do_end_point_detection - self.window_size_ms = window_size_ms - self.sil_to_speech_time_thres = sil_to_speech_time_thres - self.speech_to_sil_time_thres = speech_to_sil_time_thres - self.speech_2_noise_ratio = speech_2_noise_ratio - self.do_extend = do_extend - self.lookback_time_start_point = lookback_time_start_point - self.lookahead_time_end_point = lookahead_time_end_point - self.max_single_segment_time = max_single_segment_time - self.nn_eval_block_size = nn_eval_block_size - self.dcd_block_size = dcd_block_size - self.snr_thres = snr_thres - self.noise_frame_num_used_for_snr = noise_frame_num_used_for_snr - self.decibel_thres = decibel_thres - self.speech_noise_thres = speech_noise_thres - self.fe_prior_thres = fe_prior_thres - self.silence_pdf_num = silence_pdf_num - self.sil_pdf_ids = sil_pdf_ids - self.speech_noise_thresh_low = speech_noise_thresh_low - self.speech_noise_thresh_high = speech_noise_thresh_high - self.output_frame_probs = output_frame_probs - self.frame_in_ms = frame_in_ms - self.frame_length_ms = frame_length_ms - - -class E2EVadSpeechBufWithDoa(object): - """ - Author: Speech Lab of DAMO Academy, Alibaba Group - Deep-FSMN for Large Vocabulary Continuous Speech Recognition - https://arxiv.org/abs/1803.05030 - """ - def __init__(self): - self.start_ms = 0 - self.end_ms = 0 - self.buffer = [] - self.contain_seg_start_point = False - self.contain_seg_end_point = False - self.doa = 0 - - def Reset(self): - self.start_ms = 0 - self.end_ms = 0 - self.buffer = [] - self.contain_seg_start_point = False - self.contain_seg_end_point = False - self.doa = 0 - - -class E2EVadFrameProb(object): - """ - Author: Speech Lab of DAMO Academy, Alibaba Group - Deep-FSMN for Large Vocabulary Continuous Speech Recognition - https://arxiv.org/abs/1803.05030 - """ - def __init__(self): - self.noise_prob = 0.0 - self.speech_prob = 0.0 - self.score = 0.0 - self.frame_id = 0 - self.frm_state = 0 - - -class WindowDetector(object): - """ - Author: Speech Lab of DAMO Academy, Alibaba Group - Deep-FSMN for Large Vocabulary Continuous Speech Recognition - https://arxiv.org/abs/1803.05030 - """ - def __init__(self, window_size_ms: int, sil_to_speech_time: int, - speech_to_sil_time: int, frame_size_ms: int): - self.window_size_ms = window_size_ms - self.sil_to_speech_time = sil_to_speech_time - self.speech_to_sil_time = speech_to_sil_time - self.frame_size_ms = frame_size_ms - - self.win_size_frame = int(window_size_ms / frame_size_ms) - self.win_sum = 0 - self.win_state = [0] * self.win_size_frame # 初始化窗 - - self.cur_win_pos = 0 - self.pre_frame_state = FrameState.kFrameStateSil - self.cur_frame_state = FrameState.kFrameStateSil - self.sil_to_speech_frmcnt_thres = int(sil_to_speech_time / frame_size_ms) - self.speech_to_sil_frmcnt_thres = int(speech_to_sil_time / frame_size_ms) - - self.voice_last_frame_count = 0 - self.noise_last_frame_count = 0 - self.hydre_frame_count = 0 - - def Reset(self) -> None: - self.cur_win_pos = 0 - self.win_sum = 0 - self.win_state = [0] * self.win_size_frame - self.pre_frame_state = FrameState.kFrameStateSil - self.cur_frame_state = FrameState.kFrameStateSil - self.voice_last_frame_count = 0 - self.noise_last_frame_count = 0 - self.hydre_frame_count = 0 - - def GetWinSize(self) -> int: - return int(self.win_size_frame) - - def DetectOneFrame(self, frameState: FrameState, frame_count: int) -> AudioChangeState: - cur_frame_state = FrameState.kFrameStateSil - if frameState == FrameState.kFrameStateSpeech: - cur_frame_state = 1 - elif frameState == FrameState.kFrameStateSil: - cur_frame_state = 0 - else: - return AudioChangeState.kChangeStateInvalid - self.win_sum -= self.win_state[self.cur_win_pos] - self.win_sum += cur_frame_state - self.win_state[self.cur_win_pos] = cur_frame_state - self.cur_win_pos = (self.cur_win_pos + 1) % self.win_size_frame - - if self.pre_frame_state == FrameState.kFrameStateSil and self.win_sum >= self.sil_to_speech_frmcnt_thres: - self.pre_frame_state = FrameState.kFrameStateSpeech - return AudioChangeState.kChangeStateSil2Speech - - if self.pre_frame_state == FrameState.kFrameStateSpeech and self.win_sum <= self.speech_to_sil_frmcnt_thres: - self.pre_frame_state = FrameState.kFrameStateSil - return AudioChangeState.kChangeStateSpeech2Sil - - if self.pre_frame_state == FrameState.kFrameStateSil: - return AudioChangeState.kChangeStateSil2Sil - if self.pre_frame_state == FrameState.kFrameStateSpeech: - return AudioChangeState.kChangeStateSpeech2Speech - return AudioChangeState.kChangeStateInvalid - - def FrameSizeMs(self) -> int: - return int(self.frame_size_ms) - - -@tables.register("model_classes", "FsmnVAD") -class FsmnVAD(nn.Module): - """ - Author: Speech Lab of DAMO Academy, Alibaba Group - Deep-FSMN for Large Vocabulary Continuous Speech Recognition - https://arxiv.org/abs/1803.05030 - """ - def __init__(self, - encoder: str = None, - encoder_conf: Optional[Dict] = None, - vad_post_args: Dict[str, Any] = None, - **kwargs, - ): - super().__init__() - self.vad_opts = VADXOptions(**kwargs) - self.windows_detector = WindowDetector(self.vad_opts.window_size_ms, - self.vad_opts.sil_to_speech_time_thres, - self.vad_opts.speech_to_sil_time_thres, - self.vad_opts.frame_in_ms) - - encoder_class = tables.encoder_classes.get(encoder.lower()) - encoder = encoder_class(**encoder_conf) - self.encoder = encoder - # init variables - self.data_buf_start_frame = 0 - self.frm_cnt = 0 - self.latest_confirmed_speech_frame = 0 - self.lastest_confirmed_silence_frame = -1 - self.continous_silence_frame_count = 0 - self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected - self.confirmed_start_frame = -1 - self.confirmed_end_frame = -1 - self.number_end_time_detected = 0 - self.sil_frame = 0 - self.sil_pdf_ids = self.vad_opts.sil_pdf_ids - self.noise_average_decibel = -100.0 - self.pre_end_silence_detected = False - self.next_seg = True - - self.output_data_buf = [] - self.output_data_buf_offset = 0 - self.frame_probs = [] - self.max_end_sil_frame_cnt_thresh = self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres - self.speech_noise_thres = self.vad_opts.speech_noise_thres - self.scores = None - self.max_time_out = False - self.decibel = [] - self.data_buf = None - self.data_buf_all = None - self.waveform = None - self.last_drop_frames = 0 - - def AllResetDetection(self): - self.data_buf_start_frame = 0 - self.frm_cnt = 0 - self.latest_confirmed_speech_frame = 0 - self.lastest_confirmed_silence_frame = -1 - self.continous_silence_frame_count = 0 - self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected - self.confirmed_start_frame = -1 - self.confirmed_end_frame = -1 - self.number_end_time_detected = 0 - self.sil_frame = 0 - self.sil_pdf_ids = self.vad_opts.sil_pdf_ids - self.noise_average_decibel = -100.0 - self.pre_end_silence_detected = False - self.next_seg = True - - self.output_data_buf = [] - self.output_data_buf_offset = 0 - self.frame_probs = [] - self.max_end_sil_frame_cnt_thresh = self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres - self.speech_noise_thres = self.vad_opts.speech_noise_thres - self.scores = None - self.max_time_out = False - self.decibel = [] - self.data_buf = None - self.data_buf_all = None - self.waveform = None - self.last_drop_frames = 0 - self.windows_detector.Reset() - - def ResetDetection(self): - self.continous_silence_frame_count = 0 - self.latest_confirmed_speech_frame = 0 - self.lastest_confirmed_silence_frame = -1 - self.confirmed_start_frame = -1 - self.confirmed_end_frame = -1 - self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected - self.windows_detector.Reset() - self.sil_frame = 0 - self.frame_probs = [] - - if self.output_data_buf: - assert self.output_data_buf[-1].contain_seg_end_point == True - drop_frames = int(self.output_data_buf[-1].end_ms / self.vad_opts.frame_in_ms) - real_drop_frames = drop_frames - self.last_drop_frames - self.last_drop_frames = drop_frames - self.data_buf_all = self.data_buf_all[real_drop_frames * int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):] - self.decibel = self.decibel[real_drop_frames:] - self.scores = self.scores[:, real_drop_frames:, :] - - def ComputeDecibel(self) -> None: - frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000) - frame_shift_length = int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000) - if self.data_buf_all is None: - self.data_buf_all = self.waveform[0] # self.data_buf is pointed to self.waveform[0] - self.data_buf = self.data_buf_all - else: - self.data_buf_all = torch.cat((self.data_buf_all, self.waveform[0])) - for offset in range(0, self.waveform.shape[1] - frame_sample_length + 1, frame_shift_length): - self.decibel.append( - 10 * math.log10((self.waveform[0][offset: offset + frame_sample_length]).square().sum() + \ - 0.000001)) - - def ComputeScores(self, feats: torch.Tensor, cache: Dict[str, torch.Tensor]) -> None: - scores = self.encoder(feats, cache).to('cpu') # return B * T * D - assert scores.shape[1] == feats.shape[1], "The shape between feats and scores does not match" - self.vad_opts.nn_eval_block_size = scores.shape[1] - self.frm_cnt += scores.shape[1] # count total frames - if self.scores is None: - self.scores = scores # the first calculation - else: - self.scores = torch.cat((self.scores, scores), dim=1) - - def PopDataBufTillFrame(self, frame_idx: int) -> None: # need check again - while self.data_buf_start_frame < frame_idx: - if len(self.data_buf) >= int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000): - self.data_buf_start_frame += 1 - self.data_buf = self.data_buf_all[(self.data_buf_start_frame - self.last_drop_frames) * int( - self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):] - - def PopDataToOutputBuf(self, start_frm: int, frm_cnt: int, first_frm_is_start_point: bool, - last_frm_is_end_point: bool, end_point_is_sent_end: bool) -> None: - self.PopDataBufTillFrame(start_frm) - expected_sample_number = int(frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000) - if last_frm_is_end_point: - extra_sample = max(0, int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000 - \ - self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000)) - expected_sample_number += int(extra_sample) - if end_point_is_sent_end: - expected_sample_number = max(expected_sample_number, len(self.data_buf)) - if len(self.data_buf) < expected_sample_number: - print('error in calling pop data_buf\n') - - if len(self.output_data_buf) == 0 or first_frm_is_start_point: - self.output_data_buf.append(E2EVadSpeechBufWithDoa()) - self.output_data_buf[-1].Reset() - self.output_data_buf[-1].start_ms = start_frm * self.vad_opts.frame_in_ms - self.output_data_buf[-1].end_ms = self.output_data_buf[-1].start_ms - self.output_data_buf[-1].doa = 0 - cur_seg = self.output_data_buf[-1] - if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms: - print('warning\n') - out_pos = len(cur_seg.buffer) # cur_seg.buff现在没做任何操作 - data_to_pop = 0 - if end_point_is_sent_end: - data_to_pop = expected_sample_number - else: - data_to_pop = int(frm_cnt * self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000) - if data_to_pop > len(self.data_buf): - print('VAD data_to_pop is bigger than self.data_buf.size()!!!\n') - data_to_pop = len(self.data_buf) - expected_sample_number = len(self.data_buf) - - cur_seg.doa = 0 - for sample_cpy_out in range(0, data_to_pop): - # cur_seg.buffer[out_pos ++] = data_buf_.back(); - out_pos += 1 - for sample_cpy_out in range(data_to_pop, expected_sample_number): - # cur_seg.buffer[out_pos++] = data_buf_.back() - out_pos += 1 - if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms: - print('Something wrong with the VAD algorithm\n') - self.data_buf_start_frame += frm_cnt - cur_seg.end_ms = (start_frm + frm_cnt) * self.vad_opts.frame_in_ms - if first_frm_is_start_point: - cur_seg.contain_seg_start_point = True - if last_frm_is_end_point: - cur_seg.contain_seg_end_point = True - - def OnSilenceDetected(self, valid_frame: int): - self.lastest_confirmed_silence_frame = valid_frame - if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected: - self.PopDataBufTillFrame(valid_frame) - # silence_detected_callback_ - # pass - - def OnVoiceDetected(self, valid_frame: int) -> None: - self.latest_confirmed_speech_frame = valid_frame - self.PopDataToOutputBuf(valid_frame, 1, False, False, False) - - def OnVoiceStart(self, start_frame: int, fake_result: bool = False) -> None: - if self.vad_opts.do_start_point_detection: - pass - if self.confirmed_start_frame != -1: - print('not reset vad properly\n') - else: - self.confirmed_start_frame = start_frame - - if not fake_result and self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected: - self.PopDataToOutputBuf(self.confirmed_start_frame, 1, True, False, False) - - def OnVoiceEnd(self, end_frame: int, fake_result: bool, is_last_frame: bool) -> None: - for t in range(self.latest_confirmed_speech_frame + 1, end_frame): - self.OnVoiceDetected(t) - if self.vad_opts.do_end_point_detection: - pass - if self.confirmed_end_frame != -1: - print('not reset vad properly\n') - else: - self.confirmed_end_frame = end_frame - if not fake_result: - self.sil_frame = 0 - self.PopDataToOutputBuf(self.confirmed_end_frame, 1, False, True, is_last_frame) - self.number_end_time_detected += 1 - - def MaybeOnVoiceEndIfLastFrame(self, is_final_frame: bool, cur_frm_idx: int) -> None: - if is_final_frame: - self.OnVoiceEnd(cur_frm_idx, False, True) - self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected - - def GetLatency(self) -> int: - return int(self.LatencyFrmNumAtStartPoint() * self.vad_opts.frame_in_ms) - - def LatencyFrmNumAtStartPoint(self) -> int: - vad_latency = self.windows_detector.GetWinSize() - if self.vad_opts.do_extend: - vad_latency += int(self.vad_opts.lookback_time_start_point / self.vad_opts.frame_in_ms) - return vad_latency - - def GetFrameState(self, t: int): - frame_state = FrameState.kFrameStateInvalid - cur_decibel = self.decibel[t] - cur_snr = cur_decibel - self.noise_average_decibel - # for each frame, calc log posterior probability of each state - if cur_decibel < self.vad_opts.decibel_thres: - frame_state = FrameState.kFrameStateSil - self.DetectOneFrame(frame_state, t, False) - return frame_state - - sum_score = 0.0 - noise_prob = 0.0 - assert len(self.sil_pdf_ids) == self.vad_opts.silence_pdf_num - if len(self.sil_pdf_ids) > 0: - assert len(self.scores) == 1 # 只支持batch_size = 1的测试 - sil_pdf_scores = [self.scores[0][t][sil_pdf_id] for sil_pdf_id in self.sil_pdf_ids] - sum_score = sum(sil_pdf_scores) - noise_prob = math.log(sum_score) * self.vad_opts.speech_2_noise_ratio - total_score = 1.0 - sum_score = total_score - sum_score - speech_prob = math.log(sum_score) - if self.vad_opts.output_frame_probs: - frame_prob = E2EVadFrameProb() - frame_prob.noise_prob = noise_prob - frame_prob.speech_prob = speech_prob - frame_prob.score = sum_score - frame_prob.frame_id = t - self.frame_probs.append(frame_prob) - if math.exp(speech_prob) >= math.exp(noise_prob) + self.speech_noise_thres: - if cur_snr >= self.vad_opts.snr_thres and cur_decibel >= self.vad_opts.decibel_thres: - frame_state = FrameState.kFrameStateSpeech - else: - frame_state = FrameState.kFrameStateSil - else: - frame_state = FrameState.kFrameStateSil - if self.noise_average_decibel < -99.9: - self.noise_average_decibel = cur_decibel - else: - self.noise_average_decibel = (cur_decibel + self.noise_average_decibel * ( - self.vad_opts.noise_frame_num_used_for_snr - - 1)) / self.vad_opts.noise_frame_num_used_for_snr - - return frame_state - - def forward(self, feats: torch.Tensor, waveform: torch.tensor, cache: Dict[str, torch.Tensor] = dict(), - is_final: bool = False - ): - if not cache: - self.AllResetDetection() - self.waveform = waveform # compute decibel for each frame - self.ComputeDecibel() - self.ComputeScores(feats, cache) - if not is_final: - self.DetectCommonFrames() - else: - self.DetectLastFrames() - segments = [] - for batch_num in range(0, feats.shape[0]): # only support batch_size = 1 now - segment_batch = [] - if len(self.output_data_buf) > 0: - for i in range(self.output_data_buf_offset, len(self.output_data_buf)): - if not is_final and (not self.output_data_buf[i].contain_seg_start_point or not self.output_data_buf[ - i].contain_seg_end_point): - continue - segment = [self.output_data_buf[i].start_ms, self.output_data_buf[i].end_ms] - segment_batch.append(segment) - self.output_data_buf_offset += 1 # need update this parameter - if segment_batch: - segments.append(segment_batch) - if is_final: - # reset class variables and clear the dict for the next query - self.AllResetDetection() - return segments, cache - - def generate(self, - data_in, - data_lengths=None, - key: list = None, - tokenizer=None, - frontend=None, - **kwargs, - ): - - - meta_data = {} - audio_sample_list = [data_in] - if isinstance(data_in, torch.Tensor): # fbank - speech, speech_lengths = data_in, data_lengths - if len(speech.shape) < 3: - speech = speech[None, :, :] - if speech_lengths is None: - speech_lengths = speech.shape[1] - else: - # extract fbank feats - time1 = time.perf_counter() - audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000)) - time2 = time.perf_counter() - meta_data["load_data"] = f"{time2 - time1:0.3f}" - speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), - frontend=frontend) - time3 = time.perf_counter() - meta_data["extract_feat"] = f"{time3 - time2:0.3f}" - meta_data[ - "batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000 - - speech = speech.to(device=kwargs["device"]) - speech_lengths = speech_lengths.to(device=kwargs["device"]) - - # b. Forward Encoder streaming - t_offset = 0 - feats = speech - feats_len = speech_lengths.max().item() - waveform = pad_sequence(audio_sample_list, batch_first=True).to(device=kwargs["device"]) # data: [batch, N] - cache = kwargs.get("cache", {}) - batch_size = kwargs.get("batch_size", 1) - step = min(feats_len, 6000) - segments = [[]] * batch_size - - for t_offset in range(0, feats_len, min(step, feats_len - t_offset)): - if t_offset + step >= feats_len - 1: - step = feats_len - t_offset - is_final = True - else: - is_final = False - batch = { - "feats": feats[:, t_offset:t_offset + step, :], - "waveform": waveform[:, t_offset * 160:min(waveform.shape[-1], (t_offset + step - 1) * 160 + 400)], - "is_final": is_final, - "cache": cache - } - - - batch = to_device(batch, device=kwargs["device"]) - segments_part, cache = self.forward(**batch) - if segments_part: - for batch_num in range(0, batch_size): - segments[batch_num] += segments_part[batch_num] - - ibest_writer = None - if ibest_writer is None and kwargs.get("output_dir") is not None: - writer = DatadirWriter(kwargs.get("output_dir")) - ibest_writer = writer[f"{1}best_recog"] - - results = [] - for i in range(batch_size): - - - if ibest_writer is not None: - ibest_writer["text"][key[i]] = segments[i] - - result_i = {"key": key[i], "value": segments[i]} - results.append(result_i) - - if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas": - results[i] = json.dumps(results[i]) - - return results, meta_data - - def DetectCommonFrames(self) -> int: - if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected: - return 0 - for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1): - frame_state = FrameState.kFrameStateInvalid - frame_state = self.GetFrameState(self.frm_cnt - 1 - i - self.last_drop_frames) - self.DetectOneFrame(frame_state, self.frm_cnt - 1 - i, False) - - return 0 - - def DetectLastFrames(self) -> int: - if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected: - return 0 - for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1): - frame_state = FrameState.kFrameStateInvalid - frame_state = self.GetFrameState(self.frm_cnt - 1 - i - self.last_drop_frames) - if i != 0: - self.DetectOneFrame(frame_state, self.frm_cnt - 1 - i, False) - else: - self.DetectOneFrame(frame_state, self.frm_cnt - 1, True) - - return 0 - - def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool) -> None: - tmp_cur_frm_state = FrameState.kFrameStateInvalid - if cur_frm_state == FrameState.kFrameStateSpeech: - if math.fabs(1.0) > self.vad_opts.fe_prior_thres: - tmp_cur_frm_state = FrameState.kFrameStateSpeech - else: - tmp_cur_frm_state = FrameState.kFrameStateSil - elif cur_frm_state == FrameState.kFrameStateSil: - tmp_cur_frm_state = FrameState.kFrameStateSil - state_change = self.windows_detector.DetectOneFrame(tmp_cur_frm_state, cur_frm_idx) - frm_shift_in_ms = self.vad_opts.frame_in_ms - if AudioChangeState.kChangeStateSil2Speech == state_change: - silence_frame_count = self.continous_silence_frame_count - self.continous_silence_frame_count = 0 - self.pre_end_silence_detected = False - start_frame = 0 - if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected: - start_frame = max(self.data_buf_start_frame, cur_frm_idx - self.LatencyFrmNumAtStartPoint()) - self.OnVoiceStart(start_frame) - self.vad_state_machine = VadStateMachine.kVadInStateInSpeechSegment - for t in range(start_frame + 1, cur_frm_idx + 1): - self.OnVoiceDetected(t) - elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment: - for t in range(self.latest_confirmed_speech_frame + 1, cur_frm_idx): - self.OnVoiceDetected(t) - if cur_frm_idx - self.confirmed_start_frame + 1 > \ - self.vad_opts.max_single_segment_time / frm_shift_in_ms: - self.OnVoiceEnd(cur_frm_idx, False, False) - self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected - elif not is_final_frame: - self.OnVoiceDetected(cur_frm_idx) - else: - self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx) - else: - pass - elif AudioChangeState.kChangeStateSpeech2Sil == state_change: - self.continous_silence_frame_count = 0 - if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected: - pass - elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment: - if cur_frm_idx - self.confirmed_start_frame + 1 > \ - self.vad_opts.max_single_segment_time / frm_shift_in_ms: - self.OnVoiceEnd(cur_frm_idx, False, False) - self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected - elif not is_final_frame: - self.OnVoiceDetected(cur_frm_idx) - else: - self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx) - else: - pass - elif AudioChangeState.kChangeStateSpeech2Speech == state_change: - self.continous_silence_frame_count = 0 - if self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment: - if cur_frm_idx - self.confirmed_start_frame + 1 > \ - self.vad_opts.max_single_segment_time / frm_shift_in_ms: - self.max_time_out = True - self.OnVoiceEnd(cur_frm_idx, False, False) - self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected - elif not is_final_frame: - self.OnVoiceDetected(cur_frm_idx) - else: - self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx) - else: - pass - elif AudioChangeState.kChangeStateSil2Sil == state_change: - self.continous_silence_frame_count += 1 - if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected: - # silence timeout, return zero length decision - if ((self.vad_opts.detect_mode == VadDetectMode.kVadSingleUtteranceDetectMode.value) and ( - self.continous_silence_frame_count * frm_shift_in_ms > self.vad_opts.max_start_silence_time)) \ - or (is_final_frame and self.number_end_time_detected == 0): - for t in range(self.lastest_confirmed_silence_frame + 1, cur_frm_idx): - self.OnSilenceDetected(t) - self.OnVoiceStart(0, True) - self.OnVoiceEnd(0, True, False); - self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected - else: - if cur_frm_idx >= self.LatencyFrmNumAtStartPoint(): - self.OnSilenceDetected(cur_frm_idx - self.LatencyFrmNumAtStartPoint()) - elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment: - if self.continous_silence_frame_count * frm_shift_in_ms >= self.max_end_sil_frame_cnt_thresh: - lookback_frame = int(self.max_end_sil_frame_cnt_thresh / frm_shift_in_ms) - if self.vad_opts.do_extend: - lookback_frame -= int(self.vad_opts.lookahead_time_end_point / frm_shift_in_ms) - lookback_frame -= 1 - lookback_frame = max(0, lookback_frame) - self.OnVoiceEnd(cur_frm_idx - lookback_frame, False, False) - self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected - elif cur_frm_idx - self.confirmed_start_frame + 1 > \ - self.vad_opts.max_single_segment_time / frm_shift_in_ms: - self.OnVoiceEnd(cur_frm_idx, False, False) - self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected - elif self.vad_opts.do_extend and not is_final_frame: - if self.continous_silence_frame_count <= int( - self.vad_opts.lookahead_time_end_point / frm_shift_in_ms): - self.OnVoiceDetected(cur_frm_idx) - else: - self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx) - else: - pass - - if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected and \ - self.vad_opts.detect_mode == VadDetectMode.kVadMutipleUtteranceDetectMode.value: - self.ResetDetection() - - - diff --git a/funasr/models/fsmn_vad/template.yaml b/funasr/models/fsmn_vad/template.yaml deleted file mode 100644 index 90032eb83..000000000 --- a/funasr/models/fsmn_vad/template.yaml +++ /dev/null @@ -1,62 +0,0 @@ -# This is an example that demonstrates how to configure a model file. -# You can modify the configuration according to your own requirements. - -# to print the register_table: -# from funasr.register import tables -# tables.print() - -# network architecture -model: FsmnVAD -model_conf: - sample_rate: 16000 - detect_mode: 1 - snr_mode: 0 - max_end_silence_time: 800 - max_start_silence_time: 3000 - do_start_point_detection: True - do_end_point_detection: True - window_size_ms: 200 - sil_to_speech_time_thres: 150 - speech_to_sil_time_thres: 150 - speech_2_noise_ratio: 1.0 - do_extend: 1 - lookback_time_start_point: 200 - lookahead_time_end_point: 100 - max_single_segment_time: 60000 - snr_thres: -100.0 - noise_frame_num_used_for_snr: 100 - decibel_thres: -100.0 - speech_noise_thres: 0.6 - fe_prior_thres: 0.0001 - silence_pdf_num: 1 - sil_pdf_ids: [0] - speech_noise_thresh_low: -0.1 - speech_noise_thresh_high: 0.3 - output_frame_probs: False - frame_in_ms: 10 - frame_length_ms: 25 - -encoder: FSMN -encoder_conf: - input_dim: 400 - input_affine_dim: 140 - fsmn_layers: 4 - linear_dim: 250 - proj_dim: 128 - lorder: 20 - rorder: 0 - lstride: 1 - rstride: 0 - output_affine_dim: 140 - output_dim: 248 - -frontend: WavFrontend -frontend_conf: - fs: 16000 - window: hamming - n_mels: 80 - frame_length: 25 - frame_shift: 10 - dither: 0.0 - lfr_m: 5 - lfr_n: 1 diff --git a/funasr/models/fsmn_vad_streaming/model.py b/funasr/models/fsmn_vad_streaming/model.py index 9ceacf676..544fab88c 100644 --- a/funasr/models/fsmn_vad_streaming/model.py +++ b/funasr/models/fsmn_vad_streaming/model.py @@ -11,7 +11,8 @@ import time from funasr.register import tables from funasr.utils.load_utils import load_audio_text_image_video,extract_fbank from funasr.utils.datadir_writer import DatadirWriter -from torch.nn.utils.rnn import pad_sequence + +from dataclasses import dataclass class VadStateMachine(Enum): kVadInStateStartPointNotDetected = 1 @@ -39,7 +40,6 @@ class VadDetectMode(Enum): kVadSingleUtteranceDetectMode = 0 kVadMutipleUtteranceDetectMode = 1 - class VADXOptions: """ Author: Speech Lab of DAMO Academy, Alibaba Group @@ -153,8 +153,10 @@ class WindowDetector(object): Deep-FSMN for Large Vocabulary Continuous Speech Recognition https://arxiv.org/abs/1803.05030 """ - def __init__(self, window_size_ms: int, sil_to_speech_time: int, - speech_to_sil_time: int, frame_size_ms: int): + def __init__(self, window_size_ms: int, + sil_to_speech_time: int, + speech_to_sil_time: int, + frame_size_ms: int): self.window_size_ms = window_size_ms self.sil_to_speech_time = sil_to_speech_time self.speech_to_sil_time = speech_to_sil_time @@ -187,7 +189,7 @@ class WindowDetector(object): def GetWinSize(self) -> int: return int(self.win_size_frame) - def DetectOneFrame(self, frameState: FrameState, frame_count: int) -> AudioChangeState: + def DetectOneFrame(self, frameState: FrameState, frame_count: int, cache: dict={}) -> AudioChangeState: cur_frame_state = FrameState.kFrameStateSil if frameState == FrameState.kFrameStateSpeech: cur_frame_state = 1 @@ -218,6 +220,38 @@ class WindowDetector(object): return int(self.frame_size_ms) +@dataclass +class StatsItem: + + # init variables + data_buf_start_frame = 0 + frm_cnt = 0 + latest_confirmed_speech_frame = 0 + lastest_confirmed_silence_frame = -1 + continous_silence_frame_count = 0 + vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected + confirmed_start_frame = -1 + confirmed_end_frame = -1 + number_end_time_detected = 0 + sil_frame = 0 + sil_pdf_ids: list + noise_average_decibel = -100.0 + pre_end_silence_detected = False + next_seg = True # unused + + output_data_buf = [] + output_data_buf_offset = 0 + frame_probs = [] # unused + max_end_sil_frame_cnt_thresh: int + speech_noise_thres: float + scores = None + max_time_out = False #unused + decibel = [] + data_buf = None + data_buf_all = None + waveform = None + last_drop_frames = 0 + @tables.register("model_classes", "FsmnVADStreaming") class FsmnVADStreaming(nn.Module): """ @@ -233,143 +267,82 @@ class FsmnVADStreaming(nn.Module): ): super().__init__() self.vad_opts = VADXOptions(**kwargs) - self.windows_detector = WindowDetector(self.vad_opts.window_size_ms, - self.vad_opts.sil_to_speech_time_thres, - self.vad_opts.speech_to_sil_time_thres, - self.vad_opts.frame_in_ms) - + encoder_class = tables.encoder_classes.get(encoder.lower()) encoder = encoder_class(**encoder_conf) self.encoder = encoder - # init variables - self.data_buf_start_frame = 0 - self.frm_cnt = 0 - self.latest_confirmed_speech_frame = 0 - self.lastest_confirmed_silence_frame = -1 - self.continous_silence_frame_count = 0 - self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected - self.confirmed_start_frame = -1 - self.confirmed_end_frame = -1 - self.number_end_time_detected = 0 - self.sil_frame = 0 - self.sil_pdf_ids = self.vad_opts.sil_pdf_ids - self.noise_average_decibel = -100.0 - self.pre_end_silence_detected = False - self.next_seg = True - self.output_data_buf = [] - self.output_data_buf_offset = 0 - self.frame_probs = [] - self.max_end_sil_frame_cnt_thresh = self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres - self.speech_noise_thres = self.vad_opts.speech_noise_thres - self.scores = None - self.max_time_out = False - self.decibel = [] - self.data_buf = None - self.data_buf_all = None - self.waveform = None - self.last_drop_frames = 0 - def AllResetDetection(self): - self.data_buf_start_frame = 0 - self.frm_cnt = 0 - self.latest_confirmed_speech_frame = 0 - self.lastest_confirmed_silence_frame = -1 - self.continous_silence_frame_count = 0 - self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected - self.confirmed_start_frame = -1 - self.confirmed_end_frame = -1 - self.number_end_time_detected = 0 - self.sil_frame = 0 - self.sil_pdf_ids = self.vad_opts.sil_pdf_ids - self.noise_average_decibel = -100.0 - self.pre_end_silence_detected = False - self.next_seg = True + def ResetDetection(self, cache: dict = {}): + cache["stats"].continous_silence_frame_count = 0 + cache["stats"].latest_confirmed_speech_frame = 0 + cache["stats"].lastest_confirmed_silence_frame = -1 + cache["stats"].confirmed_start_frame = -1 + cache["stats"].confirmed_end_frame = -1 + cache["stats"].vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected + cache["windows_detector"].Reset() + cache["stats"].sil_frame = 0 + cache["stats"].frame_probs = [] - self.output_data_buf = [] - self.output_data_buf_offset = 0 - self.frame_probs = [] - self.max_end_sil_frame_cnt_thresh = self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres - self.speech_noise_thres = self.vad_opts.speech_noise_thres - self.scores = None - self.max_time_out = False - self.decibel = [] - self.data_buf = None - self.data_buf_all = None - self.waveform = None - self.last_drop_frames = 0 - self.windows_detector.Reset() + if cache["stats"].output_data_buf: + assert cache["stats"].output_data_buf[-1].contain_seg_end_point == True + drop_frames = int(cache["stats"].output_data_buf[-1].end_ms / self.vad_opts.frame_in_ms) + real_drop_frames = drop_frames - cache["stats"].last_drop_frames + cache["stats"].last_drop_frames = drop_frames + cache["stats"].data_buf_all = cache["stats"].data_buf_all[real_drop_frames * int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):] + cache["stats"].decibel = cache["stats"].decibel[real_drop_frames:] + cache["stats"].scores = cache["stats"].scores[:, real_drop_frames:, :] - def ResetDetection(self): - self.continous_silence_frame_count = 0 - self.latest_confirmed_speech_frame = 0 - self.lastest_confirmed_silence_frame = -1 - self.confirmed_start_frame = -1 - self.confirmed_end_frame = -1 - self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected - self.windows_detector.Reset() - self.sil_frame = 0 - self.frame_probs = [] - - if self.output_data_buf: - assert self.output_data_buf[-1].contain_seg_end_point == True - drop_frames = int(self.output_data_buf[-1].end_ms / self.vad_opts.frame_in_ms) - real_drop_frames = drop_frames - self.last_drop_frames - self.last_drop_frames = drop_frames - self.data_buf_all = self.data_buf_all[real_drop_frames * int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):] - self.decibel = self.decibel[real_drop_frames:] - self.scores = self.scores[:, real_drop_frames:, :] - - def ComputeDecibel(self) -> None: + def ComputeDecibel(self, cache: dict = {}) -> None: frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000) frame_shift_length = int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000) - if self.data_buf_all is None: - self.data_buf_all = self.waveform[0] # self.data_buf is pointed to self.waveform[0] - self.data_buf = self.data_buf_all + if cache["stats"].data_buf_all is None: + cache["stats"].data_buf_all = cache["stats"].waveform[0] # cache["stats"].data_buf is pointed to cache["stats"].waveform[0] + cache["stats"].data_buf = cache["stats"].data_buf_all else: - self.data_buf_all = torch.cat((self.data_buf_all, self.waveform[0])) - for offset in range(0, self.waveform.shape[1] - frame_sample_length + 1, frame_shift_length): - self.decibel.append( - 10 * math.log10((self.waveform[0][offset: offset + frame_sample_length]).square().sum() + \ + cache["stats"].data_buf_all = torch.cat((cache["stats"].data_buf_all, cache["stats"].waveform[0])) + for offset in range(0, cache["stats"].waveform.shape[1] - frame_sample_length + 1, frame_shift_length): + cache["stats"].decibel.append( + 10 * math.log10((cache["stats"].waveform[0][offset: offset + frame_sample_length]).square().sum() + \ 0.000001)) - def ComputeScores(self, feats: torch.Tensor, cache: Dict[str, torch.Tensor]) -> None: - scores = self.encoder(feats, cache).to('cpu') # return B * T * D + def ComputeScores(self, feats: torch.Tensor, cache: dict = {}) -> None: + scores = self.encoder(feats, cache=cache["encoder"]).to('cpu') # return B * T * D assert scores.shape[1] == feats.shape[1], "The shape between feats and scores does not match" self.vad_opts.nn_eval_block_size = scores.shape[1] - self.frm_cnt += scores.shape[1] # count total frames - if self.scores is None: - self.scores = scores # the first calculation + cache["stats"].frm_cnt += scores.shape[1] # count total frames + if cache["stats"].scores is None: + cache["stats"].scores = scores # the first calculation else: - self.scores = torch.cat((self.scores, scores), dim=1) + cache["stats"].scores = torch.cat((cache["stats"].scores, scores), dim=1) - def PopDataBufTillFrame(self, frame_idx: int) -> None: # need check again - while self.data_buf_start_frame < frame_idx: - if len(self.data_buf) >= int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000): - self.data_buf_start_frame += 1 - self.data_buf = self.data_buf_all[(self.data_buf_start_frame - self.last_drop_frames) * int( + def PopDataBufTillFrame(self, frame_idx: int, cache: dict={}) -> None: # need check again + while cache["stats"].data_buf_start_frame < frame_idx: + if len(cache["stats"].data_buf) >= int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000): + cache["stats"].data_buf_start_frame += 1 + cache["stats"].data_buf = cache["stats"].data_buf_all[(cache["stats"].data_buf_start_frame - cache["stats"].last_drop_frames) * int( self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):] def PopDataToOutputBuf(self, start_frm: int, frm_cnt: int, first_frm_is_start_point: bool, - last_frm_is_end_point: bool, end_point_is_sent_end: bool) -> None: - self.PopDataBufTillFrame(start_frm) + last_frm_is_end_point: bool, end_point_is_sent_end: bool, cache: dict={}) -> None: + self.PopDataBufTillFrame(start_frm, cache=cache) expected_sample_number = int(frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000) if last_frm_is_end_point: extra_sample = max(0, int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000 - \ self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000)) expected_sample_number += int(extra_sample) if end_point_is_sent_end: - expected_sample_number = max(expected_sample_number, len(self.data_buf)) - if len(self.data_buf) < expected_sample_number: + expected_sample_number = max(expected_sample_number, len(cache["stats"].data_buf)) + if len(cache["stats"].data_buf) < expected_sample_number: print('error in calling pop data_buf\n') - if len(self.output_data_buf) == 0 or first_frm_is_start_point: - self.output_data_buf.append(E2EVadSpeechBufWithDoa()) - self.output_data_buf[-1].Reset() - self.output_data_buf[-1].start_ms = start_frm * self.vad_opts.frame_in_ms - self.output_data_buf[-1].end_ms = self.output_data_buf[-1].start_ms - self.output_data_buf[-1].doa = 0 - cur_seg = self.output_data_buf[-1] + if len(cache["stats"].output_data_buf) == 0 or first_frm_is_start_point: + cache["stats"].output_data_buf.append(E2EVadSpeechBufWithDoa()) + cache["stats"].output_data_buf[-1].Reset() + cache["stats"].output_data_buf[-1].start_ms = start_frm * self.vad_opts.frame_in_ms + cache["stats"].output_data_buf[-1].end_ms = cache["stats"].output_data_buf[-1].start_ms + cache["stats"].output_data_buf[-1].doa = 0 + cur_seg = cache["stats"].output_data_buf[-1] if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms: print('warning\n') out_pos = len(cur_seg.buffer) # cur_seg.buff现在没做任何操作 @@ -378,10 +351,10 @@ class FsmnVADStreaming(nn.Module): data_to_pop = expected_sample_number else: data_to_pop = int(frm_cnt * self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000) - if data_to_pop > len(self.data_buf): - print('VAD data_to_pop is bigger than self.data_buf.size()!!!\n') - data_to_pop = len(self.data_buf) - expected_sample_number = len(self.data_buf) + if data_to_pop > len(cache["stats"].data_buf): + print('VAD data_to_pop is bigger than cache["stats"].data_buf.size()!!!\n') + data_to_pop = len(cache["stats"].data_buf) + expected_sample_number = len(cache["stats"].data_buf) cur_seg.doa = 0 for sample_cpy_out in range(0, data_to_pop): @@ -392,79 +365,79 @@ class FsmnVADStreaming(nn.Module): out_pos += 1 if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms: print('Something wrong with the VAD algorithm\n') - self.data_buf_start_frame += frm_cnt + cache["stats"].data_buf_start_frame += frm_cnt cur_seg.end_ms = (start_frm + frm_cnt) * self.vad_opts.frame_in_ms if first_frm_is_start_point: cur_seg.contain_seg_start_point = True if last_frm_is_end_point: cur_seg.contain_seg_end_point = True - def OnSilenceDetected(self, valid_frame: int): - self.lastest_confirmed_silence_frame = valid_frame - if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected: - self.PopDataBufTillFrame(valid_frame) + def OnSilenceDetected(self, valid_frame: int, cache: dict = {}): + cache["stats"].lastest_confirmed_silence_frame = valid_frame + if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected: + self.PopDataBufTillFrame(valid_frame, cache=cache) # silence_detected_callback_ # pass - def OnVoiceDetected(self, valid_frame: int) -> None: - self.latest_confirmed_speech_frame = valid_frame - self.PopDataToOutputBuf(valid_frame, 1, False, False, False) + def OnVoiceDetected(self, valid_frame: int, cache:dict={}) -> None: + cache["stats"].latest_confirmed_speech_frame = valid_frame + self.PopDataToOutputBuf(valid_frame, 1, False, False, False, cache=cache) - def OnVoiceStart(self, start_frame: int, fake_result: bool = False) -> None: + def OnVoiceStart(self, start_frame: int, fake_result: bool = False, cache:dict={}) -> None: if self.vad_opts.do_start_point_detection: pass - if self.confirmed_start_frame != -1: + if cache["stats"].confirmed_start_frame != -1: print('not reset vad properly\n') else: - self.confirmed_start_frame = start_frame + cache["stats"].confirmed_start_frame = start_frame - if not fake_result and self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected: - self.PopDataToOutputBuf(self.confirmed_start_frame, 1, True, False, False) + if not fake_result and cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected: + self.PopDataToOutputBuf(cache["stats"].confirmed_start_frame, 1, True, False, False, cache=cache) - def OnVoiceEnd(self, end_frame: int, fake_result: bool, is_last_frame: bool) -> None: - for t in range(self.latest_confirmed_speech_frame + 1, end_frame): - self.OnVoiceDetected(t) + def OnVoiceEnd(self, end_frame: int, fake_result: bool, is_last_frame: bool, cache:dict={}) -> None: + for t in range(cache["stats"].latest_confirmed_speech_frame + 1, end_frame): + self.OnVoiceDetected(t, cache=cache) if self.vad_opts.do_end_point_detection: pass - if self.confirmed_end_frame != -1: + if cache["stats"].confirmed_end_frame != -1: print('not reset vad properly\n') else: - self.confirmed_end_frame = end_frame + cache["stats"].confirmed_end_frame = end_frame if not fake_result: - self.sil_frame = 0 - self.PopDataToOutputBuf(self.confirmed_end_frame, 1, False, True, is_last_frame) - self.number_end_time_detected += 1 + cache["stats"].sil_frame = 0 + self.PopDataToOutputBuf(cache["stats"].confirmed_end_frame, 1, False, True, is_last_frame, cache=cache) + cache["stats"].number_end_time_detected += 1 - def MaybeOnVoiceEndIfLastFrame(self, is_final_frame: bool, cur_frm_idx: int) -> None: + def MaybeOnVoiceEndIfLastFrame(self, is_final_frame: bool, cur_frm_idx: int, cache: dict = {}) -> None: if is_final_frame: - self.OnVoiceEnd(cur_frm_idx, False, True) - self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected + self.OnVoiceEnd(cur_frm_idx, False, True, cache=cache) + cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected - def GetLatency(self) -> int: - return int(self.LatencyFrmNumAtStartPoint() * self.vad_opts.frame_in_ms) + def GetLatency(self, cache: dict = {}) -> int: + return int(self.LatencyFrmNumAtStartPoint(cache=cache) * self.vad_opts.frame_in_ms) - def LatencyFrmNumAtStartPoint(self) -> int: - vad_latency = self.windows_detector.GetWinSize() + def LatencyFrmNumAtStartPoint(self, cache: dict = {}) -> int: + vad_latency = cache["windows_detector"].GetWinSize() if self.vad_opts.do_extend: vad_latency += int(self.vad_opts.lookback_time_start_point / self.vad_opts.frame_in_ms) return vad_latency - def GetFrameState(self, t: int): + def GetFrameState(self, t: int, cache: dict = {}): frame_state = FrameState.kFrameStateInvalid - cur_decibel = self.decibel[t] - cur_snr = cur_decibel - self.noise_average_decibel + cur_decibel = cache["stats"].decibel[t] + cur_snr = cur_decibel - cache["stats"].noise_average_decibel # for each frame, calc log posterior probability of each state if cur_decibel < self.vad_opts.decibel_thres: frame_state = FrameState.kFrameStateSil - self.DetectOneFrame(frame_state, t, False) + self.DetectOneFrame(frame_state, t, False, cache=cache) return frame_state sum_score = 0.0 noise_prob = 0.0 - assert len(self.sil_pdf_ids) == self.vad_opts.silence_pdf_num - if len(self.sil_pdf_ids) > 0: - assert len(self.scores) == 1 # 只支持batch_size = 1的测试 - sil_pdf_scores = [self.scores[0][t][sil_pdf_id] for sil_pdf_id in self.sil_pdf_ids] + assert len(cache["stats"].sil_pdf_ids) == self.vad_opts.silence_pdf_num + if len(cache["stats"].sil_pdf_ids) > 0: + assert len(cache["stats"].scores) == 1 # 只支持batch_size = 1的测试 + sil_pdf_scores = [cache["stats"].scores[0][t][sil_pdf_id] for sil_pdf_id in cache["stats"].sil_pdf_ids] sum_score = sum(sil_pdf_scores) noise_prob = math.log(sum_score) * self.vad_opts.speech_2_noise_ratio total_score = 1.0 @@ -476,58 +449,69 @@ class FsmnVADStreaming(nn.Module): frame_prob.speech_prob = speech_prob frame_prob.score = sum_score frame_prob.frame_id = t - self.frame_probs.append(frame_prob) - if math.exp(speech_prob) >= math.exp(noise_prob) + self.speech_noise_thres: + cache["stats"].frame_probs.append(frame_prob) + if math.exp(speech_prob) >= math.exp(noise_prob) + cache["stats"].speech_noise_thres: if cur_snr >= self.vad_opts.snr_thres and cur_decibel >= self.vad_opts.decibel_thres: frame_state = FrameState.kFrameStateSpeech else: frame_state = FrameState.kFrameStateSil else: frame_state = FrameState.kFrameStateSil - if self.noise_average_decibel < -99.9: - self.noise_average_decibel = cur_decibel + if cache["stats"].noise_average_decibel < -99.9: + cache["stats"].noise_average_decibel = cur_decibel else: - self.noise_average_decibel = (cur_decibel + self.noise_average_decibel * ( + cache["stats"].noise_average_decibel = (cur_decibel + cache["stats"].noise_average_decibel * ( self.vad_opts.noise_frame_num_used_for_snr - 1)) / self.vad_opts.noise_frame_num_used_for_snr return frame_state - def forward(self, feats: torch.Tensor, waveform: torch.tensor, cache: Dict[str, torch.Tensor] = dict(), + def forward(self, feats: torch.Tensor, waveform: torch.tensor, cache: dict = {}, is_final: bool = False ): - if len(cache) == 0: - self.AllResetDetection() - self.waveform = waveform # compute decibel for each frame - self.ComputeDecibel() - self.ComputeScores(feats, cache) + # if len(cache) == 0: + # self.AllResetDetection() + # self.waveform = waveform # compute decibel for each frame + cache["stats"].waveform = waveform + self.ComputeDecibel(cache=cache) + self.ComputeScores(feats, cache=cache) if not is_final: - self.DetectCommonFrames() + self.DetectCommonFrames(cache=cache) else: - self.DetectLastFrames() + self.DetectLastFrames(cache=cache) segments = [] for batch_num in range(0, feats.shape[0]): # only support batch_size = 1 now segment_batch = [] - if len(self.output_data_buf) > 0: - for i in range(self.output_data_buf_offset, len(self.output_data_buf)): - if not is_final and (not self.output_data_buf[i].contain_seg_start_point or not self.output_data_buf[ + if len(cache["stats"].output_data_buf) > 0: + for i in range(cache["stats"].output_data_buf_offset, len(cache["stats"].output_data_buf)): + if not is_final and (not cache["stats"].output_data_buf[i].contain_seg_start_point or not cache["stats"].output_data_buf[ i].contain_seg_end_point): continue - segment = [self.output_data_buf[i].start_ms, self.output_data_buf[i].end_ms] + segment = [cache["stats"].output_data_buf[i].start_ms, cache["stats"].output_data_buf[i].end_ms] segment_batch.append(segment) - self.output_data_buf_offset += 1 # need update this parameter + cache["stats"].output_data_buf_offset += 1 # need update this parameter if segment_batch: segments.append(segment_batch) - if is_final: - # reset class variables and clear the dict for the next query - self.AllResetDetection() + # if is_final: + # # reset class variables and clear the dict for the next query + # self.AllResetDetection() return segments def init_cache(self, cache: dict = {}, **kwargs): cache["frontend"] = {} cache["prev_samples"] = torch.empty(0) cache["encoder"] = {} - + windows_detector = WindowDetector(self.vad_opts.window_size_ms, + self.vad_opts.sil_to_speech_time_thres, + self.vad_opts.speech_to_sil_time_thres, + self.vad_opts.frame_in_ms) + + stats = StatsItem(sil_pdf_ids=self.vad_opts.sil_pdf_ids, + max_end_sil_frame_cnt_thresh=self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres, + speech_noise_thres=self.vad_opts.speech_noise_thres, + ) + cache["windows_detector"] = windows_detector + cache["stats"] = stats return cache def generate(self, @@ -544,7 +528,7 @@ class FsmnVADStreaming(nn.Module): self.init_cache(cache, **kwargs) meta_data = {} - chunk_size = kwargs.get("chunk_size", 50) # 50ms + chunk_size = kwargs.get("chunk_size", 60000) # 50ms chunk_stride_samples = int(chunk_size * frontend.fs / 1000) time1 = time.perf_counter() @@ -585,10 +569,11 @@ class FsmnVADStreaming(nn.Module): "feats": speech, "waveform": cache["frontend"]["waveforms"], "is_final": kwargs["is_final"], - "cache": cache["encoder"] + "cache": cache } segments_i = self.forward(**batch) - segments.extend(segments_i) + if len(segments_i) > 0: + segments.extend(*segments_i) cache["prev_samples"] = audio_sample[:-m] @@ -614,30 +599,30 @@ class FsmnVADStreaming(nn.Module): return results, meta_data - def DetectCommonFrames(self) -> int: - if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected: + def DetectCommonFrames(self, cache: dict = {}) -> int: + if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected: return 0 for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1): frame_state = FrameState.kFrameStateInvalid - frame_state = self.GetFrameState(self.frm_cnt - 1 - i - self.last_drop_frames) - self.DetectOneFrame(frame_state, self.frm_cnt - 1 - i, False) + frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache) + self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache) return 0 - def DetectLastFrames(self) -> int: - if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected: + def DetectLastFrames(self, cache: dict = {}) -> int: + if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected: return 0 for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1): frame_state = FrameState.kFrameStateInvalid - frame_state = self.GetFrameState(self.frm_cnt - 1 - i - self.last_drop_frames) + frame_state = self.GetFrameState(cache["stats"].frm_cnt - 1 - i - cache["stats"].last_drop_frames, cache=cache) if i != 0: - self.DetectOneFrame(frame_state, self.frm_cnt - 1 - i, False) + self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1 - i, False, cache=cache) else: - self.DetectOneFrame(frame_state, self.frm_cnt - 1, True) + self.DetectOneFrame(frame_state, cache["stats"].frm_cnt - 1, True, cache=cache) return 0 - def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool) -> None: + def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool, cache: dict = {}) -> None: tmp_cur_frm_state = FrameState.kFrameStateInvalid if cur_frm_state == FrameState.kFrameStateSpeech: if math.fabs(1.0) > self.vad_opts.fe_prior_thres: @@ -646,101 +631,101 @@ class FsmnVADStreaming(nn.Module): tmp_cur_frm_state = FrameState.kFrameStateSil elif cur_frm_state == FrameState.kFrameStateSil: tmp_cur_frm_state = FrameState.kFrameStateSil - state_change = self.windows_detector.DetectOneFrame(tmp_cur_frm_state, cur_frm_idx) + state_change = cache["windows_detector"].DetectOneFrame(tmp_cur_frm_state, cur_frm_idx, cache=cache) frm_shift_in_ms = self.vad_opts.frame_in_ms if AudioChangeState.kChangeStateSil2Speech == state_change: - silence_frame_count = self.continous_silence_frame_count - self.continous_silence_frame_count = 0 - self.pre_end_silence_detected = False + silence_frame_count = cache["stats"].continous_silence_frame_count + cache["stats"].continous_silence_frame_count = 0 + cache["stats"].pre_end_silence_detected = False start_frame = 0 - if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected: - start_frame = max(self.data_buf_start_frame, cur_frm_idx - self.LatencyFrmNumAtStartPoint()) - self.OnVoiceStart(start_frame) - self.vad_state_machine = VadStateMachine.kVadInStateInSpeechSegment + if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected: + start_frame = max(cache["stats"].data_buf_start_frame, cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache)) + self.OnVoiceStart(start_frame, cache=cache) + cache["stats"].vad_state_machine = VadStateMachine.kVadInStateInSpeechSegment for t in range(start_frame + 1, cur_frm_idx + 1): - self.OnVoiceDetected(t) - elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment: - for t in range(self.latest_confirmed_speech_frame + 1, cur_frm_idx): - self.OnVoiceDetected(t) - if cur_frm_idx - self.confirmed_start_frame + 1 > \ + self.OnVoiceDetected(t, cache=cache) + elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment: + for t in range(cache["stats"].latest_confirmed_speech_frame + 1, cur_frm_idx): + self.OnVoiceDetected(t, cache=cache) + if cur_frm_idx - cache["stats"].confirmed_start_frame + 1 > \ self.vad_opts.max_single_segment_time / frm_shift_in_ms: - self.OnVoiceEnd(cur_frm_idx, False, False) - self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected + self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache) + cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected elif not is_final_frame: - self.OnVoiceDetected(cur_frm_idx) + self.OnVoiceDetected(cur_frm_idx, cache=cache) else: - self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx) + self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache) else: pass elif AudioChangeState.kChangeStateSpeech2Sil == state_change: - self.continous_silence_frame_count = 0 - if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected: + cache["stats"].continous_silence_frame_count = 0 + if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected: pass - elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment: - if cur_frm_idx - self.confirmed_start_frame + 1 > \ + elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment: + if cur_frm_idx - cache["stats"].confirmed_start_frame + 1 > \ self.vad_opts.max_single_segment_time / frm_shift_in_ms: - self.OnVoiceEnd(cur_frm_idx, False, False) - self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected + self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache) + cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected elif not is_final_frame: - self.OnVoiceDetected(cur_frm_idx) + self.OnVoiceDetected(cur_frm_idx, cache=cache) else: - self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx) + self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache) else: pass elif AudioChangeState.kChangeStateSpeech2Speech == state_change: - self.continous_silence_frame_count = 0 - if self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment: - if cur_frm_idx - self.confirmed_start_frame + 1 > \ + cache["stats"].continous_silence_frame_count = 0 + if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment: + if cur_frm_idx - cache["stats"].confirmed_start_frame + 1 > \ self.vad_opts.max_single_segment_time / frm_shift_in_ms: - self.max_time_out = True - self.OnVoiceEnd(cur_frm_idx, False, False) - self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected + cache["stats"].max_time_out = True + self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache) + cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected elif not is_final_frame: - self.OnVoiceDetected(cur_frm_idx) + self.OnVoiceDetected(cur_frm_idx, cache=cache) else: - self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx) + self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache) else: pass elif AudioChangeState.kChangeStateSil2Sil == state_change: - self.continous_silence_frame_count += 1 - if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected: + cache["stats"].continous_silence_frame_count += 1 + if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected: # silence timeout, return zero length decision if ((self.vad_opts.detect_mode == VadDetectMode.kVadSingleUtteranceDetectMode.value) and ( - self.continous_silence_frame_count * frm_shift_in_ms > self.vad_opts.max_start_silence_time)) \ - or (is_final_frame and self.number_end_time_detected == 0): - for t in range(self.lastest_confirmed_silence_frame + 1, cur_frm_idx): - self.OnSilenceDetected(t) - self.OnVoiceStart(0, True) - self.OnVoiceEnd(0, True, False); - self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected + cache["stats"].continous_silence_frame_count * frm_shift_in_ms > self.vad_opts.max_start_silence_time)) \ + or (is_final_frame and cache["stats"].number_end_time_detected == 0): + for t in range(cache["stats"].lastest_confirmed_silence_frame + 1, cur_frm_idx): + self.OnSilenceDetected(t, cache=cache) + self.OnVoiceStart(0, True, cache=cache) + self.OnVoiceEnd(0, True, False, cache=cache) + cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected else: - if cur_frm_idx >= self.LatencyFrmNumAtStartPoint(): - self.OnSilenceDetected(cur_frm_idx - self.LatencyFrmNumAtStartPoint()) - elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment: - if self.continous_silence_frame_count * frm_shift_in_ms >= self.max_end_sil_frame_cnt_thresh: - lookback_frame = int(self.max_end_sil_frame_cnt_thresh / frm_shift_in_ms) + if cur_frm_idx >= self.LatencyFrmNumAtStartPoint(cache=cache): + self.OnSilenceDetected(cur_frm_idx - self.LatencyFrmNumAtStartPoint(cache=cache), cache=cache) + elif cache["stats"].vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment: + if cache["stats"].continous_silence_frame_count * frm_shift_in_ms >= cache["stats"].max_end_sil_frame_cnt_thresh: + lookback_frame = int(cache["stats"].max_end_sil_frame_cnt_thresh / frm_shift_in_ms) if self.vad_opts.do_extend: lookback_frame -= int(self.vad_opts.lookahead_time_end_point / frm_shift_in_ms) lookback_frame -= 1 lookback_frame = max(0, lookback_frame) - self.OnVoiceEnd(cur_frm_idx - lookback_frame, False, False) - self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected - elif cur_frm_idx - self.confirmed_start_frame + 1 > \ + self.OnVoiceEnd(cur_frm_idx - lookback_frame, False, False, cache=cache) + cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected + elif cur_frm_idx - cache["stats"].confirmed_start_frame + 1 > \ self.vad_opts.max_single_segment_time / frm_shift_in_ms: - self.OnVoiceEnd(cur_frm_idx, False, False) - self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected + self.OnVoiceEnd(cur_frm_idx, False, False, cache=cache) + cache["stats"].vad_state_machine = VadStateMachine.kVadInStateEndPointDetected elif self.vad_opts.do_extend and not is_final_frame: - if self.continous_silence_frame_count <= int( + if cache["stats"].continous_silence_frame_count <= int( self.vad_opts.lookahead_time_end_point / frm_shift_in_ms): - self.OnVoiceDetected(cur_frm_idx) + self.OnVoiceDetected(cur_frm_idx, cache=cache) else: - self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx) + self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx, cache=cache) else: pass - if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected and \ + if cache["stats"].vad_state_machine == VadStateMachine.kVadInStateEndPointDetected and \ self.vad_opts.detect_mode == VadDetectMode.kVadMutipleUtteranceDetectMode.value: - self.ResetDetection() + self.ResetDetection(cache=cache)