in_cache & support soundfile read

This commit is contained in:
凌匀 2023-02-27 13:33:55 +08:00
parent dae8f7472d
commit 31eed1834f
4 changed files with 54 additions and 146 deletions

View File

@ -43,6 +43,7 @@ from funasr.utils.types import str_or_none
from funasr.utils import asr_utils, wav_utils, postprocess_utils from funasr.utils import asr_utils, wav_utils, postprocess_utils
from funasr.models.frontend.wav_frontend import WavFrontend from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.tasks.vad import VADTask from funasr.tasks.vad import VADTask
from funasr.bin.vad_inference import Speech2VadSegment
from funasr.utils.timestamp_tools import time_stamp_lfr6_pl from funasr.utils.timestamp_tools import time_stamp_lfr6_pl
from funasr.bin.punctuation_infer import Text2Punc from funasr.bin.punctuation_infer import Text2Punc
from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
@ -364,101 +365,6 @@ class Speech2Text:
hotword_list = None hotword_list = None
return hotword_list return hotword_list
class Speech2VadSegment:
"""Speech2VadSegment class
Examples:
>>> import soundfile
>>> speech2segment = Speech2VadSegment("vad_config.yml", "vad.pt")
>>> audio, rate = soundfile.read("speech.wav")
>>> speech2segment(audio)
[[10, 230], [245, 450], ...]
"""
def __init__(
self,
vad_infer_config: Union[Path, str] = None,
vad_model_file: Union[Path, str] = None,
vad_cmvn_file: Union[Path, str] = None,
device: str = "cpu",
batch_size: int = 1,
dtype: str = "float32",
**kwargs,
):
assert check_argument_types()
# 1. Build vad model
vad_model, vad_infer_args = VADTask.build_model_from_file(
vad_infer_config, vad_model_file, device
)
frontend = None
if vad_infer_args.frontend is not None:
frontend = WavFrontend(cmvn_file=vad_cmvn_file, **vad_infer_args.frontend_conf)
# logging.info("vad_model: {}".format(vad_model))
# logging.info("vad_infer_args: {}".format(vad_infer_args))
vad_model.to(dtype=getattr(torch, dtype)).eval()
self.vad_model = vad_model
self.vad_infer_args = vad_infer_args
self.device = device
self.dtype = dtype
self.frontend = frontend
self.batch_size = batch_size
@torch.no_grad()
def __call__(
self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
) -> List[List[int]]:
"""Inference
Args:
speech: Input speech data
Returns:
text, token, token_int, hyp
"""
assert check_argument_types()
# Input as audio signal
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
if self.frontend is not None:
self.frontend.filter_length_max = math.inf
fbanks, fbanks_len = self.frontend.forward_fbank(speech, speech_lengths)
feats, feats_len = self.frontend.forward_lfr_cmvn(fbanks, fbanks_len)
fbanks = to_device(fbanks, device=self.device)
feats = to_device(feats, device=self.device)
feats_len = feats_len.int()
else:
raise Exception("Need to extract feats first, please configure frontend configuration")
# b. Forward Encoder streaming
t_offset = 0
step = min(feats_len, 6000)
segments = [[]] * self.batch_size
for t_offset in range(0, feats_len, min(step, feats_len - t_offset)):
if t_offset + step >= feats_len - 1:
step = feats_len - t_offset
is_final_send = True
else:
is_final_send = False
batch = {
"feats": feats[:, t_offset:t_offset + step, :],
"waveform": speech[:, t_offset * 160:min(speech.shape[-1], (t_offset + step - 1) * 160 + 400)],
"is_final_send": is_final_send
}
# a. To device
batch = to_device(batch, device=self.device)
segments_part = self.vad_model(**batch)
if segments_part:
for batch_num in range(0, self.batch_size):
segments[batch_num] += segments_part[batch_num]
return fbanks, segments
def inference( def inference(
maxlenratio: float, maxlenratio: float,

View File

@ -11,6 +11,7 @@ from typing import Tuple
from typing import Union from typing import Union
from typing import Dict from typing import Dict
import math
import numpy as np import numpy as np
import torch import torch
from typeguard import check_argument_types from typeguard import check_argument_types
@ -86,7 +87,7 @@ class Speech2VadSegment:
@torch.no_grad() @torch.no_grad()
def __call__( def __call__(
self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
) -> List[List[int]]: ) -> Tuple[List[List[int]], Dict[str, torch.Tensor]]:
"""Inference """Inference
Args: Args:
@ -102,7 +103,10 @@ class Speech2VadSegment:
speech = torch.tensor(speech) speech = torch.tensor(speech)
if self.frontend is not None: if self.frontend is not None:
feats, feats_len = self.frontend.forward(speech, speech_lengths) self.frontend.filter_length_max = math.inf
fbanks, fbanks_len = self.frontend.forward_fbank(speech, speech_lengths)
feats, feats_len = self.frontend.forward_lfr_cmvn(fbanks, fbanks_len)
fbanks = to_device(fbanks, device=self.device)
feats = to_device(feats, device=self.device) feats = to_device(feats, device=self.device)
feats_len = feats_len.int() feats_len = feats_len.int()
else: else:
@ -110,18 +114,18 @@ class Speech2VadSegment:
# b. Forward Encoder streaming # b. Forward Encoder streaming
t_offset = 0 t_offset = 0
step = min(feats_len, 6000) step = min(feats_len.max(), 6000)
segments = [[]] * self.batch_size segments = [[]] * self.batch_size
for t_offset in range(0, feats_len, min(step, feats_len - t_offset)): for t_offset in range(0, feats_len, min(step, feats_len - t_offset)):
if t_offset + step >= feats_len - 1: if t_offset + step >= feats_len - 1:
step = feats_len - t_offset step = feats_len - t_offset
is_final_send = True is_final = True
else: else:
is_final_send = False is_final = False
batch = { batch = {
"feats": feats[:, t_offset:t_offset + step, :], "feats": feats[:, t_offset:t_offset + step, :],
"waveform": speech[:, t_offset * 160:min(speech.shape[-1], (t_offset + step - 1) * 160 + 400)], "waveform": speech[:, t_offset * 160:min(speech.shape[-1], (t_offset + step - 1) * 160 + 400)],
"is_final_send": is_final_send "is_final": is_final
} }
# a. To device # a. To device
batch = to_device(batch, device=self.device) batch = to_device(batch, device=self.device)
@ -129,7 +133,7 @@ class Speech2VadSegment:
if segments_part: if segments_part:
for batch_num in range(0, self.batch_size): for batch_num in range(0, self.batch_size):
segments[batch_num] += segments_part[batch_num] segments[batch_num] += segments_part[batch_num]
return segments return fbanks, segments
def inference( def inference(
@ -219,9 +223,13 @@ def inference_modelscope(
raw_inputs: Union[np.ndarray, torch.Tensor] = None, raw_inputs: Union[np.ndarray, torch.Tensor] = None,
output_dir_v2: Optional[str] = None, output_dir_v2: Optional[str] = None,
fs: dict = None, fs: dict = None,
param_dict: dict = None, param_dict: dict = None
): ):
# 3. Build data-iterator # 3. Build data-iterator
if data_path_and_name_and_type is None and raw_inputs is not None:
if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
loader = VADTask.build_streaming_iterator( loader = VADTask.build_streaming_iterator(
data_path_and_name_and_type, data_path_and_name_and_type,
dtype=dtype, dtype=dtype,
@ -254,7 +262,7 @@ def inference_modelscope(
assert len(keys) == _bs, f"{len(keys)} != {_bs}" assert len(keys) == _bs, f"{len(keys)} != {_bs}"
# do vad segment # do vad segment
results = speech2vadsegment(**batch) _, results = speech2vadsegment(**batch)
for i, _ in enumerate(keys): for i, _ in enumerate(keys):
results[i] = json.dumps(results[i]) results[i] = json.dumps(results[i])
item = {'key': keys[i], 'value': results[i]} item = {'key': keys[i], 'value': results[i]}

View File

@ -201,7 +201,7 @@ class E2EVadModel(nn.Module):
self.vad_opts.frame_in_ms) self.vad_opts.frame_in_ms)
self.encoder = encoder self.encoder = encoder
# init variables # init variables
self.is_final_send = False self.is_final = False
self.data_buf_start_frame = 0 self.data_buf_start_frame = 0
self.frm_cnt = 0 self.frm_cnt = 0
self.latest_confirmed_speech_frame = 0 self.latest_confirmed_speech_frame = 0
@ -230,8 +230,7 @@ class E2EVadModel(nn.Module):
self.ResetDetection() self.ResetDetection()
def AllResetDetection(self): def AllResetDetection(self):
self.encoder.cache_reset() # reset the in_cache in self.encoder for next query or next long sentence self.is_final = False
self.is_final_send = False
self.data_buf_start_frame = 0 self.data_buf_start_frame = 0
self.frm_cnt = 0 self.frm_cnt = 0
self.latest_confirmed_speech_frame = 0 self.latest_confirmed_speech_frame = 0
@ -283,8 +282,8 @@ class E2EVadModel(nn.Module):
10 * math.log10((self.waveform[0][offset: offset + frame_sample_length]).square().sum() + \ 10 * math.log10((self.waveform[0][offset: offset + frame_sample_length]).square().sum() + \
0.000001)) 0.000001))
def ComputeScores(self, feats: torch.Tensor) -> None: def ComputeScores(self, feats: torch.Tensor, in_cache: Dict[str, torch.Tensor]) -> None:
scores = self.encoder(feats) # return B * T * D scores = self.encoder(feats, in_cache) # return B * T * D
assert scores.shape[1] == feats.shape[1], "The shape between feats and scores does not match" assert scores.shape[1] == feats.shape[1], "The shape between feats and scores does not match"
self.vad_opts.nn_eval_block_size = scores.shape[1] self.vad_opts.nn_eval_block_size = scores.shape[1]
self.frm_cnt += scores.shape[1] # count total frames self.frm_cnt += scores.shape[1] # count total frames
@ -306,7 +305,7 @@ class E2EVadModel(nn.Module):
expected_sample_number = int(frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000) expected_sample_number = int(frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000)
if last_frm_is_end_point: if last_frm_is_end_point:
extra_sample = max(0, int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000 - \ extra_sample = max(0, int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000 - \
self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000)) self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000))
expected_sample_number += int(extra_sample) expected_sample_number += int(extra_sample)
if end_point_is_sent_end: if end_point_is_sent_end:
expected_sample_number = max(expected_sample_number, len(self.data_buf)) expected_sample_number = max(expected_sample_number, len(self.data_buf))
@ -443,11 +442,13 @@ class E2EVadModel(nn.Module):
return frame_state return frame_state
def forward(self, feats: torch.Tensor, waveform: torch.tensor, is_final_send: bool = False) -> List[List[List[int]]]: def forward(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
is_final: bool = False
) -> List[List[List[int]]]:
self.waveform = waveform # compute decibel for each frame self.waveform = waveform # compute decibel for each frame
self.ComputeDecibel() self.ComputeDecibel()
self.ComputeScores(feats) self.ComputeScores(feats, in_cache)
if not is_final_send: if not is_final:
self.DetectCommonFrames() self.DetectCommonFrames()
else: else:
self.DetectLastFrames() self.DetectLastFrames()
@ -456,15 +457,18 @@ class E2EVadModel(nn.Module):
segment_batch = [] segment_batch = []
if len(self.output_data_buf) > 0: if len(self.output_data_buf) > 0:
for i in range(self.output_data_buf_offset, len(self.output_data_buf)): for i in range(self.output_data_buf_offset, len(self.output_data_buf)):
if self.output_data_buf[i].contain_seg_start_point and self.output_data_buf[ if not self.output_data_buf[i].contain_seg_start_point or not self.output_data_buf[
i].contain_seg_end_point: i].contain_seg_end_point:
segment = [self.output_data_buf[i].start_ms, self.output_data_buf[i].end_ms] continue
segment_batch.append(segment) segment = [self.output_data_buf[i].start_ms, self.output_data_buf[i].end_ms]
self.output_data_buf_offset += 1 # need update this parameter segment_batch.append(segment)
self.output_data_buf_offset += 1 # need update this parameter
if segment_batch: if segment_batch:
segments.append(segment_batch) segments.append(segment_batch)
if is_final_send: if is_final:
self.AllResetDetection() # reset class variables and clear the dict for the next query
self.AllResetDetection()
in_cache.clear()
return segments return segments
def DetectCommonFrames(self) -> int: def DetectCommonFrames(self) -> int:

