update ola

This commit is contained in:
speech_asr 2023-03-13 15:30:17 +08:00
parent a4d87a7fff
commit 8762d99735

View File

@ -1,14 +1,14 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Part of the implementation is borrowed from espnet/espnet.
from typing import Tuple
import funasr.models.frontend.eend_ola_feature
import numpy as np
import torch
import torchaudio.compliance.kaldi as kaldi
from funasr.models.frontend.abs_frontend import AbsFrontend
from typeguard import check_argument_types
from torch.nn.utils.rnn import pad_sequence
from typeguard import check_argument_types
from typing import Tuple
def load_cmvn(cmvn_file):
@ -33,9 +33,9 @@ def load_cmvn(cmvn_file):
means = np.array(means_list).astype(np.float)
vars = np.array(vars_list).astype(np.float)
cmvn = np.array([means, vars])
cmvn = torch.as_tensor(cmvn)
return cmvn
cmvn = torch.as_tensor(cmvn)
return cmvn
def apply_cmvn(inputs, cmvn_file): # noqa
"""
@ -78,21 +78,22 @@ def apply_lfr(inputs, lfr_m, lfr_n):
class WavFrontend(AbsFrontend):
"""Conventional frontend structure for ASR.
"""
def __init__(
self,
cmvn_file: str = None,
fs: int = 16000,
window: str = 'hamming',
n_mels: int = 80,
frame_length: int = 25,
frame_shift: int = 10,
filter_length_min: int = -1,
filter_length_max: int = -1,
lfr_m: int = 1,
lfr_n: int = 1,
dither: float = 1.0,
snip_edges: bool = True,
upsacle_samples: bool = True,
self,
cmvn_file: str = None,
fs: int = 16000,
window: str = 'hamming',
n_mels: int = 80,
frame_length: int = 25,
frame_shift: int = 10,
filter_length_min: int = -1,
filter_length_max: int = -1,
lfr_m: int = 1,
lfr_n: int = 1,
dither: float = 1.0,
snip_edges: bool = True,
upsacle_samples: bool = True,
):
assert check_argument_types()
super().__init__()
@ -135,11 +136,11 @@ class WavFrontend(AbsFrontend):
window_type=self.window,
sample_frequency=self.fs,
snip_edges=self.snip_edges)
if self.lfr_m != 1 or self.lfr_n != 1:
mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
if self.cmvn_file is not None:
mat = apply_cmvn(mat, self.cmvn_file)
mat = apply_cmvn(mat, self.cmvn_file)
feat_length = mat.size(0)
feats.append(mat)
feats_lens.append(feat_length)
@ -171,7 +172,6 @@ class WavFrontend(AbsFrontend):
window_type=self.window,
sample_frequency=self.fs)
feat_length = mat.size(0)
feats.append(mat)
feats_lens.append(feat_length)
@ -204,3 +204,68 @@ class WavFrontend(AbsFrontend):
batch_first=True,
padding_value=0.0)
return feats_pad, feats_lens
class WavFrontendMel23(AbsFrontend):
"""Conventional frontend structure for ASR.
"""
def __init__(
self,
fs: int = 16000,
window: str = 'hamming',
n_mels: int = 80,
frame_length: int = 25,
frame_shift: int = 10,
filter_length_min: int = -1,
filter_length_max: int = -1,
lfr_m: int = 1,
lfr_n: int = 1,
dither: float = 1.0,
snip_edges: bool = True,
upsacle_samples: bool = True,
):
assert check_argument_types()
super().__init__()
self.fs = fs
self.window = window
self.n_mels = n_mels
self.frame_length = frame_length
self.frame_shift = frame_shift
self.filter_length_min = filter_length_min
self.filter_length_max = filter_length_max
self.lfr_m = lfr_m
self.lfr_n = lfr_n
self.cmvn_file = cmvn_file
self.dither = dither
self.snip_edges = snip_edges
self.upsacle_samples = upsacle_samples
def output_size(self) -> int:
return self.n_mels * self.lfr_m
def forward(
self,
input: torch.Tensor,
input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = input.size(0)
feats = []
feats_lens = []
for i in range(batch_size):
waveform_length = input_lengths[i]
waveform = input[i][:waveform_length]
waveform = waveform.unsqueeze(0).numpy()
mat = eend_ola_feature.stft(waveform, self.frame_length, self.frame_shift)
mat = eend_ola_feature.transform(mat)
mat = mat.splice(mat, context_size=self.lfr_m)
mat = mat[::self.lfr_n]
mat = torch.from_numpy(mat)
feat_length = mat.size(0)
feats.append(mat)
feats_lens.append(feat_length)
feats_lens = torch.as_tensor(feats_lens)
feats_pad = pad_sequence(feats,
batch_first=True,
padding_value=0.0)
return feats_pad, feats_lens