From cf2f14345aa2c4f168ee51c200b8081c748980b8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B8=B8=E9=9B=81?= Date: Fri, 12 Jan 2024 00:01:25 +0800 Subject: [PATCH] funasr1.0 fsmn-vad streaming --- .../fsmn_vad_streaming/demo.py | 11 + .../fsmn_vad_streaming/infer.sh | 11 + .../paraformer_streaming/demo.py | 2 - funasr/models/fsmn_vad/encoder.py | 18 +- funasr/models/fsmn_vad/model.py | 58 +- funasr/models/fsmn_vad_streaming/__init__.py | 0 funasr/models/fsmn_vad_streaming/encoder.py | 303 +++++++ funasr/models/fsmn_vad_streaming/model.py | 781 ++++++++++++++++++ .../models/fsmn_vad_streaming/template.yaml | 62 ++ funasr/models/paraformer_streaming/model.py | 15 +- .../models/paraformer_streaming/template.yaml | 143 ++++ funasr/utils/load_utils.py | 23 +- 12 files changed, 1356 insertions(+), 71 deletions(-) create mode 100644 examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py create mode 100644 examples/industrial_data_pretraining/fsmn_vad_streaming/infer.sh create mode 100644 funasr/models/fsmn_vad_streaming/__init__.py create mode 100755 funasr/models/fsmn_vad_streaming/encoder.py create mode 100644 funasr/models/fsmn_vad_streaming/model.py create mode 100644 funasr/models/fsmn_vad_streaming/template.yaml create mode 100644 funasr/models/paraformer_streaming/template.yaml diff --git a/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py b/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py new file mode 100644 index 000000000..2a157ee23 --- /dev/null +++ b/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py @@ -0,0 +1,11 @@ +#!/usr/bin/env python3 +# -*- encoding: utf-8 -*- +# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. +# MIT License (https://opensource.org/licenses/MIT) + +from funasr import AutoModel + +model = AutoModel(model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", model_revision="v2.0.0") + +res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav") +print(res) \ No newline at end of file diff --git a/examples/industrial_data_pretraining/fsmn_vad_streaming/infer.sh b/examples/industrial_data_pretraining/fsmn_vad_streaming/infer.sh new file mode 100644 index 000000000..dedd14abb --- /dev/null +++ b/examples/industrial_data_pretraining/fsmn_vad_streaming/infer.sh @@ -0,0 +1,11 @@ + + +model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch" +model_revision="v2.0.0" + +python funasr/bin/inference.py \ ++model=${model} \ ++model_revision=${model_revision} \ ++input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/vad_example.wav" \ ++output_dir="./outputs/debug" \ ++device="cpu" \ diff --git a/examples/industrial_data_pretraining/paraformer_streaming/demo.py b/examples/industrial_data_pretraining/paraformer_streaming/demo.py index 9923a0445..6d464f25f 100644 --- a/examples/industrial_data_pretraining/paraformer_streaming/demo.py +++ b/examples/industrial_data_pretraining/paraformer_streaming/demo.py @@ -12,8 +12,6 @@ decoder_chunk_look_back = 1 #number of encoder chunks to lookback for decoder cr model = AutoModel(model="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online", model_revison="v2.0.0") cache = {} res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav", - cache=cache, - is_final=True, chunk_size=chunk_size, encoder_chunk_look_back=encoder_chunk_look_back, decoder_chunk_look_back=decoder_chunk_look_back, diff --git a/funasr/models/fsmn_vad/encoder.py b/funasr/models/fsmn_vad/encoder.py index 54410acb3..a0a379da5 100755 --- a/funasr/models/fsmn_vad/encoder.py +++ b/funasr/models/fsmn_vad/encoder.py @@ -125,12 +125,12 @@ class BasicBlock(nn.Sequential): self.affine = AffineTransform(proj_dim, linear_dim) self.relu = RectifiedLinear(linear_dim, linear_dim) - def forward(self, input: torch.Tensor, in_cache: Dict[str, torch.Tensor]): + def forward(self, input: torch.Tensor, cache: Dict[str, torch.Tensor]): x1 = self.linear(input) # B T D cache_layer_name = 'cache_layer_{}'.format(self.stack_layer) - if cache_layer_name not in in_cache: - in_cache[cache_layer_name] = torch.zeros(x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1) - x2, in_cache[cache_layer_name] = self.fsmn_block(x1, in_cache[cache_layer_name]) + if cache_layer_name not in cache: + cache[cache_layer_name] = torch.zeros(x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1) + x2, cache[cache_layer_name] = self.fsmn_block(x1, cache[cache_layer_name]) x3 = self.affine(x2) x4 = self.relu(x3) return x4 @@ -140,10 +140,10 @@ class FsmnStack(nn.Sequential): def __init__(self, *args): super(FsmnStack, self).__init__(*args) - def forward(self, input: torch.Tensor, in_cache: Dict[str, torch.Tensor]): + def forward(self, input: torch.Tensor, cache: Dict[str, torch.Tensor]): x = input for module in self._modules.values(): - x = module(x, in_cache) + x = module(x, cache) return x @@ -199,19 +199,19 @@ class FSMN(nn.Module): def forward( self, input: torch.Tensor, - in_cache: Dict[str, torch.Tensor] + cache: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Args: input (torch.Tensor): Input tensor (B, T, D) - in_cache: when in_cache is not None, the forward is in streaming. The type of in_cache is a dict, egs, + cache: when cache is not None, the forward is in streaming. The type of cache is a dict, egs, {'cache_layer_1': torch.Tensor(B, T1, D)}, T1 is equal to self.lorder. It is {} for the 1st frame """ x1 = self.in_linear1(input) x2 = self.in_linear2(x1) x3 = self.relu(x2) - x4 = self.fsmn(x3, in_cache) # self.in_cache will update automatically in self.fsmn + x4 = self.fsmn(x3, cache) # self.cache will update automatically in self.fsmn x5 = self.out_linear1(x4) x6 = self.out_linear2(x5) x7 = self.softmax(x6) diff --git a/funasr/models/fsmn_vad/model.py b/funasr/models/fsmn_vad/model.py index f6e0488a8..1ed077394 100644 --- a/funasr/models/fsmn_vad/model.py +++ b/funasr/models/fsmn_vad/model.py @@ -333,8 +333,8 @@ class FsmnVAD(nn.Module): 10 * math.log10((self.waveform[0][offset: offset + frame_sample_length]).square().sum() + \ 0.000001)) - def ComputeScores(self, feats: torch.Tensor, in_cache: Dict[str, torch.Tensor]) -> None: - scores = self.encoder(feats, in_cache).to('cpu') # return B * T * D + def ComputeScores(self, feats: torch.Tensor, cache: Dict[str, torch.Tensor]) -> None: + scores = self.encoder(feats, cache).to('cpu') # return B * T * D assert scores.shape[1] == feats.shape[1], "The shape between feats and scores does not match" self.vad_opts.nn_eval_block_size = scores.shape[1] self.frm_cnt += scores.shape[1] # count total frames @@ -493,14 +493,14 @@ class FsmnVAD(nn.Module): return frame_state - def forward(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(), + def forward(self, feats: torch.Tensor, waveform: torch.tensor, cache: Dict[str, torch.Tensor] = dict(), is_final: bool = False ): - if not in_cache: + if not cache: self.AllResetDetection() self.waveform = waveform # compute decibel for each frame self.ComputeDecibel() - self.ComputeScores(feats, in_cache) + self.ComputeScores(feats, cache) if not is_final: self.DetectCommonFrames() else: @@ -521,7 +521,7 @@ class FsmnVAD(nn.Module): if is_final: # reset class variables and clear the dict for the next query self.AllResetDetection() - return segments, in_cache + return segments, cache def generate(self, data_in, @@ -561,7 +561,7 @@ class FsmnVAD(nn.Module): feats = speech feats_len = speech_lengths.max().item() waveform = pad_sequence(audio_sample_list, batch_first=True).to(device=kwargs["device"]) # data: [batch, N] - in_cache = kwargs.get("in_cache", {}) + cache = kwargs.get("cache", {}) batch_size = kwargs.get("batch_size", 1) step = min(feats_len, 6000) segments = [[]] * batch_size @@ -576,11 +576,11 @@ class FsmnVAD(nn.Module): "feats": feats[:, t_offset:t_offset + step, :], "waveform": waveform[:, t_offset * 160:min(waveform.shape[-1], (t_offset + step - 1) * 160 + 400)], "is_final": is_final, - "in_cache": in_cache + "cache": cache } - segments_part, in_cache = self.forward(**batch) + segments_part, cache = self.forward(**batch) if segments_part: for batch_num in range(0, batch_size): segments[batch_num] += segments_part[batch_num] @@ -604,46 +604,6 @@ class FsmnVAD(nn.Module): return results, meta_data - def forward_online(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(), - is_final: bool = False, max_end_sil: int = 800 - ) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]: - if not in_cache: - self.AllResetDetection() - self.max_end_sil_frame_cnt_thresh = max_end_sil - self.vad_opts.speech_to_sil_time_thres - self.waveform = waveform # compute decibel for each frame - - self.ComputeScores(feats, in_cache) - self.ComputeDecibel() - if not is_final: - self.DetectCommonFrames() - else: - self.DetectLastFrames() - segments = [] - for batch_num in range(0, feats.shape[0]): # only support batch_size = 1 now - segment_batch = [] - if len(self.output_data_buf) > 0: - for i in range(self.output_data_buf_offset, len(self.output_data_buf)): - if not self.output_data_buf[i].contain_seg_start_point: - continue - if not self.next_seg and not self.output_data_buf[i].contain_seg_end_point: - continue - start_ms = self.output_data_buf[i].start_ms if self.next_seg else -1 - if self.output_data_buf[i].contain_seg_end_point: - end_ms = self.output_data_buf[i].end_ms - self.next_seg = True - self.output_data_buf_offset += 1 - else: - end_ms = -1 - self.next_seg = False - segment = [start_ms, end_ms] - segment_batch.append(segment) - if segment_batch: - segments.append(segment_batch) - if is_final: - # reset class variables and clear the dict for the next query - self.AllResetDetection() - return segments, in_cache - def DetectCommonFrames(self) -> int: if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected: return 0 diff --git a/funasr/models/fsmn_vad_streaming/__init__.py b/funasr/models/fsmn_vad_streaming/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/funasr/models/fsmn_vad_streaming/encoder.py b/funasr/models/fsmn_vad_streaming/encoder.py new file mode 100755 index 000000000..ae9185222 --- /dev/null +++ b/funasr/models/fsmn_vad_streaming/encoder.py @@ -0,0 +1,303 @@ +from typing import Tuple, Dict +import copy + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from funasr.register import tables + +class LinearTransform(nn.Module): + + def __init__(self, input_dim, output_dim): + super(LinearTransform, self).__init__() + self.input_dim = input_dim + self.output_dim = output_dim + self.linear = nn.Linear(input_dim, output_dim, bias=False) + + def forward(self, input): + output = self.linear(input) + + return output + + +class AffineTransform(nn.Module): + + def __init__(self, input_dim, output_dim): + super(AffineTransform, self).__init__() + self.input_dim = input_dim + self.output_dim = output_dim + self.linear = nn.Linear(input_dim, output_dim) + + def forward(self, input): + output = self.linear(input) + + return output + + +class RectifiedLinear(nn.Module): + + def __init__(self, input_dim, output_dim): + super(RectifiedLinear, self).__init__() + self.dim = input_dim + self.relu = nn.ReLU() + self.dropout = nn.Dropout(0.1) + + def forward(self, input): + out = self.relu(input) + return out + + +class FSMNBlock(nn.Module): + + def __init__( + self, + input_dim: int, + output_dim: int, + lorder=None, + rorder=None, + lstride=1, + rstride=1, + ): + super(FSMNBlock, self).__init__() + + self.dim = input_dim + + if lorder is None: + return + + self.lorder = lorder + self.rorder = rorder + self.lstride = lstride + self.rstride = rstride + + self.conv_left = nn.Conv2d( + self.dim, self.dim, [lorder, 1], dilation=[lstride, 1], groups=self.dim, bias=False) + + if self.rorder > 0: + self.conv_right = nn.Conv2d( + self.dim, self.dim, [rorder, 1], dilation=[rstride, 1], groups=self.dim, bias=False) + else: + self.conv_right = None + + def forward(self, input: torch.Tensor, cache: torch.Tensor): + x = torch.unsqueeze(input, 1) + x_per = x.permute(0, 3, 2, 1) # B D T C + + cache = cache.to(x_per.device) + y_left = torch.cat((cache, x_per), dim=2) + cache = y_left[:, :, -(self.lorder - 1) * self.lstride:, :] + y_left = self.conv_left(y_left) + out = x_per + y_left + + if self.conv_right is not None: + # maybe need to check + y_right = F.pad(x_per, [0, 0, 0, self.rorder * self.rstride]) + y_right = y_right[:, :, self.rstride:, :] + y_right = self.conv_right(y_right) + out += y_right + + out_per = out.permute(0, 3, 2, 1) + output = out_per.squeeze(1) + + return output, cache + + +class BasicBlock(nn.Module): + def __init__(self, + linear_dim: int, + proj_dim: int, + lorder: int, + rorder: int, + lstride: int, + rstride: int, + stack_layer: int + ): + super(BasicBlock, self).__init__() + self.lorder = lorder + self.rorder = rorder + self.lstride = lstride + self.rstride = rstride + self.stack_layer = stack_layer + self.linear = LinearTransform(linear_dim, proj_dim) + self.fsmn_block = FSMNBlock(proj_dim, proj_dim, lorder, rorder, lstride, rstride) + self.affine = AffineTransform(proj_dim, linear_dim) + self.relu = RectifiedLinear(linear_dim, linear_dim) + + def forward(self, input: torch.Tensor, cache: Dict[str, torch.Tensor]): + x1 = self.linear(input) # B T D + cache_layer_name = 'cache_layer_{}'.format(self.stack_layer) + if cache_layer_name not in cache: + cache[cache_layer_name] = torch.zeros(x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1) + x2, cache[cache_layer_name] = self.fsmn_block(x1, cache[cache_layer_name]) + x3 = self.affine(x2) + x4 = self.relu(x3) + return x4 + + +class FsmnStack(nn.Sequential): + def __init__(self, *args): + super(FsmnStack, self).__init__(*args) + + def forward(self, input: torch.Tensor, cache: Dict[str, torch.Tensor]): + x = input + for module in self._modules.values(): + x = module(x, cache) + return x + + +''' +FSMN net for keyword spotting +input_dim: input dimension +linear_dim: fsmn input dimensionll +proj_dim: fsmn projection dimension +lorder: fsmn left order +rorder: fsmn right order +num_syn: output dimension +fsmn_layers: no. of sequential fsmn layers +''' + +@tables.register("encoder_classes", "FSMN") +class FSMN(nn.Module): + def __init__( + self, + input_dim: int, + input_affine_dim: int, + fsmn_layers: int, + linear_dim: int, + proj_dim: int, + lorder: int, + rorder: int, + lstride: int, + rstride: int, + output_affine_dim: int, + output_dim: int + ): + super(FSMN, self).__init__() + + self.input_dim = input_dim + self.input_affine_dim = input_affine_dim + self.fsmn_layers = fsmn_layers + self.linear_dim = linear_dim + self.proj_dim = proj_dim + self.output_affine_dim = output_affine_dim + self.output_dim = output_dim + + self.in_linear1 = AffineTransform(input_dim, input_affine_dim) + self.in_linear2 = AffineTransform(input_affine_dim, linear_dim) + self.relu = RectifiedLinear(linear_dim, linear_dim) + self.fsmn = FsmnStack(*[BasicBlock(linear_dim, proj_dim, lorder, rorder, lstride, rstride, i) for i in + range(fsmn_layers)]) + self.out_linear1 = AffineTransform(linear_dim, output_affine_dim) + self.out_linear2 = AffineTransform(output_affine_dim, output_dim) + self.softmax = nn.Softmax(dim=-1) + + def fuse_modules(self): + pass + + def forward( + self, + input: torch.Tensor, + cache: Dict[str, torch.Tensor] + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + """ + Args: + input (torch.Tensor): Input tensor (B, T, D) + cache: when cache is not None, the forward is in streaming. The type of cache is a dict, egs, + {'cache_layer_1': torch.Tensor(B, T1, D)}, T1 is equal to self.lorder. It is {} for the 1st frame + """ + + x1 = self.in_linear1(input) + x2 = self.in_linear2(x1) + x3 = self.relu(x2) + x4 = self.fsmn(x3, cache) # self.cache will update automatically in self.fsmn + x5 = self.out_linear1(x4) + x6 = self.out_linear2(x5) + x7 = self.softmax(x6) + + return x7 + + +''' +one deep fsmn layer +dimproj: projection dimension, input and output dimension of memory blocks +dimlinear: dimension of mapping layer +lorder: left order +rorder: right order +lstride: left stride +rstride: right stride +''' + +@tables.register("encoder_classes", "DFSMN") +class DFSMN(nn.Module): + + def __init__(self, dimproj=64, dimlinear=128, lorder=20, rorder=1, lstride=1, rstride=1): + super(DFSMN, self).__init__() + + self.lorder = lorder + self.rorder = rorder + self.lstride = lstride + self.rstride = rstride + + self.expand = AffineTransform(dimproj, dimlinear) + self.shrink = LinearTransform(dimlinear, dimproj) + + self.conv_left = nn.Conv2d( + dimproj, dimproj, [lorder, 1], dilation=[lstride, 1], groups=dimproj, bias=False) + + if rorder > 0: + self.conv_right = nn.Conv2d( + dimproj, dimproj, [rorder, 1], dilation=[rstride, 1], groups=dimproj, bias=False) + else: + self.conv_right = None + + def forward(self, input): + f1 = F.relu(self.expand(input)) + p1 = self.shrink(f1) + + x = torch.unsqueeze(p1, 1) + x_per = x.permute(0, 3, 2, 1) + + y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0]) + + if self.conv_right is not None: + y_right = F.pad(x_per, [0, 0, 0, (self.rorder) * self.rstride]) + y_right = y_right[:, :, self.rstride:, :] + out = x_per + self.conv_left(y_left) + self.conv_right(y_right) + else: + out = x_per + self.conv_left(y_left) + + out1 = out.permute(0, 3, 2, 1) + output = input + out1.squeeze(1) + + return output + + +''' +build stacked dfsmn layers +''' + + +def buildDFSMNRepeats(linear_dim=128, proj_dim=64, lorder=20, rorder=1, fsmn_layers=6): + repeats = [ + nn.Sequential( + DFSMN(proj_dim, linear_dim, lorder, rorder, 1, 1)) + for i in range(fsmn_layers) + ] + + return nn.Sequential(*repeats) + + +if __name__ == '__main__': + fsmn = FSMN(400, 140, 4, 250, 128, 10, 2, 1, 1, 140, 2599) + print(fsmn) + + num_params = sum(p.numel() for p in fsmn.parameters()) + print('the number of model params: {}'.format(num_params)) + x = torch.zeros(128, 200, 400) # batch-size * time * dim + y, _ = fsmn(x) # batch-size * time * dim + print('input shape: {}'.format(x.shape)) + print('output shape: {}'.format(y.shape)) + + print(fsmn.to_kaldi_net()) diff --git a/funasr/models/fsmn_vad_streaming/model.py b/funasr/models/fsmn_vad_streaming/model.py new file mode 100644 index 000000000..4c7e94309 --- /dev/null +++ b/funasr/models/fsmn_vad_streaming/model.py @@ -0,0 +1,781 @@ +from enum import Enum +from typing import List, Tuple, Dict, Any +import logging +import os +import json +import torch +from torch import nn +import math +from typing import Optional +import time +from funasr.register import tables +from funasr.utils.load_utils import load_audio_text_image_video,extract_fbank +from funasr.utils.datadir_writer import DatadirWriter +from torch.nn.utils.rnn import pad_sequence + +class VadStateMachine(Enum): + kVadInStateStartPointNotDetected = 1 + kVadInStateInSpeechSegment = 2 + kVadInStateEndPointDetected = 3 + + +class FrameState(Enum): + kFrameStateInvalid = -1 + kFrameStateSpeech = 1 + kFrameStateSil = 0 + + +# final voice/unvoice state per frame +class AudioChangeState(Enum): + kChangeStateSpeech2Speech = 0 + kChangeStateSpeech2Sil = 1 + kChangeStateSil2Sil = 2 + kChangeStateSil2Speech = 3 + kChangeStateNoBegin = 4 + kChangeStateInvalid = 5 + + +class VadDetectMode(Enum): + kVadSingleUtteranceDetectMode = 0 + kVadMutipleUtteranceDetectMode = 1 + + +class VADXOptions: + """ + Author: Speech Lab of DAMO Academy, Alibaba Group + Deep-FSMN for Large Vocabulary Continuous Speech Recognition + https://arxiv.org/abs/1803.05030 + """ + def __init__( + self, + sample_rate: int = 16000, + detect_mode: int = VadDetectMode.kVadMutipleUtteranceDetectMode.value, + snr_mode: int = 0, + max_end_silence_time: int = 800, + max_start_silence_time: int = 3000, + do_start_point_detection: bool = True, + do_end_point_detection: bool = True, + window_size_ms: int = 200, + sil_to_speech_time_thres: int = 150, + speech_to_sil_time_thres: int = 150, + speech_2_noise_ratio: float = 1.0, + do_extend: int = 1, + lookback_time_start_point: int = 200, + lookahead_time_end_point: int = 100, + max_single_segment_time: int = 60000, + nn_eval_block_size: int = 8, + dcd_block_size: int = 4, + snr_thres: int = -100.0, + noise_frame_num_used_for_snr: int = 100, + decibel_thres: int = -100.0, + speech_noise_thres: float = 0.6, + fe_prior_thres: float = 1e-4, + silence_pdf_num: int = 1, + sil_pdf_ids: List[int] = [0], + speech_noise_thresh_low: float = -0.1, + speech_noise_thresh_high: float = 0.3, + output_frame_probs: bool = False, + frame_in_ms: int = 10, + frame_length_ms: int = 25, + **kwargs, + ): + self.sample_rate = sample_rate + self.detect_mode = detect_mode + self.snr_mode = snr_mode + self.max_end_silence_time = max_end_silence_time + self.max_start_silence_time = max_start_silence_time + self.do_start_point_detection = do_start_point_detection + self.do_end_point_detection = do_end_point_detection + self.window_size_ms = window_size_ms + self.sil_to_speech_time_thres = sil_to_speech_time_thres + self.speech_to_sil_time_thres = speech_to_sil_time_thres + self.speech_2_noise_ratio = speech_2_noise_ratio + self.do_extend = do_extend + self.lookback_time_start_point = lookback_time_start_point + self.lookahead_time_end_point = lookahead_time_end_point + self.max_single_segment_time = max_single_segment_time + self.nn_eval_block_size = nn_eval_block_size + self.dcd_block_size = dcd_block_size + self.snr_thres = snr_thres + self.noise_frame_num_used_for_snr = noise_frame_num_used_for_snr + self.decibel_thres = decibel_thres + self.speech_noise_thres = speech_noise_thres + self.fe_prior_thres = fe_prior_thres + self.silence_pdf_num = silence_pdf_num + self.sil_pdf_ids = sil_pdf_ids + self.speech_noise_thresh_low = speech_noise_thresh_low + self.speech_noise_thresh_high = speech_noise_thresh_high + self.output_frame_probs = output_frame_probs + self.frame_in_ms = frame_in_ms + self.frame_length_ms = frame_length_ms + + +class E2EVadSpeechBufWithDoa(object): + """ + Author: Speech Lab of DAMO Academy, Alibaba Group + Deep-FSMN for Large Vocabulary Continuous Speech Recognition + https://arxiv.org/abs/1803.05030 + """ + def __init__(self): + self.start_ms = 0 + self.end_ms = 0 + self.buffer = [] + self.contain_seg_start_point = False + self.contain_seg_end_point = False + self.doa = 0 + + def Reset(self): + self.start_ms = 0 + self.end_ms = 0 + self.buffer = [] + self.contain_seg_start_point = False + self.contain_seg_end_point = False + self.doa = 0 + + +class E2EVadFrameProb(object): + """ + Author: Speech Lab of DAMO Academy, Alibaba Group + Deep-FSMN for Large Vocabulary Continuous Speech Recognition + https://arxiv.org/abs/1803.05030 + """ + def __init__(self): + self.noise_prob = 0.0 + self.speech_prob = 0.0 + self.score = 0.0 + self.frame_id = 0 + self.frm_state = 0 + + +class WindowDetector(object): + """ + Author: Speech Lab of DAMO Academy, Alibaba Group + Deep-FSMN for Large Vocabulary Continuous Speech Recognition + https://arxiv.org/abs/1803.05030 + """ + def __init__(self, window_size_ms: int, sil_to_speech_time: int, + speech_to_sil_time: int, frame_size_ms: int): + self.window_size_ms = window_size_ms + self.sil_to_speech_time = sil_to_speech_time + self.speech_to_sil_time = speech_to_sil_time + self.frame_size_ms = frame_size_ms + + self.win_size_frame = int(window_size_ms / frame_size_ms) + self.win_sum = 0 + self.win_state = [0] * self.win_size_frame # 初始化窗 + + self.cur_win_pos = 0 + self.pre_frame_state = FrameState.kFrameStateSil + self.cur_frame_state = FrameState.kFrameStateSil + self.sil_to_speech_frmcnt_thres = int(sil_to_speech_time / frame_size_ms) + self.speech_to_sil_frmcnt_thres = int(speech_to_sil_time / frame_size_ms) + + self.voice_last_frame_count = 0 + self.noise_last_frame_count = 0 + self.hydre_frame_count = 0 + + def Reset(self) -> None: + self.cur_win_pos = 0 + self.win_sum = 0 + self.win_state = [0] * self.win_size_frame + self.pre_frame_state = FrameState.kFrameStateSil + self.cur_frame_state = FrameState.kFrameStateSil + self.voice_last_frame_count = 0 + self.noise_last_frame_count = 0 + self.hydre_frame_count = 0 + + def GetWinSize(self) -> int: + return int(self.win_size_frame) + + def DetectOneFrame(self, frameState: FrameState, frame_count: int) -> AudioChangeState: + cur_frame_state = FrameState.kFrameStateSil + if frameState == FrameState.kFrameStateSpeech: + cur_frame_state = 1 + elif frameState == FrameState.kFrameStateSil: + cur_frame_state = 0 + else: + return AudioChangeState.kChangeStateInvalid + self.win_sum -= self.win_state[self.cur_win_pos] + self.win_sum += cur_frame_state + self.win_state[self.cur_win_pos] = cur_frame_state + self.cur_win_pos = (self.cur_win_pos + 1) % self.win_size_frame + + if self.pre_frame_state == FrameState.kFrameStateSil and self.win_sum >= self.sil_to_speech_frmcnt_thres: + self.pre_frame_state = FrameState.kFrameStateSpeech + return AudioChangeState.kChangeStateSil2Speech + + if self.pre_frame_state == FrameState.kFrameStateSpeech and self.win_sum <= self.speech_to_sil_frmcnt_thres: + self.pre_frame_state = FrameState.kFrameStateSil + return AudioChangeState.kChangeStateSpeech2Sil + + if self.pre_frame_state == FrameState.kFrameStateSil: + return AudioChangeState.kChangeStateSil2Sil + if self.pre_frame_state == FrameState.kFrameStateSpeech: + return AudioChangeState.kChangeStateSpeech2Speech + return AudioChangeState.kChangeStateInvalid + + def FrameSizeMs(self) -> int: + return int(self.frame_size_ms) + + +@tables.register("model_classes", "FsmnVADStreaming") +class FsmnVADStreaming(nn.Module): + """ + Author: Speech Lab of DAMO Academy, Alibaba Group + Deep-FSMN for Large Vocabulary Continuous Speech Recognition + https://arxiv.org/abs/1803.05030 + """ + def __init__(self, + encoder: str = None, + encoder_conf: Optional[Dict] = None, + vad_post_args: Dict[str, Any] = None, + **kwargs, + ): + super().__init__() + self.vad_opts = VADXOptions(**kwargs) + self.windows_detector = WindowDetector(self.vad_opts.window_size_ms, + self.vad_opts.sil_to_speech_time_thres, + self.vad_opts.speech_to_sil_time_thres, + self.vad_opts.frame_in_ms) + + encoder_class = tables.encoder_classes.get(encoder.lower()) + encoder = encoder_class(**encoder_conf) + self.encoder = encoder + # init variables + self.data_buf_start_frame = 0 + self.frm_cnt = 0 + self.latest_confirmed_speech_frame = 0 + self.lastest_confirmed_silence_frame = -1 + self.continous_silence_frame_count = 0 + self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected + self.confirmed_start_frame = -1 + self.confirmed_end_frame = -1 + self.number_end_time_detected = 0 + self.sil_frame = 0 + self.sil_pdf_ids = self.vad_opts.sil_pdf_ids + self.noise_average_decibel = -100.0 + self.pre_end_silence_detected = False + self.next_seg = True + + self.output_data_buf = [] + self.output_data_buf_offset = 0 + self.frame_probs = [] + self.max_end_sil_frame_cnt_thresh = self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres + self.speech_noise_thres = self.vad_opts.speech_noise_thres + self.scores = None + self.max_time_out = False + self.decibel = [] + self.data_buf = None + self.data_buf_all = None + self.waveform = None + self.last_drop_frames = 0 + + def AllResetDetection(self): + self.data_buf_start_frame = 0 + self.frm_cnt = 0 + self.latest_confirmed_speech_frame = 0 + self.lastest_confirmed_silence_frame = -1 + self.continous_silence_frame_count = 0 + self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected + self.confirmed_start_frame = -1 + self.confirmed_end_frame = -1 + self.number_end_time_detected = 0 + self.sil_frame = 0 + self.sil_pdf_ids = self.vad_opts.sil_pdf_ids + self.noise_average_decibel = -100.0 + self.pre_end_silence_detected = False + self.next_seg = True + + self.output_data_buf = [] + self.output_data_buf_offset = 0 + self.frame_probs = [] + self.max_end_sil_frame_cnt_thresh = self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres + self.speech_noise_thres = self.vad_opts.speech_noise_thres + self.scores = None + self.max_time_out = False + self.decibel = [] + self.data_buf = None + self.data_buf_all = None + self.waveform = None + self.last_drop_frames = 0 + self.windows_detector.Reset() + + def ResetDetection(self): + self.continous_silence_frame_count = 0 + self.latest_confirmed_speech_frame = 0 + self.lastest_confirmed_silence_frame = -1 + self.confirmed_start_frame = -1 + self.confirmed_end_frame = -1 + self.vad_state_machine = VadStateMachine.kVadInStateStartPointNotDetected + self.windows_detector.Reset() + self.sil_frame = 0 + self.frame_probs = [] + + if self.output_data_buf: + assert self.output_data_buf[-1].contain_seg_end_point == True + drop_frames = int(self.output_data_buf[-1].end_ms / self.vad_opts.frame_in_ms) + real_drop_frames = drop_frames - self.last_drop_frames + self.last_drop_frames = drop_frames + self.data_buf_all = self.data_buf_all[real_drop_frames * int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):] + self.decibel = self.decibel[real_drop_frames:] + self.scores = self.scores[:, real_drop_frames:, :] + + def ComputeDecibel(self) -> None: + frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000) + frame_shift_length = int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000) + if self.data_buf_all is None: + self.data_buf_all = self.waveform[0] # self.data_buf is pointed to self.waveform[0] + self.data_buf = self.data_buf_all + else: + self.data_buf_all = torch.cat((self.data_buf_all, self.waveform[0])) + for offset in range(0, self.waveform.shape[1] - frame_sample_length + 1, frame_shift_length): + self.decibel.append( + 10 * math.log10((self.waveform[0][offset: offset + frame_sample_length]).square().sum() + \ + 0.000001)) + + def ComputeScores(self, feats: torch.Tensor, cache: Dict[str, torch.Tensor]) -> None: + scores = self.encoder(feats, cache).to('cpu') # return B * T * D + assert scores.shape[1] == feats.shape[1], "The shape between feats and scores does not match" + self.vad_opts.nn_eval_block_size = scores.shape[1] + self.frm_cnt += scores.shape[1] # count total frames + if self.scores is None: + self.scores = scores # the first calculation + else: + self.scores = torch.cat((self.scores, scores), dim=1) + + def PopDataBufTillFrame(self, frame_idx: int) -> None: # need check again + while self.data_buf_start_frame < frame_idx: + if len(self.data_buf) >= int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000): + self.data_buf_start_frame += 1 + self.data_buf = self.data_buf_all[(self.data_buf_start_frame - self.last_drop_frames) * int( + self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):] + + def PopDataToOutputBuf(self, start_frm: int, frm_cnt: int, first_frm_is_start_point: bool, + last_frm_is_end_point: bool, end_point_is_sent_end: bool) -> None: + self.PopDataBufTillFrame(start_frm) + expected_sample_number = int(frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000) + if last_frm_is_end_point: + extra_sample = max(0, int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000 - \ + self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000)) + expected_sample_number += int(extra_sample) + if end_point_is_sent_end: + expected_sample_number = max(expected_sample_number, len(self.data_buf)) + if len(self.data_buf) < expected_sample_number: + print('error in calling pop data_buf\n') + + if len(self.output_data_buf) == 0 or first_frm_is_start_point: + self.output_data_buf.append(E2EVadSpeechBufWithDoa()) + self.output_data_buf[-1].Reset() + self.output_data_buf[-1].start_ms = start_frm * self.vad_opts.frame_in_ms + self.output_data_buf[-1].end_ms = self.output_data_buf[-1].start_ms + self.output_data_buf[-1].doa = 0 + cur_seg = self.output_data_buf[-1] + if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms: + print('warning\n') + out_pos = len(cur_seg.buffer) # cur_seg.buff现在没做任何操作 + data_to_pop = 0 + if end_point_is_sent_end: + data_to_pop = expected_sample_number + else: + data_to_pop = int(frm_cnt * self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000) + if data_to_pop > len(self.data_buf): + print('VAD data_to_pop is bigger than self.data_buf.size()!!!\n') + data_to_pop = len(self.data_buf) + expected_sample_number = len(self.data_buf) + + cur_seg.doa = 0 + for sample_cpy_out in range(0, data_to_pop): + # cur_seg.buffer[out_pos ++] = data_buf_.back(); + out_pos += 1 + for sample_cpy_out in range(data_to_pop, expected_sample_number): + # cur_seg.buffer[out_pos++] = data_buf_.back() + out_pos += 1 + if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms: + print('Something wrong with the VAD algorithm\n') + self.data_buf_start_frame += frm_cnt + cur_seg.end_ms = (start_frm + frm_cnt) * self.vad_opts.frame_in_ms + if first_frm_is_start_point: + cur_seg.contain_seg_start_point = True + if last_frm_is_end_point: + cur_seg.contain_seg_end_point = True + + def OnSilenceDetected(self, valid_frame: int): + self.lastest_confirmed_silence_frame = valid_frame + if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected: + self.PopDataBufTillFrame(valid_frame) + # silence_detected_callback_ + # pass + + def OnVoiceDetected(self, valid_frame: int) -> None: + self.latest_confirmed_speech_frame = valid_frame + self.PopDataToOutputBuf(valid_frame, 1, False, False, False) + + def OnVoiceStart(self, start_frame: int, fake_result: bool = False) -> None: + if self.vad_opts.do_start_point_detection: + pass + if self.confirmed_start_frame != -1: + print('not reset vad properly\n') + else: + self.confirmed_start_frame = start_frame + + if not fake_result and self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected: + self.PopDataToOutputBuf(self.confirmed_start_frame, 1, True, False, False) + + def OnVoiceEnd(self, end_frame: int, fake_result: bool, is_last_frame: bool) -> None: + for t in range(self.latest_confirmed_speech_frame + 1, end_frame): + self.OnVoiceDetected(t) + if self.vad_opts.do_end_point_detection: + pass + if self.confirmed_end_frame != -1: + print('not reset vad properly\n') + else: + self.confirmed_end_frame = end_frame + if not fake_result: + self.sil_frame = 0 + self.PopDataToOutputBuf(self.confirmed_end_frame, 1, False, True, is_last_frame) + self.number_end_time_detected += 1 + + def MaybeOnVoiceEndIfLastFrame(self, is_final_frame: bool, cur_frm_idx: int) -> None: + if is_final_frame: + self.OnVoiceEnd(cur_frm_idx, False, True) + self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected + + def GetLatency(self) -> int: + return int(self.LatencyFrmNumAtStartPoint() * self.vad_opts.frame_in_ms) + + def LatencyFrmNumAtStartPoint(self) -> int: + vad_latency = self.windows_detector.GetWinSize() + if self.vad_opts.do_extend: + vad_latency += int(self.vad_opts.lookback_time_start_point / self.vad_opts.frame_in_ms) + return vad_latency + + def GetFrameState(self, t: int): + frame_state = FrameState.kFrameStateInvalid + cur_decibel = self.decibel[t] + cur_snr = cur_decibel - self.noise_average_decibel + # for each frame, calc log posterior probability of each state + if cur_decibel < self.vad_opts.decibel_thres: + frame_state = FrameState.kFrameStateSil + self.DetectOneFrame(frame_state, t, False) + return frame_state + + sum_score = 0.0 + noise_prob = 0.0 + assert len(self.sil_pdf_ids) == self.vad_opts.silence_pdf_num + if len(self.sil_pdf_ids) > 0: + assert len(self.scores) == 1 # 只支持batch_size = 1的测试 + sil_pdf_scores = [self.scores[0][t][sil_pdf_id] for sil_pdf_id in self.sil_pdf_ids] + sum_score = sum(sil_pdf_scores) + noise_prob = math.log(sum_score) * self.vad_opts.speech_2_noise_ratio + total_score = 1.0 + sum_score = total_score - sum_score + speech_prob = math.log(sum_score) + if self.vad_opts.output_frame_probs: + frame_prob = E2EVadFrameProb() + frame_prob.noise_prob = noise_prob + frame_prob.speech_prob = speech_prob + frame_prob.score = sum_score + frame_prob.frame_id = t + self.frame_probs.append(frame_prob) + if math.exp(speech_prob) >= math.exp(noise_prob) + self.speech_noise_thres: + if cur_snr >= self.vad_opts.snr_thres and cur_decibel >= self.vad_opts.decibel_thres: + frame_state = FrameState.kFrameStateSpeech + else: + frame_state = FrameState.kFrameStateSil + else: + frame_state = FrameState.kFrameStateSil + if self.noise_average_decibel < -99.9: + self.noise_average_decibel = cur_decibel + else: + self.noise_average_decibel = (cur_decibel + self.noise_average_decibel * ( + self.vad_opts.noise_frame_num_used_for_snr + - 1)) / self.vad_opts.noise_frame_num_used_for_snr + + return frame_state + + def forward(self, feats: torch.Tensor, waveform: torch.tensor, cache: Dict[str, torch.Tensor] = dict(), + is_final: bool = False + ): + if not cache: + self.AllResetDetection() + self.waveform = waveform # compute decibel for each frame + self.ComputeDecibel() + self.ComputeScores(feats, cache) + if not is_final: + self.DetectCommonFrames() + else: + self.DetectLastFrames() + segments = [] + for batch_num in range(0, feats.shape[0]): # only support batch_size = 1 now + segment_batch = [] + if len(self.output_data_buf) > 0: + for i in range(self.output_data_buf_offset, len(self.output_data_buf)): + if not is_final and (not self.output_data_buf[i].contain_seg_start_point or not self.output_data_buf[ + i].contain_seg_end_point): + continue + segment = [self.output_data_buf[i].start_ms, self.output_data_buf[i].end_ms] + segment_batch.append(segment) + self.output_data_buf_offset += 1 # need update this parameter + if segment_batch: + segments.append(segment_batch) + if is_final: + # reset class variables and clear the dict for the next query + self.AllResetDetection() + return segments, cache + + def init_cache(self, cache: dict = {}, **kwargs): + cache["frontend"] = {} + cache["prev_samples"] = torch.empty(0) + + return cache + def generate(self, + data_in, + data_lengths=None, + key: list = None, + tokenizer=None, + frontend=None, + cache: dict = {}, + **kwargs, + ): + + if len(cache) == 0: + self.init_cache(cache, **kwargs) + + meta_data = {} + chunk_size = kwargs.get("chunk_size", 50) # 50ms + chunk_stride_samples = chunk_size * 16 + + time1 = time.perf_counter() + cfg = {"is_final": kwargs.get("is_final", False)} + audio_sample_list = load_audio_text_image_video(data_in, + fs=frontend.fs, + audio_fs=kwargs.get("fs", 16000), + data_type=kwargs.get("data_type", "sound"), + tokenizer=tokenizer, + **cfg, + ) + _is_final = cfg["is_final"] # if data_in is a file or url, set is_final=True + + time2 = time.perf_counter() + meta_data["load_data"] = f"{time2 - time1:0.3f}" + assert len(audio_sample_list) == 1, "batch_size must be set 1" + + audio_sample = torch.cat((cache["prev_samples"], audio_sample_list[0])) + + n = len(audio_sample) // chunk_stride_samples + int(_is_final) + m = len(audio_sample) % chunk_stride_samples * (1 - int(_is_final)) + tokens = [] + for i in range(n): + kwargs["is_final"] = _is_final and i == n - 1 + audio_sample_i = audio_sample[i * chunk_stride_samples:(i + 1) * chunk_stride_samples] + + # extract fbank feats + speech, speech_lengths = extract_fbank([audio_sample_i], data_type=kwargs.get("data_type", "sound"), + frontend=frontend, cache=cache["frontend"], + is_final=kwargs["is_final"]) + time3 = time.perf_counter() + meta_data["extract_feat"] = f"{time3 - time2:0.3f}" + meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000 + + meta_data = {} + audio_sample_list = [data_in] + if isinstance(data_in, torch.Tensor): # fbank + speech, speech_lengths = data_in, data_lengths + if len(speech.shape) < 3: + speech = speech[None, :, :] + if speech_lengths is None: + speech_lengths = speech.shape[1] + else: + # extract fbank feats + time1 = time.perf_counter() + audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000)) + time2 = time.perf_counter() + meta_data["load_data"] = f"{time2 - time1:0.3f}" + speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), + frontend=frontend) + time3 = time.perf_counter() + meta_data["extract_feat"] = f"{time3 - time2:0.3f}" + meta_data[ + "batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000 + + speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"]) + + # b. Forward Encoder streaming + t_offset = 0 + feats = speech + feats_len = speech_lengths.max().item() + waveform = pad_sequence(audio_sample_list, batch_first=True).to(device=kwargs["device"]) # data: [batch, N] + cache = kwargs.get("cache", {}) + batch_size = kwargs.get("batch_size", 1) + step = min(feats_len, 6000) + segments = [[]] * batch_size + + for t_offset in range(0, feats_len, min(step, feats_len - t_offset)): + if t_offset + step >= feats_len - 1: + step = feats_len - t_offset + is_final = True + else: + is_final = False + batch = { + "feats": feats[:, t_offset:t_offset + step, :], + "waveform": waveform[:, t_offset * 160:min(waveform.shape[-1], (t_offset + step - 1) * 160 + 400)], + "is_final": is_final, + "cache": cache + } + + + segments_part, cache = self.forward(**batch) + if segments_part: + for batch_num in range(0, batch_size): + segments[batch_num] += segments_part[batch_num] + + ibest_writer = None + if ibest_writer is None and kwargs.get("output_dir") is not None: + writer = DatadirWriter(kwargs.get("output_dir")) + ibest_writer = writer[f"{1}best_recog"] + + results = [] + for i in range(batch_size): + + if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas": + results[i] = json.dumps(results[i]) + + if ibest_writer is not None: + ibest_writer["text"][key[i]] = segments[i] + + result_i = {"key": key[i], "value": segments[i]} + results.append(result_i) + + return results, meta_data + + + def DetectCommonFrames(self) -> int: + if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected: + return 0 + for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1): + frame_state = FrameState.kFrameStateInvalid + frame_state = self.GetFrameState(self.frm_cnt - 1 - i - self.last_drop_frames) + self.DetectOneFrame(frame_state, self.frm_cnt - 1 - i, False) + + return 0 + + def DetectLastFrames(self) -> int: + if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected: + return 0 + for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1): + frame_state = FrameState.kFrameStateInvalid + frame_state = self.GetFrameState(self.frm_cnt - 1 - i - self.last_drop_frames) + if i != 0: + self.DetectOneFrame(frame_state, self.frm_cnt - 1 - i, False) + else: + self.DetectOneFrame(frame_state, self.frm_cnt - 1, True) + + return 0 + + def DetectOneFrame(self, cur_frm_state: FrameState, cur_frm_idx: int, is_final_frame: bool) -> None: + tmp_cur_frm_state = FrameState.kFrameStateInvalid + if cur_frm_state == FrameState.kFrameStateSpeech: + if math.fabs(1.0) > self.vad_opts.fe_prior_thres: + tmp_cur_frm_state = FrameState.kFrameStateSpeech + else: + tmp_cur_frm_state = FrameState.kFrameStateSil + elif cur_frm_state == FrameState.kFrameStateSil: + tmp_cur_frm_state = FrameState.kFrameStateSil + state_change = self.windows_detector.DetectOneFrame(tmp_cur_frm_state, cur_frm_idx) + frm_shift_in_ms = self.vad_opts.frame_in_ms + if AudioChangeState.kChangeStateSil2Speech == state_change: + silence_frame_count = self.continous_silence_frame_count + self.continous_silence_frame_count = 0 + self.pre_end_silence_detected = False + start_frame = 0 + if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected: + start_frame = max(self.data_buf_start_frame, cur_frm_idx - self.LatencyFrmNumAtStartPoint()) + self.OnVoiceStart(start_frame) + self.vad_state_machine = VadStateMachine.kVadInStateInSpeechSegment + for t in range(start_frame + 1, cur_frm_idx + 1): + self.OnVoiceDetected(t) + elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment: + for t in range(self.latest_confirmed_speech_frame + 1, cur_frm_idx): + self.OnVoiceDetected(t) + if cur_frm_idx - self.confirmed_start_frame + 1 > \ + self.vad_opts.max_single_segment_time / frm_shift_in_ms: + self.OnVoiceEnd(cur_frm_idx, False, False) + self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected + elif not is_final_frame: + self.OnVoiceDetected(cur_frm_idx) + else: + self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx) + else: + pass + elif AudioChangeState.kChangeStateSpeech2Sil == state_change: + self.continous_silence_frame_count = 0 + if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected: + pass + elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment: + if cur_frm_idx - self.confirmed_start_frame + 1 > \ + self.vad_opts.max_single_segment_time / frm_shift_in_ms: + self.OnVoiceEnd(cur_frm_idx, False, False) + self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected + elif not is_final_frame: + self.OnVoiceDetected(cur_frm_idx) + else: + self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx) + else: + pass + elif AudioChangeState.kChangeStateSpeech2Speech == state_change: + self.continous_silence_frame_count = 0 + if self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment: + if cur_frm_idx - self.confirmed_start_frame + 1 > \ + self.vad_opts.max_single_segment_time / frm_shift_in_ms: + self.max_time_out = True + self.OnVoiceEnd(cur_frm_idx, False, False) + self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected + elif not is_final_frame: + self.OnVoiceDetected(cur_frm_idx) + else: + self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx) + else: + pass + elif AudioChangeState.kChangeStateSil2Sil == state_change: + self.continous_silence_frame_count += 1 + if self.vad_state_machine == VadStateMachine.kVadInStateStartPointNotDetected: + # silence timeout, return zero length decision + if ((self.vad_opts.detect_mode == VadDetectMode.kVadSingleUtteranceDetectMode.value) and ( + self.continous_silence_frame_count * frm_shift_in_ms > self.vad_opts.max_start_silence_time)) \ + or (is_final_frame and self.number_end_time_detected == 0): + for t in range(self.lastest_confirmed_silence_frame + 1, cur_frm_idx): + self.OnSilenceDetected(t) + self.OnVoiceStart(0, True) + self.OnVoiceEnd(0, True, False); + self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected + else: + if cur_frm_idx >= self.LatencyFrmNumAtStartPoint(): + self.OnSilenceDetected(cur_frm_idx - self.LatencyFrmNumAtStartPoint()) + elif self.vad_state_machine == VadStateMachine.kVadInStateInSpeechSegment: + if self.continous_silence_frame_count * frm_shift_in_ms >= self.max_end_sil_frame_cnt_thresh: + lookback_frame = int(self.max_end_sil_frame_cnt_thresh / frm_shift_in_ms) + if self.vad_opts.do_extend: + lookback_frame -= int(self.vad_opts.lookahead_time_end_point / frm_shift_in_ms) + lookback_frame -= 1 + lookback_frame = max(0, lookback_frame) + self.OnVoiceEnd(cur_frm_idx - lookback_frame, False, False) + self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected + elif cur_frm_idx - self.confirmed_start_frame + 1 > \ + self.vad_opts.max_single_segment_time / frm_shift_in_ms: + self.OnVoiceEnd(cur_frm_idx, False, False) + self.vad_state_machine = VadStateMachine.kVadInStateEndPointDetected + elif self.vad_opts.do_extend and not is_final_frame: + if self.continous_silence_frame_count <= int( + self.vad_opts.lookahead_time_end_point / frm_shift_in_ms): + self.OnVoiceDetected(cur_frm_idx) + else: + self.MaybeOnVoiceEndIfLastFrame(is_final_frame, cur_frm_idx) + else: + pass + + if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected and \ + self.vad_opts.detect_mode == VadDetectMode.kVadMutipleUtteranceDetectMode.value: + self.ResetDetection() + + + diff --git a/funasr/models/fsmn_vad_streaming/template.yaml b/funasr/models/fsmn_vad_streaming/template.yaml new file mode 100644 index 000000000..e8a3a4f30 --- /dev/null +++ b/funasr/models/fsmn_vad_streaming/template.yaml @@ -0,0 +1,62 @@ +# This is an example that demonstrates how to configure a model file. +# You can modify the configuration according to your own requirements. + +# to print the register_table: +# from funasr.register import tables +# tables.print() + +# network architecture +model: FsmnVADStreaming +model_conf: + sample_rate: 16000 + detect_mode: 1 + snr_mode: 0 + max_end_silence_time: 800 + max_start_silence_time: 3000 + do_start_point_detection: True + do_end_point_detection: True + window_size_ms: 200 + sil_to_speech_time_thres: 150 + speech_to_sil_time_thres: 150 + speech_2_noise_ratio: 1.0 + do_extend: 1 + lookback_time_start_point: 200 + lookahead_time_end_point: 100 + max_single_segment_time: 60000 + snr_thres: -100.0 + noise_frame_num_used_for_snr: 100 + decibel_thres: -100.0 + speech_noise_thres: 0.6 + fe_prior_thres: 0.0001 + silence_pdf_num: 1 + sil_pdf_ids: [0] + speech_noise_thresh_low: -0.1 + speech_noise_thresh_high: 0.3 + output_frame_probs: False + frame_in_ms: 10 + frame_length_ms: 25 + +encoder: FSMN +encoder_conf: + input_dim: 400 + input_affine_dim: 140 + fsmn_layers: 4 + linear_dim: 250 + proj_dim: 128 + lorder: 20 + rorder: 0 + lstride: 1 + rstride: 0 + output_affine_dim: 140 + output_dim: 248 + +frontend: WavFrontend +frontend_conf: + fs: 16000 + window: hamming + n_mels: 80 + frame_length: 25 + frame_shift: 10 + dither: 0.0 + lfr_m: 5 + lfr_n: 1 diff --git a/funasr/models/paraformer_streaming/model.py b/funasr/models/paraformer_streaming/model.py index 927b09199..fdc0c9312 100644 --- a/funasr/models/paraformer_streaming/model.py +++ b/funasr/models/paraformer_streaming/model.py @@ -519,16 +519,23 @@ class ParaformerStreaming(Paraformer): if len(cache) == 0: self.init_cache(cache, **kwargs) - _is_final = kwargs.get("is_final", False) + meta_data = {} chunk_size = kwargs.get("chunk_size", [0, 10, 5]) chunk_stride_samples = chunk_size[1] * 960 # 600ms time1 = time.perf_counter() - audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000), - data_type=kwargs.get("data_type", "sound"), - tokenizer=tokenizer) + cfg = {"is_final": kwargs.get("is_final", False)} + audio_sample_list = load_audio_text_image_video(data_in, + fs=frontend.fs, + audio_fs=kwargs.get("fs", 16000), + data_type=kwargs.get("data_type", "sound"), + tokenizer=tokenizer, + **cfg, + ) + _is_final = cfg["is_final"] # if data_in is a file or url, set is_final=True + time2 = time.perf_counter() meta_data["load_data"] = f"{time2 - time1:0.3f}" assert len(audio_sample_list) == 1, "batch_size must be set 1" diff --git a/funasr/models/paraformer_streaming/template.yaml b/funasr/models/paraformer_streaming/template.yaml new file mode 100644 index 000000000..d1300ac79 --- /dev/null +++ b/funasr/models/paraformer_streaming/template.yaml @@ -0,0 +1,143 @@ +# This is an example that demonstrates how to configure a model file. +# You can modify the configuration according to your own requirements. + +# to print the register_table: +# from funasr.register import tables +# tables.print() + +# network architecture +model: ParaformerStreaming +model_conf: + ctc_weight: 0.0 + lsm_weight: 0.1 + length_normalized_loss: true + predictor_weight: 1.0 + predictor_bias: 1 + sampling_ratio: 0.75 + +# encoder +encoder: SANMEncoderChunkOpt +encoder_conf: + output_size: 512 + attention_heads: 4 + linear_units: 2048 + num_blocks: 50 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.1 + input_layer: pe_online + pos_enc_class: SinusoidalPositionEncoder + normalize_before: true + kernel_size: 11 + sanm_shfit: 0 + selfattention_layer_type: sanm + chunk_size: + - 12 + - 15 + stride: + - 8 + - 10 + pad_left: + - 0 + - 0 + encoder_att_look_back_factor: + - 4 + - 4 + decoder_att_look_back_factor: + - 1 + - 1 + +# decoder +decoder: ParaformerSANMDecoder +decoder_conf: + attention_heads: 4 + linear_units: 2048 + num_blocks: 16 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + self_attention_dropout_rate: 0.1 + src_attention_dropout_rate: 0.1 + att_layer_num: 16 + kernel_size: 11 + sanm_shfit: 5 + +predictor: CifPredictorV2 +predictor_conf: + idim: 512 + threshold: 1.0 + l_order: 1 + r_order: 1 + tail_threshold: 0.45 + +# frontend related +frontend: WavFrontendOnline +frontend_conf: + fs: 16000 + window: hamming + n_mels: 80 + frame_length: 25 + frame_shift: 10 + lfr_m: 7 + lfr_n: 6 + +specaug: SpecAugLFR +specaug_conf: + apply_time_warp: false + time_warp_window: 5 + time_warp_mode: bicubic + apply_freq_mask: true + freq_mask_width_range: + - 0 + - 30 + lfr_rate: 6 + num_freq_mask: 1 + apply_time_mask: true + time_mask_width_range: + - 0 + - 12 + num_time_mask: 1 + +train_conf: + accum_grad: 1 + grad_clip: 5 + max_epoch: 150 + val_scheduler_criterion: + - valid + - acc + best_model_criterion: + - - valid + - acc + - max + keep_nbest_models: 10 + log_interval: 50 + +optim: adam +optim_conf: + lr: 0.0005 +scheduler: warmuplr +scheduler_conf: + warmup_steps: 30000 + +dataset: AudioDataset +dataset_conf: + index_ds: IndexDSJsonl + batch_sampler: DynamicBatchLocalShuffleSampler + batch_type: example # example or length + batch_size: 1 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len; + max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length, + buffer_size: 500 + shuffle: True + num_workers: 0 + +tokenizer: CharTokenizer +tokenizer_conf: + unk_symbol: + split_with_space: true + + +ctc_conf: + dropout_rate: 0.0 + ctc_type: builtin + reduce: true + ignore_nan_grad: true +normalize: null diff --git a/funasr/utils/load_utils.py b/funasr/utils/load_utils.py index bb9cf01b9..638e0ac4f 100644 --- a/funasr/utils/load_utils.py +++ b/funasr/utils/load_utils.py @@ -16,7 +16,7 @@ except: -def load_audio_text_image_video(data_or_path_or_list, fs: int = 16000, audio_fs: int = 16000, data_type="sound", tokenizer=None): +def load_audio_text_image_video(data_or_path_or_list, fs: int = 16000, audio_fs: int = 16000, data_type="sound", tokenizer=None, **kwargs): if isinstance(data_or_path_or_list, (list, tuple)): if data_type is not None and isinstance(data_type, (list, tuple)): @@ -26,20 +26,29 @@ def load_audio_text_image_video(data_or_path_or_list, fs: int = 16000, audio_fs: for j, (data_type_j, data_or_path_or_list_j) in enumerate(zip(data_type_i, data_or_path_or_list_i)): - data_or_path_or_list_j = load_audio_text_image_video(data_or_path_or_list_j, fs=fs, audio_fs=audio_fs, data_type=data_type_j, tokenizer=tokenizer) + data_or_path_or_list_j = load_audio_text_image_video(data_or_path_or_list_j, fs=fs, audio_fs=audio_fs, data_type=data_type_j, tokenizer=tokenizer, **kwargs) data_or_path_or_list_ret[j].append(data_or_path_or_list_j) return data_or_path_or_list_ret else: - return [load_audio_text_image_video(audio, fs=fs, audio_fs=audio_fs, data_type=data_type) for audio in data_or_path_or_list] - if isinstance(data_or_path_or_list, str) and data_or_path_or_list.startswith('http'): + return [load_audio_text_image_video(audio, fs=fs, audio_fs=audio_fs, data_type=data_type, **kwargs) for audio in data_or_path_or_list] + + if isinstance(data_or_path_or_list, str) and data_or_path_or_list.startswith('http'): # download url to local file data_or_path_or_list = download_from_url(data_or_path_or_list) - if isinstance(data_or_path_or_list, str) and os.path.exists(data_or_path_or_list): + + if isinstance(data_or_path_or_list, str) and os.path.exists(data_or_path_or_list): # local file if data_type is None or data_type == "sound": data_or_path_or_list, audio_fs = torchaudio.load(data_or_path_or_list) data_or_path_or_list = data_or_path_or_list[0, :] - # elif data_type == "text" and tokenizer is not None: - # data_or_path_or_list = tokenizer.encode(data_or_path_or_list) + elif data_type == "text" and tokenizer is not None: + data_or_path_or_list = tokenizer.encode(data_or_path_or_list) + elif data_type == "image": # undo + pass + elif data_type == "video": # undo + pass + + # if data_in is a file or url, set is_final=True + kwargs["is_final"] = True elif isinstance(data_or_path_or_list, str) and data_type == "text" and tokenizer is not None: data_or_path_or_list = tokenizer.encode(data_or_path_or_list) elif isinstance(data_or_path_or_list, np.ndarray): # audio sample point