View File

@ -79,14 +79,12 @@ class FSMNBlock(nn.Module):
else: else:
self.conv_right = None self.conv_right = None
def forward(self, input: torch.Tensor, in_cache=None): def forward(self, input: torch.Tensor, cache: torch.Tensor):
x = torch.unsqueeze(input, 1) x = torch.unsqueeze(input, 1)
x_per = x.permute(0, 3, 2, 1) # B D T C x_per = x.permute(0, 3, 2, 1) # B D T C
if in_cache is None: # offline
y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0]) y_left = torch.cat((cache, x_per), dim=2)
else: cache = y_left[:, :, -(self.lorder - 1) * self.lstride:, :]
y_left = torch.cat((in_cache, x_per), dim=2)
in_cache = y_left[:, :, -(self.lorder - 1) * self.lstride:, :]
y_left = self.conv_left(y_left) y_left = self.conv_left(y_left)
out = x_per + y_left out = x_per + y_left
@ -100,7 +98,7 @@ class FSMNBlock(nn.Module):
out_per = out.permute(0, 3, 2, 1) out_per = out.permute(0, 3, 2, 1)
output = out_per.squeeze(1) output = out_per.squeeze(1)
return output, in_cache return output, cache
class BasicBlock(nn.Sequential): class BasicBlock(nn.Sequential):
@ -124,28 +122,25 @@ class BasicBlock(nn.Sequential):
self.affine = AffineTransform(proj_dim, linear_dim) self.affine = AffineTransform(proj_dim, linear_dim)
self.relu = RectifiedLinear(linear_dim, linear_dim) self.relu = RectifiedLinear(linear_dim, linear_dim)
def forward(self, input: torch.Tensor, in_cache=None): def forward(self, input: torch.Tensor, in_cache: Dict[str, torch.Tensor]):
x1 = self.linear(input) # B T D x1 = self.linear(input) # B T D
if in_cache is not None: # Dict[str, tensor.Tensor] cache_layer_name = 'cache_layer_{}'.format(self.stack_layer)
cache_layer_name = 'cache_layer_{}'.format(self.stack_layer) if cache_layer_name not in in_cache:
if cache_layer_name not in in_cache: in_cache[cache_layer_name] = torch.zeros(x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1)
in_cache[cache_layer_name] = torch.zeros(x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1) x2, in_cache[cache_layer_name] = self.fsmn_block(x1, in_cache[cache_layer_name])
x2, in_cache[cache_layer_name] = self.fsmn_block(x1, in_cache[cache_layer_name])
else:
x2, _ = self.fsmn_block(x1)
x3 = self.affine(x2) x3 = self.affine(x2)
x4 = self.relu(x3) x4 = self.relu(x3)
return x4, in_cache return x4
class FsmnStack(nn.Sequential): class FsmnStack(nn.Sequential):
def __init__(self, *args): def __init__(self, *args):
super(FsmnStack, self).__init__(*args) super(FsmnStack, self).__init__(*args)
def forward(self, input: torch.Tensor, in_cache=None): def forward(self, input: torch.Tensor, in_cache: Dict[str, torch.Tensor]):
x = input x = input
for module in self._modules.values(): for module in self._modules.values():
x, in_cache = module(x, in_cache) x = module(x, in_cache)
return x return x
@ -174,8 +169,7 @@ class FSMN(nn.Module):
lstride: int, lstride: int,
rstride: int, rstride: int,
output_affine_dim: int, output_affine_dim: int,
output_dim: int, output_dim: int
streaming=False
): ):
super(FSMN, self).__init__() super(FSMN, self).__init__()
@ -186,8 +180,6 @@ class FSMN(nn.Module):
self.proj_dim = proj_dim self.proj_dim = proj_dim
self.output_affine_dim = output_affine_dim self.output_affine_dim = output_affine_dim
self.output_dim = output_dim self.output_dim = output_dim
self.in_cache_original = dict() if streaming else None
self.in_cache = copy.deepcopy(self.in_cache_original)
self.in_linear1 = AffineTransform(input_dim, input_affine_dim) self.in_linear1 = AffineTransform(input_dim, input_affine_dim)
self.in_linear2 = AffineTransform(input_affine_dim, linear_dim) self.in_linear2 = AffineTransform(input_affine_dim, linear_dim)
@ -201,12 +193,10 @@ class FSMN(nn.Module):
def fuse_modules(self): def fuse_modules(self):
pass pass
def cache_reset(self):
self.in_cache = copy.deepcopy(self.in_cache_original)
def forward( def forward(
self, self,
input: torch.Tensor, input: torch.Tensor,
in_cache: Dict[str, torch.Tensor]
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
""" """
Args: Args:
@ -218,7 +208,7 @@ class FSMN(nn.Module):
x1 = self.in_linear1(input) x1 = self.in_linear1(input)
x2 = self.in_linear2(x1) x2 = self.in_linear2(x1)
x3 = self.relu(x2) x3 = self.relu(x2)
x4 = self.fsmn(x3, self.in_cache) # if in_cache is not None, self.fsmn is streaming's format, it will update automatically in self.fsmn x4 = self.fsmn(x3, in_cache) # self.in_cache will update automatically in self.fsmn
x5 = self.out_linear1(x4) x5 = self.out_linear1(x4)
x6 = self.out_linear2(x5) x6 = self.out_linear2(x5)
x7 = self.softmax(x6) x7 = self.softmax(x6)
@ -307,4 +297,4 @@ if __name__ == '__main__':
print('input shape: {}'.format(x.shape)) print('input shape: {}'.format(x.shape))
print('output shape: {}'.format(y.shape)) print('output shape: {}'.format(y.shape))
print(fsmn.to_kaldi_net()) print(fsmn.to_kaldi_net())