mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update ola
This commit is contained in:
parent
a4d87a7fff
commit
8762d99735
@ -1,14 +1,14 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
# Part of the implementation is borrowed from espnet/espnet.
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import funasr.models.frontend.eend_ola_feature
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio.compliance.kaldi as kaldi
|
||||
from funasr.models.frontend.abs_frontend import AbsFrontend
|
||||
from typeguard import check_argument_types
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from typeguard import check_argument_types
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
def load_cmvn(cmvn_file):
|
||||
@ -33,9 +33,9 @@ def load_cmvn(cmvn_file):
|
||||
means = np.array(means_list).astype(np.float)
|
||||
vars = np.array(vars_list).astype(np.float)
|
||||
cmvn = np.array([means, vars])
|
||||
cmvn = torch.as_tensor(cmvn)
|
||||
return cmvn
|
||||
|
||||
cmvn = torch.as_tensor(cmvn)
|
||||
return cmvn
|
||||
|
||||
|
||||
def apply_cmvn(inputs, cmvn_file): # noqa
|
||||
"""
|
||||
@ -78,21 +78,22 @@ def apply_lfr(inputs, lfr_m, lfr_n):
|
||||
class WavFrontend(AbsFrontend):
|
||||
"""Conventional frontend structure for ASR.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cmvn_file: str = None,
|
||||
fs: int = 16000,
|
||||
window: str = 'hamming',
|
||||
n_mels: int = 80,
|
||||
frame_length: int = 25,
|
||||
frame_shift: int = 10,
|
||||
filter_length_min: int = -1,
|
||||
filter_length_max: int = -1,
|
||||
lfr_m: int = 1,
|
||||
lfr_n: int = 1,
|
||||
dither: float = 1.0,
|
||||
snip_edges: bool = True,
|
||||
upsacle_samples: bool = True,
|
||||
self,
|
||||
cmvn_file: str = None,
|
||||
fs: int = 16000,
|
||||
window: str = 'hamming',
|
||||
n_mels: int = 80,
|
||||
frame_length: int = 25,
|
||||
frame_shift: int = 10,
|
||||
filter_length_min: int = -1,
|
||||
filter_length_max: int = -1,
|
||||
lfr_m: int = 1,
|
||||
lfr_n: int = 1,
|
||||
dither: float = 1.0,
|
||||
snip_edges: bool = True,
|
||||
upsacle_samples: bool = True,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
@ -135,11 +136,11 @@ class WavFrontend(AbsFrontend):
|
||||
window_type=self.window,
|
||||
sample_frequency=self.fs,
|
||||
snip_edges=self.snip_edges)
|
||||
|
||||
|
||||
if self.lfr_m != 1 or self.lfr_n != 1:
|
||||
mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
|
||||
if self.cmvn_file is not None:
|
||||
mat = apply_cmvn(mat, self.cmvn_file)
|
||||
mat = apply_cmvn(mat, self.cmvn_file)
|
||||
feat_length = mat.size(0)
|
||||
feats.append(mat)
|
||||
feats_lens.append(feat_length)
|
||||
@ -171,7 +172,6 @@ class WavFrontend(AbsFrontend):
|
||||
window_type=self.window,
|
||||
sample_frequency=self.fs)
|
||||
|
||||
|
||||
feat_length = mat.size(0)
|
||||
feats.append(mat)
|
||||
feats_lens.append(feat_length)
|
||||
@ -204,3 +204,68 @@ class WavFrontend(AbsFrontend):
|
||||
batch_first=True,
|
||||
padding_value=0.0)
|
||||
return feats_pad, feats_lens
|
||||
|
||||
|
||||
class WavFrontendMel23(AbsFrontend):
|
||||
"""Conventional frontend structure for ASR.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fs: int = 16000,
|
||||
window: str = 'hamming',
|
||||
n_mels: int = 80,
|
||||
frame_length: int = 25,
|
||||
frame_shift: int = 10,
|
||||
filter_length_min: int = -1,
|
||||
filter_length_max: int = -1,
|
||||
lfr_m: int = 1,
|
||||
lfr_n: int = 1,
|
||||
dither: float = 1.0,
|
||||
snip_edges: bool = True,
|
||||
upsacle_samples: bool = True,
|
||||
):
|
||||
assert check_argument_types()
|
||||
super().__init__()
|
||||
self.fs = fs
|
||||
self.window = window
|
||||
self.n_mels = n_mels
|
||||
self.frame_length = frame_length
|
||||
self.frame_shift = frame_shift
|
||||
self.filter_length_min = filter_length_min
|
||||
self.filter_length_max = filter_length_max
|
||||
self.lfr_m = lfr_m
|
||||
self.lfr_n = lfr_n
|
||||
self.cmvn_file = cmvn_file
|
||||
self.dither = dither
|
||||
self.snip_edges = snip_edges
|
||||
self.upsacle_samples = upsacle_samples
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self.n_mels * self.lfr_m
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input: torch.Tensor,
|
||||
input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_size = input.size(0)
|
||||
feats = []
|
||||
feats_lens = []
|
||||
for i in range(batch_size):
|
||||
waveform_length = input_lengths[i]
|
||||
waveform = input[i][:waveform_length]
|
||||
waveform = waveform.unsqueeze(0).numpy()
|
||||
mat = eend_ola_feature.stft(waveform, self.frame_length, self.frame_shift)
|
||||
mat = eend_ola_feature.transform(mat)
|
||||
mat = mat.splice(mat, context_size=self.lfr_m)
|
||||
mat = mat[::self.lfr_n]
|
||||
mat = torch.from_numpy(mat)
|
||||
feat_length = mat.size(0)
|
||||
feats.append(mat)
|
||||
feats_lens.append(feat_length)
|
||||
|
||||
feats_lens = torch.as_tensor(feats_lens)
|
||||
feats_pad = pad_sequence(feats,
|
||||
batch_first=True,
|
||||
padding_value=0.0)
|
||||
return feats_pad, feats_lens
|
||||
|
||||
Loading…
Reference in New Issue
Block a user