diff --git a/funasr/bin/vad_inference.py b/funasr/bin/vad_inference.py index 9f1d0f310..0d9659401 100644 --- a/funasr/bin/vad_inference.py +++ b/funasr/bin/vad_inference.py @@ -1,6 +1,7 @@ import argparse import logging import sys +import json from pathlib import Path from typing import Any from typing import List @@ -105,19 +106,34 @@ class Speech2VadSegment: feats_len = feats_len.int() else: raise Exception("Need to extract feats first, please configure frontend configuration") - batch = {"feats": feats, "feats_lengths": feats_len, "waveform": speech} + # batch = {"feats": feats, "waveform": speech, "is_final_send": True} + # segments = self.vad_model(**batch) - # a. To device - batch = to_device(batch, device=self.device) - - # b. Forward Encoder - segments = self.vad_model(**batch) + # b. Forward Encoder sreaming + segments = [] + step = 6000 + t_offset = 0 + for t_offset in range(0, feats_len, min(step, feats_len - t_offset)): + if t_offset + step >= feats_len - 1: + step = feats_len - t_offset + is_final_send = True + else: + is_final_send = False + batch = { + "feats": feats[:, t_offset:t_offset + step, :], + "waveform": speech[:, t_offset * 160:min(speech.shape[-1], (t_offset + step - 1) * 160 + 400)], + "is_final_send": is_final_send + } + # a. To device + batch = to_device(batch, device=self.device) + segments_part = self.vad_model(**batch) + if segments_part: + segments += segments_part + #print(segments) return segments - - def inference( batch_size: int, ngpu: int, @@ -152,11 +168,12 @@ def inference( ) return inference_pipeline(data_path_and_name_and_type, raw_inputs) + def inference_modelscope( batch_size: int, ngpu: int, log_level: Union[int, str], - #data_path_and_name_and_type, + # data_path_and_name_and_type, vad_infer_config: Optional[str], vad_model_file: Optional[str], vad_cmvn_file: Optional[str] = None, @@ -167,7 +184,6 @@ def inference_modelscope( dtype: str = "float32", seed: int = 0, num_workers: int = 1, - param_dict: dict = None, **kwargs, ): assert check_argument_types() @@ -201,11 +217,11 @@ def inference_modelscope( speech2vadsegment = Speech2VadSegment(**speech2vadsegment_kwargs) def _forward( - data_path_and_name_and_type, - raw_inputs: Union[np.ndarray, torch.Tensor] = None, - output_dir_v2: Optional[str] = None, - fs: dict = None, - param_dict: dict = None, + data_path_and_name_and_type, + raw_inputs: Union[np.ndarray, torch.Tensor] = None, + output_dir_v2: Optional[str] = None, + fs: dict = None, + param_dict: dict = None, ): # 3. Build data-iterator loader = VADTask.build_streaming_iterator( @@ -243,9 +259,11 @@ def inference_modelscope( # do vad segment results = speech2vadsegment(**batch) for i, _ in enumerate(keys): + results[i] = json.dumps(results[i]) item = {'key': keys[i], 'value': results[i]} vad_results.append(item) if writer is not None: + results[i] = json.loads(results[i]) ibest_writer["text"][keys[i]] = "{}".format(results[i]) return vad_results diff --git a/funasr/bin/vad_inference_launch.py b/funasr/bin/vad_inference_launch.py index 54bf31984..42c5c1e12 100644 --- a/funasr/bin/vad_inference_launch.py +++ b/funasr/bin/vad_inference_launch.py @@ -107,14 +107,16 @@ def get_parser(): def inference_launch(mode, **kwargs): - if mode == "vad": + if mode == "offline": from funasr.bin.vad_inference import inference_modelscope return inference_modelscope(**kwargs) + elif mode == "online": + from funasr.bin.vad_inference_online import inference_modelscope + return inference_modelscope(**kwargs) else: logging.info("Unknown decoding mode: {}".format(mode)) return None - def main(cmd=None): print(get_commandline_args(), file=sys.stderr) parser = get_parser() diff --git a/funasr/models/e2e_vad.py b/funasr/models/e2e_vad.py index 98504d6b4..8afc8db6d 100755 --- a/funasr/models/e2e_vad.py +++ b/funasr/models/e2e_vad.py @@ -5,7 +5,6 @@ import torch from torch import nn import math from funasr.models.encoder.fsmn_encoder import FSMN -# from checkpoint import load_checkpoint class VadStateMachine(Enum): @@ -136,7 +135,7 @@ class WindowDetector(object): self.win_size_frame = int(window_size_ms / frame_size_ms) self.win_sum = 0 - self.win_state = [0 for i in range(0, self.win_size_frame)] # 初始化窗 + self.win_state = [0] * self.win_size_frame # 初始化窗 self.cur_win_pos = 0 self.pre_frame_state = FrameState.kFrameStateSil @@ -151,7 +150,7 @@ class WindowDetector(object): def Reset(self) -> None: self.cur_win_pos = 0 self.win_sum = 0 - self.win_state = [0 for i in range(0, self.win_size_frame)] + self.win_state = [0] * self.win_size_frame self.pre_frame_state = FrameState.kFrameStateSil self.cur_frame_state = FrameState.kFrameStateSil self.voice_last_frame_count = 0 @@ -192,8 +191,8 @@ class WindowDetector(object): return int(self.frame_size_ms) -class E2EVadModel(torch.nn.Module): - def __init__(self, encoder: FSMN, vad_post_args: Dict[str, Any]): +class E2EVadModel(nn.Module): + def __init__(self, encoder: FSMN, vad_post_args: Dict[str, Any], streaming=False): super(E2EVadModel, self).__init__() self.vad_opts = VADXOptions(**vad_post_args) self.windows_detector = WindowDetector(self.vad_opts.window_size_ms, @@ -212,13 +211,13 @@ class E2EVadModel(torch.nn.Module): self.confirmed_start_frame = -1 self.confirmed_end_frame = -1 self.number_end_time_detected = 0 - self.is_callback_with_sign = False self.sil_frame = 0 self.sil_pdf_ids = self.vad_opts.sil_pdf_ids self.noise_average_decibel = -100.0 self.pre_end_silence_detected = False self.output_data_buf = [] + self.output_data_buf_offset = 0 self.frame_probs = [] self.max_end_sil_frame_cnt_thresh = self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres self.speech_noise_thres = self.vad_opts.speech_noise_thres @@ -226,10 +225,13 @@ class E2EVadModel(torch.nn.Module): self.max_time_out = False self.decibel = [] self.data_buf = None + self.data_buf_all = None self.waveform = None + self.streaming = streaming self.ResetDetection() def AllResetDetection(self): + self.encoder.cache_reset() # reset the in_cache in self.encoder for next query or next long sentence self.is_final_send = False self.data_buf_start_frame = 0 self.frm_cnt = 0 @@ -240,13 +242,13 @@ class E2EVadModel(torch.nn.Module): self.confirmed_start_frame = -1 self.confirmed_end_frame = -1 self.number_end_time_detected = 0 - self.is_callback_with_sign = False self.sil_frame = 0 self.sil_pdf_ids = self.vad_opts.sil_pdf_ids self.noise_average_decibel = -100.0 self.pre_end_silence_detected = False self.output_data_buf = [] + self.output_data_buf_offset = 0 self.frame_probs = [] self.max_end_sil_frame_cnt_thresh = self.vad_opts.max_end_silence_time - self.vad_opts.speech_to_sil_time_thres self.speech_noise_thres = self.vad_opts.speech_noise_thres @@ -254,6 +256,7 @@ class E2EVadModel(torch.nn.Module): self.max_time_out = False self.decibel = [] self.data_buf = None + self.data_buf_all = None self.waveform = None self.ResetDetection() @@ -271,26 +274,32 @@ class E2EVadModel(torch.nn.Module): def ComputeDecibel(self) -> None: frame_sample_length = int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000) frame_shift_length = int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000) - self.data_buf = self.waveform[0] # 指向self.waveform[0] + if self.data_buf_all is None: + self.data_buf_all = self.waveform[0] # self.data_buf is pointed to self.waveform[0] + self.data_buf = self.data_buf_all + else: + self.data_buf_all = torch.cat((self.data_buf_all, self.waveform[0])) for offset in range(0, self.waveform.shape[1] - frame_sample_length + 1, frame_shift_length): self.decibel.append( 10 * math.log10((self.waveform[0][offset: offset + frame_sample_length]).square().sum() + \ 0.000001)) - def ComputeScores(self, feats: torch.Tensor, feats_lengths: int) -> None: - self.scores = self.encoder(feats) # return B * T * D - self.frm_cnt = feats_lengths # frame - # return self.scores + def ComputeScores(self, feats: torch.Tensor) -> None: + scores = self.encoder(feats) # return B * T * D + assert scores.shape[1] == feats.shape[1], "The shape between feats and scores does not match" + self.vad_opts.nn_eval_block_size = scores.shape[1] + self.frm_cnt += scores.shape[1] # count total frames + if self.scores is None: + self.scores = scores # the first calculation + else: + self.scores = torch.cat((self.scores, scores), dim=1) def PopDataBufTillFrame(self, frame_idx: int) -> None: # need check again while self.data_buf_start_frame < frame_idx: if len(self.data_buf) >= int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000): self.data_buf_start_frame += 1 - self.data_buf = self.waveform[0][self.data_buf_start_frame * int( + self.data_buf = self.data_buf_all[self.data_buf_start_frame * int( self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000):] - # for i in range(0, int(self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000)): - # self.data_buf.popleft() - # self.data_buf_start_frame += 1 def PopDataToOutputBuf(self, start_frm: int, frm_cnt: int, first_frm_is_start_point: bool, last_frm_is_end_point: bool, end_point_is_sent_end: bool) -> None: @@ -301,8 +310,9 @@ class E2EVadModel(torch.nn.Module): self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000)) expected_sample_number += int(extra_sample) if end_point_is_sent_end: - # expected_sample_number = max(expected_sample_number, len(self.data_buf)) - pass + expected_sample_number = max(expected_sample_number, len(self.data_buf)) + if len(self.data_buf) < expected_sample_number: + print('error in calling pop data_buf\n') if len(self.output_data_buf) == 0 or first_frm_is_start_point: self.output_data_buf.append(E2EVadSpeechBufWithDoa()) @@ -312,15 +322,18 @@ class E2EVadModel(torch.nn.Module): self.output_data_buf[-1].doa = 0 cur_seg = self.output_data_buf[-1] if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms: - print('warning') + print('warning\n') out_pos = len(cur_seg.buffer) # cur_seg.buff现在没做任何操作 data_to_pop = 0 if end_point_is_sent_end: data_to_pop = expected_sample_number else: data_to_pop = int(frm_cnt * self.vad_opts.frame_in_ms * self.vad_opts.sample_rate / 1000) - # if data_to_pop > len(self.data_buf_) - # pass + if data_to_pop > len(self.data_buf): + print('VAD data_to_pop is bigger than self.data_buf.size()!!!\n') + data_to_pop = len(self.data_buf) + expected_sample_number = len(self.data_buf) + cur_seg.doa = 0 for sample_cpy_out in range(0, data_to_pop): # cur_seg.buffer[out_pos ++] = data_buf_.back(); @@ -329,7 +342,7 @@ class E2EVadModel(torch.nn.Module): # cur_seg.buffer[out_pos++] = data_buf_.back() out_pos += 1 if cur_seg.end_ms != start_frm * self.vad_opts.frame_in_ms: - print('warning') + print('Something wrong with the VAD algorithm\n') self.data_buf_start_frame += frm_cnt cur_seg.end_ms = (start_frm + frm_cnt) * self.vad_opts.frame_in_ms if first_frm_is_start_point: @@ -346,14 +359,13 @@ class E2EVadModel(torch.nn.Module): def OnVoiceDetected(self, valid_frame: int) -> None: self.latest_confirmed_speech_frame = valid_frame - if True: # is_new_api_enable_ = True - self.PopDataToOutputBuf(valid_frame, 1, False, False, False) + self.PopDataToOutputBuf(valid_frame, 1, False, False, False) def OnVoiceStart(self, start_frame: int, fake_result: bool = False) -> None: if self.vad_opts.do_start_point_detection: pass if self.confirmed_start_frame != -1: - print('warning') + print('not reset vad properly\n') else: self.confirmed_start_frame = start_frame @@ -366,7 +378,7 @@ class E2EVadModel(torch.nn.Module): if self.vad_opts.do_end_point_detection: pass if self.confirmed_end_frame != -1: - print('warning') + print('not reset vad properly\n') else: self.confirmed_end_frame = end_frame if not fake_result: @@ -406,7 +418,6 @@ class E2EVadModel(torch.nn.Module): sil_pdf_scores = [self.scores[0][t][sil_pdf_id] for sil_pdf_id in self.sil_pdf_ids] sum_score = sum(sil_pdf_scores) noise_prob = math.log(sum_score) * self.vad_opts.speech_2_noise_ratio - # total_score = sum(self.scores[0][t][:]) total_score = 1.0 sum_score = total_score - sum_score speech_prob = math.log(sum_score) @@ -433,23 +444,57 @@ class E2EVadModel(torch.nn.Module): return frame_state - def forward(self, feats: torch.Tensor, feats_lengths: int, waveform: torch.tensor) -> List[List[List[int]]]: - self.AllResetDetection() + def forward(self, feats: torch.Tensor, waveform: torch.tensor, is_final_send: bool = False) -> List[List[List[int]]]: self.waveform = waveform # compute decibel for each frame self.ComputeDecibel() - self.ComputeScores(feats, feats_lengths) - assert len(self.decibel) == len(self.scores[0]) # 保证帧数一致 - self.DetectLastFrames() + self.ComputeScores(feats) + if not is_final_send: + self.DetectCommonFrames() + else: + if self.streaming: + self.DetectLastFrames() + else: + self.AllResetDetection() + self.DetectAllFrames() # offline decode and is_final_send == True segments = [] for batch_num in range(0, feats.shape[0]): # only support batch_size = 1 now segment_batch = [] - for i in range(0, len(self.output_data_buf)): - segment = [self.output_data_buf[i].start_ms, self.output_data_buf[i].end_ms] - segment_batch.append(segment) - segments.append(segment_batch) + if len(self.output_data_buf) > 0: + for i in range(self.output_data_buf_offset, len(self.output_data_buf)): + if self.output_data_buf[i].contain_seg_start_point and self.output_data_buf[ + i].contain_seg_end_point: + segment = [self.output_data_buf[i].start_ms, self.output_data_buf[i].end_ms] + segment_batch.append(segment) + self.output_data_buf_offset += 1 # need update this parameter + if segment_batch: + segments.append(segment_batch) + return segments + def DetectCommonFrames(self) -> int: + if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected: + return 0 + for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1): + frame_state = FrameState.kFrameStateInvalid + frame_state = self.GetFrameState(self.frm_cnt - 1 - i) + self.DetectOneFrame(frame_state, self.frm_cnt - 1 - i, False) + + return 0 + def DetectLastFrames(self) -> int: + if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected: + return 0 + for i in range(self.vad_opts.nn_eval_block_size - 1, -1, -1): + frame_state = FrameState.kFrameStateInvalid + frame_state = self.GetFrameState(self.frm_cnt - 1 - i) + if i != 0: + self.DetectOneFrame(frame_state, self.frm_cnt - 1 - i, False) + else: + self.DetectOneFrame(frame_state, self.frm_cnt - 1, True) + + return 0 + + def DetectAllFrames(self) -> int: if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected: return 0 if self.vad_opts.nn_eval_block_size != self.vad_opts.dcd_block_size: diff --git a/funasr/models/encoder/fsmn_encoder.py b/funasr/models/encoder/fsmn_encoder.py index 643cefc54..54a113ddd 100755 --- a/funasr/models/encoder/fsmn_encoder.py +++ b/funasr/models/encoder/fsmn_encoder.py @@ -1,57 +1,52 @@ +from typing import Tuple, Dict +import copy + import numpy as np import torch import torch.nn as nn import torch.nn.functional as F -from typing import Tuple - - class LinearTransform(nn.Module): - def __init__(self, input_dim, output_dim, quantize=0): + def __init__(self, input_dim, output_dim): super(LinearTransform, self).__init__() self.input_dim = input_dim self.output_dim = output_dim self.linear = nn.Linear(input_dim, output_dim, bias=False) - self.quantize = quantize - self.quant = torch.quantization.QuantStub() - self.dequant = torch.quantization.DeQuantStub() def forward(self, input): - if self.quantize: - output = self.quant(input) - else: - output = input - output = self.linear(output) - if self.quantize: - output = self.dequant(output) + output = self.linear(input) return output class AffineTransform(nn.Module): - def __init__(self, input_dim, output_dim, quantize=0): + def __init__(self, input_dim, output_dim): super(AffineTransform, self).__init__() self.input_dim = input_dim self.output_dim = output_dim - self.quantize = quantize self.linear = nn.Linear(input_dim, output_dim) - self.quant = torch.quantization.QuantStub() - self.dequant = torch.quantization.DeQuantStub() def forward(self, input): - if self.quantize: - output = self.quant(input) - else: - output = input - output = self.linear(output) - if self.quantize: - output = self.dequant(output) + output = self.linear(input) return output +class RectifiedLinear(nn.Module): + + def __init__(self, input_dim, output_dim): + super(RectifiedLinear, self).__init__() + self.dim = input_dim + self.relu = nn.ReLU() + self.dropout = nn.Dropout(0.1) + + def forward(self, input): + out = self.relu(input) + return out + + class FSMNBlock(nn.Module): def __init__( @@ -62,7 +57,6 @@ class FSMNBlock(nn.Module): rorder=None, lstride=1, rstride=1, - quantize=0 ): super(FSMNBlock, self).__init__() @@ -84,71 +78,75 @@ class FSMNBlock(nn.Module): self.dim, self.dim, [rorder, 1], dilation=[rstride, 1], groups=self.dim, bias=False) else: self.conv_right = None - self.quantize = quantize - self.quant = torch.quantization.QuantStub() - self.dequant = torch.quantization.DeQuantStub() - def forward(self, input): + def forward(self, input: torch.Tensor, in_cache=None): x = torch.unsqueeze(input, 1) - x_per = x.permute(0, 3, 2, 1) - - y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0]) - if self.quantize: - y_left = self.quant(y_left) + x_per = x.permute(0, 3, 2, 1) # B D T C + if in_cache is None: # offline + y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0]) + else: + y_left = torch.cat((in_cache, x_per), dim=2) + in_cache = y_left[:, :, -(self.lorder - 1) * self.lstride:, :] y_left = self.conv_left(y_left) - if self.quantize: - y_left = self.dequant(y_left) out = x_per + y_left if self.conv_right is not None: + # maybe need to check y_right = F.pad(x_per, [0, 0, 0, self.rorder * self.rstride]) y_right = y_right[:, :, self.rstride:, :] - if self.quantize: - y_right = self.quant(y_right) y_right = self.conv_right(y_right) - if self.quantize: - y_right = self.dequant(y_right) out += y_right out_per = out.permute(0, 3, 2, 1) output = out_per.squeeze(1) - return output + return output, in_cache -class RectifiedLinear(nn.Module): +class BasicBlock(nn.Sequential): + def __init__(self, + linear_dim: int, + proj_dim: int, + lorder: int, + rorder: int, + lstride: int, + rstride: int, + stack_layer: int + ): + super(BasicBlock, self).__init__() + self.lorder = lorder + self.rorder = rorder + self.lstride = lstride + self.rstride = rstride + self.stack_layer = stack_layer + self.linear = LinearTransform(linear_dim, proj_dim) + self.fsmn_block = FSMNBlock(proj_dim, proj_dim, lorder, rorder, lstride, rstride) + self.affine = AffineTransform(proj_dim, linear_dim) + self.relu = RectifiedLinear(linear_dim, linear_dim) - def __init__(self, input_dim, output_dim): - super(RectifiedLinear, self).__init__() - self.dim = input_dim - self.relu = nn.ReLU() - self.dropout = nn.Dropout(0.1) - - def forward(self, input): - out = self.relu(input) - # out = self.dropout(out) - return out + def forward(self, input: torch.Tensor, in_cache=None): + x1 = self.linear(input) # B T D + if in_cache is not None: # Dict[str, tensor.Tensor] + cache_layer_name = 'cache_layer_{}'.format(self.stack_layer) + if cache_layer_name not in in_cache: + in_cache[cache_layer_name] = torch.zeros(x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1) + x2, in_cache[cache_layer_name] = self.fsmn_block(x1, in_cache[cache_layer_name]) + else: + x2, _ = self.fsmn_block(x1) + x3 = self.affine(x2) + x4 = self.relu(x3) + return x4, in_cache -def _build_repeats( - fsmn_layers: int, - linear_dim: int, - proj_dim: int, - lorder: int, - rorder: int, - lstride=1, - rstride=1, -): - repeats = [ - nn.Sequential( - LinearTransform(linear_dim, proj_dim), - FSMNBlock(proj_dim, proj_dim, lorder, rorder, 1, 1), - AffineTransform(proj_dim, linear_dim), - RectifiedLinear(linear_dim, linear_dim)) - for i in range(fsmn_layers) - ] +class FsmnStack(nn.Sequential): + def __init__(self, *args): + super(FsmnStack, self).__init__(*args) - return nn.Sequential(*repeats) + def forward(self, input: torch.Tensor, in_cache=None): + x = input + for module in self._modules.values(): + x, in_cache = module(x, in_cache) + return x ''' @@ -177,6 +175,7 @@ class FSMN(nn.Module): rstride: int, output_affine_dim: int, output_dim: int, + streaming=False ): super(FSMN, self).__init__() @@ -185,23 +184,16 @@ class FSMN(nn.Module): self.fsmn_layers = fsmn_layers self.linear_dim = linear_dim self.proj_dim = proj_dim - self.lorder = lorder - self.rorder = rorder - self.lstride = lstride - self.rstride = rstride self.output_affine_dim = output_affine_dim self.output_dim = output_dim + self.in_cache_original = dict() if streaming else None + self.in_cache = copy.deepcopy(self.in_cache_original) self.in_linear1 = AffineTransform(input_dim, input_affine_dim) self.in_linear2 = AffineTransform(input_affine_dim, linear_dim) self.relu = RectifiedLinear(linear_dim, linear_dim) - - self.fsmn = _build_repeats(fsmn_layers, - linear_dim, - proj_dim, - lorder, rorder, - lstride, rstride) - + self.fsmn = FsmnStack(*[BasicBlock(linear_dim, proj_dim, lorder, rorder, lstride, rstride, i) for i in + range(fsmn_layers)]) self.out_linear1 = AffineTransform(linear_dim, output_affine_dim) self.out_linear2 = AffineTransform(output_affine_dim, output_dim) self.softmax = nn.Softmax(dim=-1) @@ -209,27 +201,29 @@ class FSMN(nn.Module): def fuse_modules(self): pass + def cache_reset(self): + self.in_cache = copy.deepcopy(self.in_cache_original) + def forward( self, input: torch.Tensor, - in_cache: torch.Tensor = torch.zeros(0, 0, 0, dtype=torch.float) - ) -> torch.Tensor: + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ Args: input (torch.Tensor): Input tensor (B, T, D) - in_cache(torhc.Tensor): (B, D, C), C is the accumulated cache size + in_cache: when in_cache is not None, the forward is in streaming. The type of in_cache is a dict, egs, + {'cache_layer_1': torch.Tensor(B, T1, D)}, T1 is equal to self.lorder. It is {} for the 1st frame """ x1 = self.in_linear1(input) x2 = self.in_linear2(x1) x3 = self.relu(x2) - x4 = self.fsmn(x3) + x4 = self.fsmn(x3, self.in_cache) # if in_cache is not None, self.fsmn is streaming's format, it will update automatically in self.fsmn x5 = self.out_linear1(x4) x6 = self.out_linear2(x5) x7 = self.softmax(x6) return x7 - # return x6, in_cache ''' diff --git a/funasr/tasks/vad.py b/funasr/tasks/vad.py index dfd07c441..e2a912394 100644 --- a/funasr/tasks/vad.py +++ b/funasr/tasks/vad.py @@ -235,7 +235,7 @@ class VADTask(AbsTask): cls, args: argparse.Namespace, train: bool ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]: assert check_argument_types() - #if args.use_preprocessor: + # if args.use_preprocessor: # retval = CommonPreprocessor( # train=train, # # NOTE(kamo): Check attribute existence for backward compatibility @@ -254,7 +254,7 @@ class VADTask(AbsTask): # if hasattr(args, "rir_scp") # else None, # ) - #else: + # else: # retval = None retval = None assert check_return_type(retval) @@ -291,7 +291,8 @@ class VADTask(AbsTask): model_class = model_choices.get_class(args.model) except AttributeError: model_class = model_choices.get_class("e2evad") - model = model_class(encoder=encoder, vad_post_args=args.vad_post_conf) + model = model_class(encoder=encoder, vad_post_args=args.vad_post_conf, + streaming=args.encoder_conf.get('streaming', False)) return model