From 31eed1834f9ff17d6246008f64d3e061f58ef80a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=87=8C=E5=8C=80?= Date: Mon, 27 Feb 2023 13:33:55 +0800 Subject: [PATCH] in_cache & support soundfile read --- .../bin/asr_inference_paraformer_vad_punc.py | 96 +------------------ funasr/bin/vad_inference.py | 26 +++-- funasr/models/e2e_vad.py | 34 ++++--- funasr/models/encoder/fsmn_encoder.py | 44 ++++----- 4 files changed, 54 insertions(+), 146 deletions(-) diff --git a/funasr/bin/asr_inference_paraformer_vad_punc.py b/funasr/bin/asr_inference_paraformer_vad_punc.py index 96f70eff7..13208778f 100644 --- a/funasr/bin/asr_inference_paraformer_vad_punc.py +++ b/funasr/bin/asr_inference_paraformer_vad_punc.py @@ -43,6 +43,7 @@ from funasr.utils.types import str_or_none from funasr.utils import asr_utils, wav_utils, postprocess_utils from funasr.models.frontend.wav_frontend import WavFrontend from funasr.tasks.vad import VADTask +from funasr.bin.vad_inference import Speech2VadSegment from funasr.utils.timestamp_tools import time_stamp_lfr6_pl from funasr.bin.punctuation_infer import Text2Punc from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer @@ -364,101 +365,6 @@ class Speech2Text: hotword_list = None return hotword_list -class Speech2VadSegment: - """Speech2VadSegment class - - Examples: - >>> import soundfile - >>> speech2segment = Speech2VadSegment("vad_config.yml", "vad.pt") - >>> audio, rate = soundfile.read("speech.wav") - >>> speech2segment(audio) - [[10, 230], [245, 450], ...] - - """ - - def __init__( - self, - vad_infer_config: Union[Path, str] = None, - vad_model_file: Union[Path, str] = None, - vad_cmvn_file: Union[Path, str] = None, - device: str = "cpu", - batch_size: int = 1, - dtype: str = "float32", - **kwargs, - ): - assert check_argument_types() - - # 1. Build vad model - vad_model, vad_infer_args = VADTask.build_model_from_file( - vad_infer_config, vad_model_file, device - ) - frontend = None - if vad_infer_args.frontend is not None: - frontend = WavFrontend(cmvn_file=vad_cmvn_file, **vad_infer_args.frontend_conf) - - # logging.info("vad_model: {}".format(vad_model)) - # logging.info("vad_infer_args: {}".format(vad_infer_args)) - vad_model.to(dtype=getattr(torch, dtype)).eval() - - self.vad_model = vad_model - self.vad_infer_args = vad_infer_args - self.device = device - self.dtype = dtype - self.frontend = frontend - self.batch_size = batch_size - - @torch.no_grad() - def __call__( - self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None - ) -> List[List[int]]: - """Inference - - Args: - speech: Input speech data - Returns: - text, token, token_int, hyp - - """ - assert check_argument_types() - - # Input as audio signal - if isinstance(speech, np.ndarray): - speech = torch.tensor(speech) - - if self.frontend is not None: - self.frontend.filter_length_max = math.inf - fbanks, fbanks_len = self.frontend.forward_fbank(speech, speech_lengths) - feats, feats_len = self.frontend.forward_lfr_cmvn(fbanks, fbanks_len) - fbanks = to_device(fbanks, device=self.device) - feats = to_device(feats, device=self.device) - feats_len = feats_len.int() - else: - raise Exception("Need to extract feats first, please configure frontend configuration") - - # b. Forward Encoder streaming - t_offset = 0 - step = min(feats_len, 6000) - segments = [[]] * self.batch_size - for t_offset in range(0, feats_len, min(step, feats_len - t_offset)): - if t_offset + step >= feats_len - 1: - step = feats_len - t_offset - is_final_send = True - else: - is_final_send = False - batch = { - "feats": feats[:, t_offset:t_offset + step, :], - "waveform": speech[:, t_offset * 160:min(speech.shape[-1], (t_offset + step - 1) * 160 + 400)], - "is_final_send": is_final_send - } - # a. To device - batch = to_device(batch, device=self.device) - segments_part = self.vad_model(**batch) - if segments_part: - for batch_num in range(0, self.batch_size): - segments[batch_num] += segments_part[batch_num] - - return fbanks, segments - def inference( maxlenratio: float, diff --git a/funasr/bin/vad_inference.py b/funasr/bin/vad_inference.py index 607f131dd..258b38b33 100644 --- a/funasr/bin/vad_inference.py +++ b/funasr/bin/vad_inference.py @@ -11,6 +11,7 @@ from typing import Tuple from typing import Union from typing import Dict +import math import numpy as np import torch from typeguard import check_argument_types @@ -86,7 +87,7 @@ class Speech2VadSegment: @torch.no_grad() def __call__( self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None - ) -> List[List[int]]: + ) -> Tuple[List[List[int]], Dict[str, torch.Tensor]]: """Inference Args: @@ -102,7 +103,10 @@ class Speech2VadSegment: speech = torch.tensor(speech) if self.frontend is not None: - feats, feats_len = self.frontend.forward(speech, speech_lengths) + self.frontend.filter_length_max = math.inf + fbanks, fbanks_len = self.frontend.forward_fbank(speech, speech_lengths) + feats, feats_len = self.frontend.forward_lfr_cmvn(fbanks, fbanks_len) + fbanks = to_device(fbanks, device=self.device) feats = to_device(feats, device=self.device) feats_len = feats_len.int() else: @@ -110,18 +114,18 @@ class Speech2VadSegment: # b. Forward Encoder streaming t_offset = 0 - step = min(feats_len, 6000) + step = min(feats_len.max(), 6000) segments = [[]] * self.batch_size for t_offset in range(0, feats_len, min(step, feats_len - t_offset)): if t_offset + step >= feats_len - 1: step = feats_len - t_offset - is_final_send = True + is_final = True else: - is_final_send = False + is_final = False batch = { "feats": feats[:, t_offset:t_offset + step, :], "waveform": speech[:, t_offset * 160:min(speech.shape[-1], (t_offset + step - 1) * 160 + 400)], - "is_final_send": is_final_send + "is_final": is_final } # a. To device batch = to_device(batch, device=self.device) @@ -129,7 +133,7 @@ class Speech2VadSegment: if segments_part: for batch_num in range(0, self.batch_size): segments[batch_num] += segments_part[batch_num] - return segments + return fbanks, segments def inference( @@ -219,9 +223,13 @@ def inference_modelscope( raw_inputs: Union[np.ndarray, torch.Tensor] = None, output_dir_v2: Optional[str] = None, fs: dict = None, - param_dict: dict = None, + param_dict: dict = None ): # 3. Build data-iterator + if data_path_and_name_and_type is None and raw_inputs is not None: + if isinstance(raw_inputs, torch.Tensor): + raw_inputs = raw_inputs.numpy() + data_path_and_name_and_type = [raw_inputs, "speech", "waveform"] loader = VADTask.build_streaming_iterator( data_path_and_name_and_type, dtype=dtype, @@ -254,7 +262,7 @@ def inference_modelscope( assert len(keys) == _bs, f"{len(keys)} != {_bs}" # do vad segment - results = speech2vadsegment(**batch) + _, results = speech2vadsegment(**batch) for i, _ in enumerate(keys): results[i] = json.dumps(results[i]) item = {'key': keys[i], 'value': results[i]} diff --git a/funasr/models/e2e_vad.py b/funasr/models/e2e_vad.py index b64c677f3..c21be1b95 100755 --- a/funasr/models/e2e_vad.py +++ b/funasr/models/e2e_vad.py @@ -201,7 +201,7 @@ class E2EVadModel(nn.Module): self.vad_opts.frame_in_ms) self.encoder = encoder # init variables - self.is_final_send = False + self.is_final = False self.data_buf_start_frame = 0 self.frm_cnt = 0 self.latest_confirmed_speech_frame = 0 @@ -230,8 +230,7 @@ class E2EVadModel(nn.Module): self.ResetDetection() def AllResetDetection(self): - self.encoder.cache_reset() # reset the in_cache in self.encoder for next query or next long sentence - self.is_final_send = False + self.is_final = False self.data_buf_start_frame = 0 self.frm_cnt = 0 self.latest_confirmed_speech_frame = 0 @@ -283,8 +282,8 @@ class E2EVadModel(nn.Module): 10 * math.log10((self.waveform[0][offset: offset + frame_sample_length]).square().sum() + \ 0.000001)) - def ComputeScores(self, feats: torch.Tensor) -> None: - scores = self.encoder(feats) # return B * T * D + def ComputeScores(self, feats: torch.Tensor, in_cache: Dict[str, torch.Tensor]) -> None: + scores = self.encoder(feats, in_cache) # return B * T * D assert scores.shape[1] == feats.shape[1], "The shape between feats and scores does not match" self.vad_opts.nn_eval_block_size = scores.shape[1] self.frm_cnt += scores.shape[1] # count total frames @@ -306,7 +305,7 @@ class E2EVadModel(nn.Module): expected_sample_number = int(frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000) if last_frm_is_end_point: extra_sample = max(0, int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000 - \ - self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000)) + self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000)) expected_sample_number += int(extra_sample) if end_point_is_sent_end: expected_sample_number = max(expected_sample_number, len(self.data_buf)) @@ -443,11 +442,13 @@ class E2EVadModel(nn.Module): return frame_state - def forward(self, feats: torch.Tensor, waveform: torch.tensor, is_final_send: bool = False) -> List[List[List[int]]]: + def forward(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(), + is_final: bool = False + ) -> List[List[List[int]]]: self.waveform = waveform # compute decibel for each frame self.ComputeDecibel() - self.ComputeScores(feats) - if not is_final_send: + self.ComputeScores(feats, in_cache) + if not is_final: self.DetectCommonFrames() else: self.DetectLastFrames() @@ -456,15 +457,18 @@ class E2EVadModel(nn.Module): segment_batch = [] if len(self.output_data_buf) > 0: for i in range(self.output_data_buf_offset, len(self.output_data_buf)): - if self.output_data_buf[i].contain_seg_start_point and self.output_data_buf[ + if not self.output_data_buf[i].contain_seg_start_point or not self.output_data_buf[ i].contain_seg_end_point: - segment = [self.output_data_buf[i].start_ms, self.output_data_buf[i].end_ms] - segment_batch.append(segment) - self.output_data_buf_offset += 1 # need update this parameter + continue + segment = [self.output_data_buf[i].start_ms, self.output_data_buf[i].end_ms] + segment_batch.append(segment) + self.output_data_buf_offset += 1 # need update this parameter if segment_batch: segments.append(segment_batch) - if is_final_send: - self.AllResetDetection() + if is_final: + # reset class variables and clear the dict for the next query + self.AllResetDetection() + in_cache.clear() return segments def DetectCommonFrames(self) -> int: diff --git a/funasr/models/encoder/fsmn_encoder.py b/funasr/models/encoder/fsmn_encoder.py index 54a113ddd..c749dc438 100755 --- a/funasr/models/encoder/fsmn_encoder.py +++ b/funasr/models/encoder/fsmn_encoder.py @@ -79,14 +79,12 @@ class FSMNBlock(nn.Module): else: self.conv_right = None - def forward(self, input: torch.Tensor, in_cache=None): + def forward(self, input: torch.Tensor, cache: torch.Tensor): x = torch.unsqueeze(input, 1) x_per = x.permute(0, 3, 2, 1) # B D T C - if in_cache is None: # offline - y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0]) - else: - y_left = torch.cat((in_cache, x_per), dim=2) - in_cache = y_left[:, :, -(self.lorder - 1) * self.lstride:, :] + + y_left = torch.cat((cache, x_per), dim=2) + cache = y_left[:, :, -(self.lorder - 1) * self.lstride:, :] y_left = self.conv_left(y_left) out = x_per + y_left @@ -100,7 +98,7 @@ class FSMNBlock(nn.Module): out_per = out.permute(0, 3, 2, 1) output = out_per.squeeze(1) - return output, in_cache + return output, cache class BasicBlock(nn.Sequential): @@ -124,28 +122,25 @@ class BasicBlock(nn.Sequential): self.affine = AffineTransform(proj_dim, linear_dim) self.relu = RectifiedLinear(linear_dim, linear_dim) - def forward(self, input: torch.Tensor, in_cache=None): + def forward(self, input: torch.Tensor, in_cache: Dict[str, torch.Tensor]): x1 = self.linear(input) # B T D - if in_cache is not None: # Dict[str, tensor.Tensor] - cache_layer_name = 'cache_layer_{}'.format(self.stack_layer) - if cache_layer_name not in in_cache: - in_cache[cache_layer_name] = torch.zeros(x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1) - x2, in_cache[cache_layer_name] = self.fsmn_block(x1, in_cache[cache_layer_name]) - else: - x2, _ = self.fsmn_block(x1) + cache_layer_name = 'cache_layer_{}'.format(self.stack_layer) + if cache_layer_name not in in_cache: + in_cache[cache_layer_name] = torch.zeros(x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1) + x2, in_cache[cache_layer_name] = self.fsmn_block(x1, in_cache[cache_layer_name]) x3 = self.affine(x2) x4 = self.relu(x3) - return x4, in_cache + return x4 class FsmnStack(nn.Sequential): def __init__(self, *args): super(FsmnStack, self).__init__(*args) - def forward(self, input: torch.Tensor, in_cache=None): + def forward(self, input: torch.Tensor, in_cache: Dict[str, torch.Tensor]): x = input for module in self._modules.values(): - x, in_cache = module(x, in_cache) + x = module(x, in_cache) return x @@ -174,8 +169,7 @@ class FSMN(nn.Module): lstride: int, rstride: int, output_affine_dim: int, - output_dim: int, - streaming=False + output_dim: int ): super(FSMN, self).__init__() @@ -186,8 +180,6 @@ class FSMN(nn.Module): self.proj_dim = proj_dim self.output_affine_dim = output_affine_dim self.output_dim = output_dim - self.in_cache_original = dict() if streaming else None - self.in_cache = copy.deepcopy(self.in_cache_original) self.in_linear1 = AffineTransform(input_dim, input_affine_dim) self.in_linear2 = AffineTransform(input_affine_dim, linear_dim) @@ -201,12 +193,10 @@ class FSMN(nn.Module): def fuse_modules(self): pass - def cache_reset(self): - self.in_cache = copy.deepcopy(self.in_cache_original) - def forward( self, input: torch.Tensor, + in_cache: Dict[str, torch.Tensor] ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Args: @@ -218,7 +208,7 @@ class FSMN(nn.Module): x1 = self.in_linear1(input) x2 = self.in_linear2(x1) x3 = self.relu(x2) - x4 = self.fsmn(x3, self.in_cache) # if in_cache is not None, self.fsmn is streaming's format, it will update automatically in self.fsmn + x4 = self.fsmn(x3, in_cache) # self.in_cache will update automatically in self.fsmn x5 = self.out_linear1(x4) x6 = self.out_linear2(x5) x7 = self.softmax(x6) @@ -307,4 +297,4 @@ if __name__ == '__main__': print('input shape: {}'.format(x.shape)) print('output shape: {}'.format(y.shape)) - print(fsmn.to_kaldi_net()) + print(fsmn.to_kaldi_net()) \ No newline at end of file