in_cache & support soundfile read

This commit is contained in:
凌匀 2023-02-27 13:33:55 +08:00
parent dae8f7472d
commit 31eed1834f
4 changed files with 54 additions and 146 deletions

View File

@ -43,6 +43,7 @@ from funasr.utils.types import str_or_none
from funasr.utils import asr_utils, wav_utils, postprocess_utils
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.tasks.vad import VADTask
from funasr.bin.vad_inference import Speech2VadSegment
from funasr.utils.timestamp_tools import time_stamp_lfr6_pl
from funasr.bin.punctuation_infer import Text2Punc
from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
@ -364,101 +365,6 @@ class Speech2Text:
hotword_list = None
return hotword_list
class Speech2VadSegment:
"""Speech2VadSegment class
Examples:
>>> import soundfile
>>> speech2segment = Speech2VadSegment("vad_config.yml", "vad.pt")
>>> audio, rate = soundfile.read("speech.wav")
>>> speech2segment(audio)
[[10, 230], [245, 450], ...]
"""
def __init__(
self,
vad_infer_config: Union[Path, str] = None,
vad_model_file: Union[Path, str] = None,
vad_cmvn_file: Union[Path, str] = None,
device: str = "cpu",
batch_size: int = 1,
dtype: str = "float32",
**kwargs,
):
assert check_argument_types()
# 1. Build vad model
vad_model, vad_infer_args = VADTask.build_model_from_file(
vad_infer_config, vad_model_file, device
)
frontend = None
if vad_infer_args.frontend is not None:
frontend = WavFrontend(cmvn_file=vad_cmvn_file, **vad_infer_args.frontend_conf)
# logging.info("vad_model: {}".format(vad_model))
# logging.info("vad_infer_args: {}".format(vad_infer_args))
vad_model.to(dtype=getattr(torch, dtype)).eval()
self.vad_model = vad_model
self.vad_infer_args = vad_infer_args
self.device = device
self.dtype = dtype
self.frontend = frontend
self.batch_size = batch_size
@torch.no_grad()
def __call__(
self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
) -> List[List[int]]:
"""Inference
Args:
speech: Input speech data
Returns:
text, token, token_int, hyp
"""
assert check_argument_types()
# Input as audio signal
if isinstance(speech, np.ndarray):
speech = torch.tensor(speech)
if self.frontend is not None:
self.frontend.filter_length_max = math.inf
fbanks, fbanks_len = self.frontend.forward_fbank(speech, speech_lengths)
feats, feats_len = self.frontend.forward_lfr_cmvn(fbanks, fbanks_len)
fbanks = to_device(fbanks, device=self.device)
feats = to_device(feats, device=self.device)
feats_len = feats_len.int()
else:
raise Exception("Need to extract feats first, please configure frontend configuration")
# b. Forward Encoder streaming
t_offset = 0
step = min(feats_len, 6000)
segments = [[]] * self.batch_size
for t_offset in range(0, feats_len, min(step, feats_len - t_offset)):
if t_offset + step >= feats_len - 1:
step = feats_len - t_offset
is_final_send = True
else:
is_final_send = False
batch = {
"feats": feats[:, t_offset:t_offset + step, :],
"waveform": speech[:, t_offset * 160:min(speech.shape[-1], (t_offset + step - 1) * 160 + 400)],
"is_final_send": is_final_send
}
# a. To device
batch = to_device(batch, device=self.device)
segments_part = self.vad_model(**batch)
if segments_part:
for batch_num in range(0, self.batch_size):
segments[batch_num] += segments_part[batch_num]
return fbanks, segments
def inference(
maxlenratio: float,

View File

@ -11,6 +11,7 @@ from typing import Tuple
from typing import Union
from typing import Dict
import math
import numpy as np
import torch
from typeguard import check_argument_types
@ -86,7 +87,7 @@ class Speech2VadSegment:
@torch.no_grad()
def __call__(
self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None
) -> List[List[int]]:
) -> Tuple[List[List[int]], Dict[str, torch.Tensor]]:
"""Inference
Args:
@ -102,7 +103,10 @@ class Speech2VadSegment:
speech = torch.tensor(speech)
if self.frontend is not None:
feats, feats_len = self.frontend.forward(speech, speech_lengths)
self.frontend.filter_length_max = math.inf
fbanks, fbanks_len = self.frontend.forward_fbank(speech, speech_lengths)
feats, feats_len = self.frontend.forward_lfr_cmvn(fbanks, fbanks_len)
fbanks = to_device(fbanks, device=self.device)
feats = to_device(feats, device=self.device)
feats_len = feats_len.int()
else:
@ -110,18 +114,18 @@ class Speech2VadSegment:
# b. Forward Encoder streaming
t_offset = 0
step = min(feats_len, 6000)
step = min(feats_len.max(), 6000)
segments = [[]] * self.batch_size
for t_offset in range(0, feats_len, min(step, feats_len - t_offset)):
if t_offset + step >= feats_len - 1:
step = feats_len - t_offset
is_final_send = True
is_final = True
else:
is_final_send = False
is_final = False
batch = {
"feats": feats[:, t_offset:t_offset + step, :],
"waveform": speech[:, t_offset * 160:min(speech.shape[-1], (t_offset + step - 1) * 160 + 400)],
"is_final_send": is_final_send
"is_final": is_final
}
# a. To device
batch = to_device(batch, device=self.device)
@ -129,7 +133,7 @@ class Speech2VadSegment:
if segments_part:
for batch_num in range(0, self.batch_size):
segments[batch_num] += segments_part[batch_num]
return segments
return fbanks, segments
def inference(
@ -219,9 +223,13 @@ def inference_modelscope(
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
output_dir_v2: Optional[str] = None,
fs: dict = None,
param_dict: dict = None,
param_dict: dict = None
):
# 3. Build data-iterator
if data_path_and_name_and_type is None and raw_inputs is not None:
if isinstance(raw_inputs, torch.Tensor):
raw_inputs = raw_inputs.numpy()
data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
loader = VADTask.build_streaming_iterator(
data_path_and_name_and_type,
dtype=dtype,
@ -254,7 +262,7 @@ def inference_modelscope(
assert len(keys) == _bs, f"{len(keys)} != {_bs}"
# do vad segment
results = speech2vadsegment(**batch)
_, results = speech2vadsegment(**batch)
for i, _ in enumerate(keys):
results[i] = json.dumps(results[i])
item = {'key': keys[i], 'value': results[i]}

View File

@ -201,7 +201,7 @@ class E2EVadModel(nn.Module):
self.vad_opts.frame_in_ms)
self.encoder = encoder
# init variables
self.is_final_send = False
self.is_final = False
self.data_buf_start_frame = 0
self.frm_cnt = 0
self.latest_confirmed_speech_frame = 0
@ -230,8 +230,7 @@ class E2EVadModel(nn.Module):
self.ResetDetection()
def AllResetDetection(self):
self.encoder.cache_reset() # reset the in_cache in self.encoder for next query or next long sentence
self.is_final_send = False
self.is_final = False
self.data_buf_start_frame = 0
self.frm_cnt = 0
self.latest_confirmed_speech_frame = 0
@ -283,8 +282,8 @@ class E2EVadModel(nn.Module):
10 * math.log10((self.waveform[0][offset: offset + frame_sample_length]).square().sum() + \
0.000001))
def ComputeScores(self, feats: torch.Tensor) -> None:
scores = self.encoder(feats) # return B * T * D
def ComputeScores(self, feats: torch.Tensor, in_cache: Dict[str, torch.Tensor]) -> None:
scores = self.encoder(feats, in_cache) # return B * T * D
assert scores.shape[1] == feats.shape[1], "The shape between feats and scores does not match"
self.vad_opts.nn_eval_block_size = scores.shape[1]
self.frm_cnt += scores.shape[1] # count total frames
@ -306,7 +305,7 @@ class E2EVadModel(nn.Module):
expected_sample_number = int(frm_cnt * self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000)
if last_frm_is_end_point:
extra_sample = max(0, int(self.vad_opts.frame_length_ms * self.vad_opts.sample_rate / 1000 - \
self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000))
self.vad_opts.sample_rate * self.vad_opts.frame_in_ms / 1000))
expected_sample_number += int(extra_sample)
if end_point_is_sent_end:
expected_sample_number = max(expected_sample_number, len(self.data_buf))
@ -443,11 +442,13 @@ class E2EVadModel(nn.Module):
return frame_state
def forward(self, feats: torch.Tensor, waveform: torch.tensor, is_final_send: bool = False) -> List[List[List[int]]]:
def forward(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
is_final: bool = False
) -> List[List[List[int]]]:
self.waveform = waveform # compute decibel for each frame
self.ComputeDecibel()
self.ComputeScores(feats)
if not is_final_send:
self.ComputeScores(feats, in_cache)
if not is_final:
self.DetectCommonFrames()
else:
self.DetectLastFrames()
@ -456,15 +457,18 @@ class E2EVadModel(nn.Module):
segment_batch = []
if len(self.output_data_buf) > 0:
for i in range(self.output_data_buf_offset, len(self.output_data_buf)):
if self.output_data_buf[i].contain_seg_start_point and self.output_data_buf[
if not self.output_data_buf[i].contain_seg_start_point or not self.output_data_buf[
i].contain_seg_end_point:
segment = [self.output_data_buf[i].start_ms, self.output_data_buf[i].end_ms]
segment_batch.append(segment)
self.output_data_buf_offset += 1 # need update this parameter
continue
segment = [self.output_data_buf[i].start_ms, self.output_data_buf[i].end_ms]
segment_batch.append(segment)
self.output_data_buf_offset += 1 # need update this parameter
if segment_batch:
segments.append(segment_batch)
if is_final_send:
self.AllResetDetection()
if is_final:
# reset class variables and clear the dict for the next query
self.AllResetDetection()
in_cache.clear()
return segments
def DetectCommonFrames(self) -> int:

View File

@ -79,14 +79,12 @@ class FSMNBlock(nn.Module):
else:
self.conv_right = None
def forward(self, input: torch.Tensor, in_cache=None):
def forward(self, input: torch.Tensor, cache: torch.Tensor):
x = torch.unsqueeze(input, 1)
x_per = x.permute(0, 3, 2, 1) # B D T C
if in_cache is None: # offline
y_left = F.pad(x_per, [0, 0, (self.lorder - 1) * self.lstride, 0])
else:
y_left = torch.cat((in_cache, x_per), dim=2)
in_cache = y_left[:, :, -(self.lorder - 1) * self.lstride:, :]
y_left = torch.cat((cache, x_per), dim=2)
cache = y_left[:, :, -(self.lorder - 1) * self.lstride:, :]
y_left = self.conv_left(y_left)
out = x_per + y_left
@ -100,7 +98,7 @@ class FSMNBlock(nn.Module):
out_per = out.permute(0, 3, 2, 1)
output = out_per.squeeze(1)
return output, in_cache
return output, cache
class BasicBlock(nn.Sequential):
@ -124,28 +122,25 @@ class BasicBlock(nn.Sequential):
self.affine = AffineTransform(proj_dim, linear_dim)
self.relu = RectifiedLinear(linear_dim, linear_dim)
def forward(self, input: torch.Tensor, in_cache=None):
def forward(self, input: torch.Tensor, in_cache: Dict[str, torch.Tensor]):
x1 = self.linear(input) # B T D
if in_cache is not None: # Dict[str, tensor.Tensor]
cache_layer_name = 'cache_layer_{}'.format(self.stack_layer)
if cache_layer_name not in in_cache:
in_cache[cache_layer_name] = torch.zeros(x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1)
x2, in_cache[cache_layer_name] = self.fsmn_block(x1, in_cache[cache_layer_name])
else:
x2, _ = self.fsmn_block(x1)
cache_layer_name = 'cache_layer_{}'.format(self.stack_layer)
if cache_layer_name not in in_cache:
in_cache[cache_layer_name] = torch.zeros(x1.shape[0], x1.shape[-1], (self.lorder - 1) * self.lstride, 1)
x2, in_cache[cache_layer_name] = self.fsmn_block(x1, in_cache[cache_layer_name])
x3 = self.affine(x2)
x4 = self.relu(x3)
return x4, in_cache
return x4
class FsmnStack(nn.Sequential):
def __init__(self, *args):
super(FsmnStack, self).__init__(*args)
def forward(self, input: torch.Tensor, in_cache=None):
def forward(self, input: torch.Tensor, in_cache: Dict[str, torch.Tensor]):
x = input
for module in self._modules.values():
x, in_cache = module(x, in_cache)
x = module(x, in_cache)
return x
@ -174,8 +169,7 @@ class FSMN(nn.Module):
lstride: int,
rstride: int,
output_affine_dim: int,
output_dim: int,
streaming=False
output_dim: int
):
super(FSMN, self).__init__()
@ -186,8 +180,6 @@ class FSMN(nn.Module):
self.proj_dim = proj_dim
self.output_affine_dim = output_affine_dim
self.output_dim = output_dim
self.in_cache_original = dict() if streaming else None
self.in_cache = copy.deepcopy(self.in_cache_original)
self.in_linear1 = AffineTransform(input_dim, input_affine_dim)
self.in_linear2 = AffineTransform(input_affine_dim, linear_dim)
@ -201,12 +193,10 @@ class FSMN(nn.Module):
def fuse_modules(self):
pass
def cache_reset(self):
self.in_cache = copy.deepcopy(self.in_cache_original)
def forward(
self,
input: torch.Tensor,
in_cache: Dict[str, torch.Tensor]
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Args:
@ -218,7 +208,7 @@ class FSMN(nn.Module):
x1 = self.in_linear1(input)
x2 = self.in_linear2(x1)
x3 = self.relu(x2)
x4 = self.fsmn(x3, self.in_cache) # if in_cache is not None, self.fsmn is streaming's format, it will update automatically in self.fsmn
x4 = self.fsmn(x3, in_cache) # self.in_cache will update automatically in self.fsmn
x5 = self.out_linear1(x4)
x6 = self.out_linear2(x5)
x7 = self.softmax(x6)
@ -307,4 +297,4 @@ if __name__ == '__main__':
print('input shape: {}'.format(x.shape))
print('output shape: {}'.format(y.shape))
print(fsmn.to_kaldi_net())
print(fsmn.to_kaldi_net())