This commit is contained in:
speech_asr 2023-03-14 15:54:28 +08:00
parent ad2ef72341
commit 141a4737f7
2 changed files with 3 additions and 15 deletions

View File

@ -6,6 +6,7 @@ import numpy as np
import torch
import torchaudio.compliance.kaldi as kaldi
from funasr.models.frontend.abs_frontend import AbsFrontend
import funasr.models.frontend.eend_ola_feature as eend_ola_feature
from torch.nn.utils.rnn import pad_sequence
from typeguard import check_argument_types
from typing import Tuple
@ -213,33 +214,18 @@ class WavFrontendMel23(AbsFrontend):
def __init__(
self,
fs: int = 16000,
window: str = 'hamming',
n_mels: int = 80,
frame_length: int = 25,
frame_shift: int = 10,
filter_length_min: int = -1,
filter_length_max: int = -1,
lfr_m: int = 1,
lfr_n: int = 1,
dither: float = 1.0,
snip_edges: bool = True,
upsacle_samples: bool = True,
):
assert check_argument_types()
super().__init__()
self.fs = fs
self.window = window
self.n_mels = n_mels
self.frame_length = frame_length
self.frame_shift = frame_shift
self.filter_length_min = filter_length_min
self.filter_length_max = filter_length_max
self.lfr_m = lfr_m
self.lfr_n = lfr_n
self.cmvn_file = cmvn_file
self.dither = dither
self.snip_edges = snip_edges
self.upsacle_samples = upsacle_samples
def output_size(self) -> int:
return self.n_mels * self.lfr_m

View File

@ -23,6 +23,7 @@ from funasr.layers.global_mvn import GlobalMVN
from funasr.layers.label_aggregation import LabelAggregate
from funasr.layers.utterance_mvn import UtteranceMVN
from funasr.models.e2e_diar_sond import DiarSondModel
from funasr.models.e2e_diar_eend_ola import DiarEENDOLAModel
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.encoder.conformer_encoder import ConformerEncoder
from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
@ -103,6 +104,7 @@ model_choices = ClassChoices(
"model",
classes=dict(
sond=DiarSondModel,
eend_ola=DiarEENDOLAModel,
),
type_check=AbsESPnetModel,
default="sond",