mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update data2vec pretrain
This commit is contained in:
parent
933d5afc02
commit
7bea618623
45
funasr/bin/data2vec_train.py
Executable file
45
funasr/bin/data2vec_train.py
Executable file
@ -0,0 +1,45 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from funasr.tasks.data2vec import Data2VecTask
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
parser = Data2VecTask.get_parser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--gpu_id",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="local gpu id.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def main(args=None, cmd=None):
|
||||||
|
# for data2vec Training
|
||||||
|
Data2VecTask.main(args=args, cmd=cmd)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
# setup local gpu_id
|
||||||
|
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
|
||||||
|
|
||||||
|
# DDP settings
|
||||||
|
if args.ngpu > 1:
|
||||||
|
args.distributed = True
|
||||||
|
else:
|
||||||
|
args.distributed = False
|
||||||
|
assert args.num_worker_count == 1
|
||||||
|
|
||||||
|
# re-compute batch size: when dataset type is small
|
||||||
|
if args.dataset_type == "small":
|
||||||
|
if args.batch_size is not None:
|
||||||
|
args.batch_size = args.batch_size * args.ngpu
|
||||||
|
if args.batch_bins is not None:
|
||||||
|
args.batch_bins = args.batch_bins * args.ngpu
|
||||||
|
|
||||||
|
main(args=args)
|
||||||
160
funasr/models/data2vec.py
Normal file
160
funasr/models/data2vec.py
Normal file
@ -0,0 +1,160 @@
|
|||||||
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the MIT license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from distutils.version import LooseVersion
|
||||||
|
from typing import Dict
|
||||||
|
from typing import Optional
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from typeguard import check_argument_types
|
||||||
|
|
||||||
|
from funasr.layers.abs_normalize import AbsNormalize
|
||||||
|
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||||
|
from funasr.models.frontend.abs_frontend import AbsFrontend
|
||||||
|
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
|
||||||
|
from funasr.models.specaug.abs_specaug import AbsSpecAug
|
||||||
|
from funasr.torch_utils.device_funcs import force_gatherable
|
||||||
|
from funasr.train.abs_espnet_model import AbsESPnetModel
|
||||||
|
|
||||||
|
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
|
||||||
|
from torch.cuda.amp import autocast
|
||||||
|
else:
|
||||||
|
# Nothing to do if torch<1.6.0
|
||||||
|
@contextmanager
|
||||||
|
def autocast(enabled=True):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
class Data2VecPretrainModel(AbsESPnetModel):
|
||||||
|
"""Data2Vec Pretrain model"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
frontend: Optional[AbsFrontend],
|
||||||
|
specaug: Optional[AbsSpecAug],
|
||||||
|
normalize: Optional[AbsNormalize],
|
||||||
|
preencoder: Optional[AbsPreEncoder],
|
||||||
|
encoder: AbsEncoder,
|
||||||
|
):
|
||||||
|
assert check_argument_types()
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.frontend = frontend
|
||||||
|
self.specaug = specaug
|
||||||
|
self.normalize = normalize
|
||||||
|
self.preencoder = preencoder
|
||||||
|
self.encoder = encoder
|
||||||
|
self.num_updates = 0
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
speech: torch.Tensor,
|
||||||
|
speech_lengths: torch.Tensor,
|
||||||
|
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
||||||
|
"""Frontend + Encoder + Calc loss
|
||||||
|
|
||||||
|
Args:
|
||||||
|
speech: (Batch, Length, ...)
|
||||||
|
speech_lengths: (Batch, )
|
||||||
|
"""
|
||||||
|
# Check that batch_size is unified
|
||||||
|
assert (
|
||||||
|
speech.shape[0]
|
||||||
|
== speech_lengths.shape[0]
|
||||||
|
), (speech.shape, speech_lengths.shape)
|
||||||
|
|
||||||
|
self.encoder.set_num_updates(self.num_updates)
|
||||||
|
|
||||||
|
# 1. Encoder
|
||||||
|
encoder_out = self.encode(speech, speech_lengths)
|
||||||
|
|
||||||
|
losses = encoder_out["losses"]
|
||||||
|
loss = sum(losses.values())
|
||||||
|
sample_size = encoder_out["sample_size"]
|
||||||
|
loss = loss.sum() / sample_size
|
||||||
|
|
||||||
|
target_var = float(encoder_out["target_var"])
|
||||||
|
pred_var = float(encoder_out["pred_var"])
|
||||||
|
ema_decay = float(encoder_out["ema_decay"])
|
||||||
|
|
||||||
|
stats = dict(
|
||||||
|
loss=torch.clone(loss.detach()),
|
||||||
|
target_var=target_var,
|
||||||
|
pred_var=pred_var,
|
||||||
|
ema_decay=ema_decay,
|
||||||
|
)
|
||||||
|
|
||||||
|
loss, stats, weight = force_gatherable((loss, stats, sample_size), loss.device)
|
||||||
|
return loss, stats, weight
|
||||||
|
|
||||||
|
def collect_feats(
|
||||||
|
self,
|
||||||
|
speech: torch.Tensor,
|
||||||
|
speech_lengths: torch.Tensor
|
||||||
|
) -> Dict[str, torch.Tensor]:
|
||||||
|
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
|
||||||
|
return {"feats": feats, "feats_lengths": feats_lengths}
|
||||||
|
|
||||||
|
def encode(
|
||||||
|
self,
|
||||||
|
speech: torch.Tensor,
|
||||||
|
speech_lengths: torch.Tensor,
|
||||||
|
):
|
||||||
|
"""Frontend + Encoder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
speech: (Batch, Length, ...)
|
||||||
|
speech_lengths: (Batch, )
|
||||||
|
"""
|
||||||
|
with autocast(False):
|
||||||
|
# 1. Extract feats
|
||||||
|
feats, feats_lengths = self._extract_feats(speech, speech_lengths)
|
||||||
|
|
||||||
|
# 2. Data augmentation
|
||||||
|
if self.specaug is not None and self.training:
|
||||||
|
feats, feats_lengths = self.specaug(feats, feats_lengths)
|
||||||
|
|
||||||
|
# 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
|
||||||
|
if self.normalize is not None:
|
||||||
|
feats, feats_lengths = self.normalize(feats, feats_lengths)
|
||||||
|
|
||||||
|
# Pre-encoder, e.g. used for raw input data
|
||||||
|
if self.preencoder is not None:
|
||||||
|
feats, feats_lengths = self.preencoder(feats, feats_lengths)
|
||||||
|
|
||||||
|
# 4. Forward encoder
|
||||||
|
if min(speech_lengths) == max(speech_lengths): # for clipping, set speech_lengths as None
|
||||||
|
speech_lengths = None
|
||||||
|
encoder_out = self.encoder(feats, speech_lengths, mask=True, features_only=False)
|
||||||
|
|
||||||
|
return encoder_out
|
||||||
|
|
||||||
|
def _extract_feats(
|
||||||
|
self, speech: torch.Tensor, speech_lengths: torch.Tensor
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
assert speech_lengths.dim() == 1, speech_lengths.shape
|
||||||
|
|
||||||
|
# for data-parallel
|
||||||
|
speech = speech[:, : speech_lengths.max()]
|
||||||
|
|
||||||
|
if self.frontend is not None:
|
||||||
|
# Frontend
|
||||||
|
# e.g. STFT and Feature extract
|
||||||
|
# data_loader may send time-domain signal in this case
|
||||||
|
# speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
|
||||||
|
feats, feats_lengths = self.frontend(speech, speech_lengths)
|
||||||
|
else:
|
||||||
|
# No frontend and no feature extract
|
||||||
|
feats, feats_lengths = speech, speech_lengths
|
||||||
|
return feats, feats_lengths
|
||||||
|
|
||||||
|
def set_num_updates(self, num_updates):
|
||||||
|
self.num_updates = num_updates
|
||||||
|
|
||||||
|
def get_num_updates(self):
|
||||||
|
return self.num_updates
|
||||||
376
funasr/tasks/data2vec.py
Normal file
376
funasr/tasks/data2vec.py
Normal file
@ -0,0 +1,376 @@
|
|||||||
|
import argparse
|
||||||
|
from typing import Callable
|
||||||
|
from typing import Collection
|
||||||
|
from typing import Dict
|
||||||
|
from typing import List
|
||||||
|
from typing import Optional
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from typeguard import check_argument_types
|
||||||
|
from typeguard import check_return_type
|
||||||
|
|
||||||
|
from funasr.datasets.collate_fn import CommonCollateFn
|
||||||
|
from funasr.datasets.preprocessor import CommonPreprocessor
|
||||||
|
from funasr.layers.abs_normalize import AbsNormalize
|
||||||
|
from funasr.layers.global_mvn import GlobalMVN
|
||||||
|
from funasr.layers.utterance_mvn import UtteranceMVN
|
||||||
|
from funasr.models.data2vec import Data2VecPretrainModel
|
||||||
|
from funasr.models.encoder.abs_encoder import AbsEncoder
|
||||||
|
from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
|
||||||
|
from funasr.models.frontend.abs_frontend import AbsFrontend
|
||||||
|
from funasr.models.frontend.default import DefaultFrontend
|
||||||
|
from funasr.models.frontend.windowing import SlidingWindow
|
||||||
|
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
|
||||||
|
from funasr.models.preencoder.sinc import LightweightSincConvs
|
||||||
|
from funasr.models.specaug.abs_specaug import AbsSpecAug
|
||||||
|
from funasr.models.specaug.specaug import SpecAug
|
||||||
|
from funasr.tasks.abs_task import AbsTask
|
||||||
|
from funasr.text.phoneme_tokenizer import g2p_choices
|
||||||
|
from funasr.torch_utils.initialize import initialize
|
||||||
|
from funasr.train.class_choices import ClassChoices
|
||||||
|
from funasr.train.trainer import Trainer
|
||||||
|
from funasr.utils.types import float_or_none
|
||||||
|
from funasr.utils.types import int_or_none
|
||||||
|
from funasr.utils.types import str2bool
|
||||||
|
from funasr.utils.types import str_or_none
|
||||||
|
|
||||||
|
frontend_choices = ClassChoices(
|
||||||
|
name="frontend",
|
||||||
|
classes=dict(default=DefaultFrontend, sliding_window=SlidingWindow),
|
||||||
|
type_check=AbsFrontend,
|
||||||
|
default="default",
|
||||||
|
)
|
||||||
|
specaug_choices = ClassChoices(
|
||||||
|
name="specaug",
|
||||||
|
classes=dict(specaug=SpecAug),
|
||||||
|
type_check=AbsSpecAug,
|
||||||
|
default=None,
|
||||||
|
optional=True,
|
||||||
|
)
|
||||||
|
normalize_choices = ClassChoices(
|
||||||
|
"normalize",
|
||||||
|
classes=dict(
|
||||||
|
global_mvn=GlobalMVN,
|
||||||
|
utterance_mvn=UtteranceMVN,
|
||||||
|
),
|
||||||
|
type_check=AbsNormalize,
|
||||||
|
default=None,
|
||||||
|
optional=True,
|
||||||
|
)
|
||||||
|
preencoder_choices = ClassChoices(
|
||||||
|
name="preencoder",
|
||||||
|
classes=dict(
|
||||||
|
sinc=LightweightSincConvs,
|
||||||
|
),
|
||||||
|
type_check=AbsPreEncoder,
|
||||||
|
default=None,
|
||||||
|
optional=True,
|
||||||
|
)
|
||||||
|
encoder_choices = ClassChoices(
|
||||||
|
"encoder",
|
||||||
|
classes=dict(
|
||||||
|
data2vec_encoder=Data2VecEncoder,
|
||||||
|
),
|
||||||
|
type_check=AbsEncoder,
|
||||||
|
default="data2vec_encoder",
|
||||||
|
)
|
||||||
|
model_choices = ClassChoices(
|
||||||
|
"model",
|
||||||
|
classes=dict(
|
||||||
|
data2vec=Data2VecPretrainModel,
|
||||||
|
),
|
||||||
|
default="data2vec",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Data2VecTask(AbsTask):
|
||||||
|
# If you need more than one optimizers, change this value
|
||||||
|
num_optimizers: int = 1
|
||||||
|
|
||||||
|
# Add variable objects configurations
|
||||||
|
class_choices_list = [
|
||||||
|
# --frontend and --frontend_conf
|
||||||
|
frontend_choices,
|
||||||
|
# --specaug and --specaug_conf
|
||||||
|
specaug_choices,
|
||||||
|
# --normalize and --normalize_conf
|
||||||
|
normalize_choices,
|
||||||
|
# --preencoder and --preencoder_conf
|
||||||
|
preencoder_choices,
|
||||||
|
# --encoder and --encoder_conf
|
||||||
|
encoder_choices,
|
||||||
|
# --model and --model_conf
|
||||||
|
model_choices,
|
||||||
|
]
|
||||||
|
|
||||||
|
# If you need to modify train() or eval() procedures, change Trainer class here
|
||||||
|
trainer = Trainer
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def add_task_arguments(cls, parser: argparse.ArgumentParser):
|
||||||
|
group = parser.add_argument_group(description="Task related")
|
||||||
|
|
||||||
|
# NOTE(kamo): add_arguments(..., required=True) can't be used
|
||||||
|
# to provide --print_config mode. Instead of it, do as
|
||||||
|
group.add_argument(
|
||||||
|
"--token_list",
|
||||||
|
type=str_or_none,
|
||||||
|
default=None,
|
||||||
|
help="A text mapping int-id to token",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--init",
|
||||||
|
type=lambda x: str_or_none(x.lower()),
|
||||||
|
default=None,
|
||||||
|
help="The initialization method",
|
||||||
|
choices=[
|
||||||
|
"chainer",
|
||||||
|
"xavier_uniform",
|
||||||
|
"xavier_normal",
|
||||||
|
"kaiming_uniform",
|
||||||
|
"kaiming_normal",
|
||||||
|
None,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
group.add_argument(
|
||||||
|
"--input_size",
|
||||||
|
type=int_or_none,
|
||||||
|
default=None,
|
||||||
|
help="The number of input dimension of the feature",
|
||||||
|
)
|
||||||
|
|
||||||
|
group = parser.add_argument_group(description="Preprocess related")
|
||||||
|
group.add_argument(
|
||||||
|
"--use_preprocessor",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Apply preprocessing to data or not",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--token_type",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
choices=["bpe", "char", "word", "phn"],
|
||||||
|
help="The text will be tokenized " "in the specified level token",
|
||||||
|
)
|
||||||
|
|
||||||
|
group.add_argument(
|
||||||
|
"--feats_type",
|
||||||
|
type=str,
|
||||||
|
default='fbank',
|
||||||
|
help="feats type, e.g. fbank, wav, ark_wav(needed to be scale normalization)",
|
||||||
|
)
|
||||||
|
|
||||||
|
group.add_argument(
|
||||||
|
"--bpemodel",
|
||||||
|
type=str_or_none,
|
||||||
|
default=None,
|
||||||
|
help="The model file of sentencepiece",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--non_linguistic_symbols",
|
||||||
|
type=str_or_none,
|
||||||
|
help="non_linguistic_symbols file path",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--cleaner",
|
||||||
|
type=str_or_none,
|
||||||
|
choices=[None, "tacotron", "jaconv", "vietnamese"],
|
||||||
|
default=None,
|
||||||
|
help="Apply text cleaning",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--g2p",
|
||||||
|
type=str_or_none,
|
||||||
|
choices=g2p_choices,
|
||||||
|
default=None,
|
||||||
|
help="Specify g2p method if --token_type=phn",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--speech_volume_normalize",
|
||||||
|
type=float_or_none,
|
||||||
|
default=None,
|
||||||
|
help="Scale the maximum amplitude to the given value.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--rir_scp",
|
||||||
|
type=str_or_none,
|
||||||
|
default=None,
|
||||||
|
help="The file path of rir scp file.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--rir_apply_prob",
|
||||||
|
type=float,
|
||||||
|
default=1.0,
|
||||||
|
help="THe probability for applying RIR convolution.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--noise_scp",
|
||||||
|
type=str_or_none,
|
||||||
|
default=None,
|
||||||
|
help="The file path of noise scp file.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--noise_apply_prob",
|
||||||
|
type=float,
|
||||||
|
default=1.0,
|
||||||
|
help="The probability applying Noise adding.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--noise_db_range",
|
||||||
|
type=str,
|
||||||
|
default="13_15",
|
||||||
|
help="The range of noise decibel level.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pred_masked_weight",
|
||||||
|
type=float,
|
||||||
|
default=1.0,
|
||||||
|
help="weight for predictive loss for masked frames",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pred_nomask_weight",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="weight for predictive loss for unmasked frames",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--loss_weights",
|
||||||
|
type=float,
|
||||||
|
default=0.0,
|
||||||
|
help="weights for additional loss terms (not first one)",
|
||||||
|
)
|
||||||
|
|
||||||
|
for class_choices in cls.class_choices_list:
|
||||||
|
# Append --<name> and --<name>_conf.
|
||||||
|
# e.g. --encoder and --encoder_conf
|
||||||
|
class_choices.add_arguments(group)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_collate_fn(
|
||||||
|
cls, args: argparse.Namespace, train: bool
|
||||||
|
) -> Callable[
|
||||||
|
[Collection[Tuple[str, Dict[str, np.ndarray]]]],
|
||||||
|
Tuple[List[str], Dict[str, torch.Tensor]],
|
||||||
|
]:
|
||||||
|
assert check_argument_types()
|
||||||
|
return CommonCollateFn(clipping=True)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_preprocess_fn(
|
||||||
|
cls, args: argparse.Namespace, train: bool
|
||||||
|
) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
|
||||||
|
assert check_argument_types()
|
||||||
|
if args.use_preprocessor:
|
||||||
|
retval = CommonPreprocessor(
|
||||||
|
train=train,
|
||||||
|
bpemodel=args.bpemodel,
|
||||||
|
non_linguistic_symbols=args.non_linguistic_symbols,
|
||||||
|
text_cleaner=args.cleaner,
|
||||||
|
g2p_type=args.g2p,
|
||||||
|
# NOTE(kamo): Check attribute existence for backward compatibility
|
||||||
|
rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
|
||||||
|
rir_apply_prob=args.rir_apply_prob
|
||||||
|
if hasattr(args, "rir_apply_prob")
|
||||||
|
else 1.0,
|
||||||
|
noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
|
||||||
|
noise_apply_prob=args.noise_apply_prob
|
||||||
|
if hasattr(args, "noise_apply_prob")
|
||||||
|
else 1.0,
|
||||||
|
noise_db_range=args.noise_db_range
|
||||||
|
if hasattr(args, "noise_db_range")
|
||||||
|
else "13_15",
|
||||||
|
speech_volume_normalize=args.speech_volume_normalize
|
||||||
|
if hasattr(args, "rir_scp")
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
retval = None
|
||||||
|
assert check_return_type(retval)
|
||||||
|
return retval
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def required_data_names(
|
||||||
|
cls, train: bool = True, inference: bool = False
|
||||||
|
) -> Tuple[str, ...]:
|
||||||
|
# for pre-training
|
||||||
|
retval = ("speech",)
|
||||||
|
return retval
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def optional_data_names(
|
||||||
|
cls, train: bool = True, inference: bool = False
|
||||||
|
) -> Tuple[str, ...]:
|
||||||
|
retval = ()
|
||||||
|
assert check_return_type(retval)
|
||||||
|
return retval
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_model(cls, args: argparse.Namespace):
|
||||||
|
assert check_argument_types()
|
||||||
|
|
||||||
|
# 1. frontend
|
||||||
|
if args.input_size is None:
|
||||||
|
# Extract features in the model
|
||||||
|
frontend_class = frontend_choices.get_class(args.frontend)
|
||||||
|
frontend = frontend_class(**args.frontend_conf)
|
||||||
|
input_size = frontend.output_size()
|
||||||
|
else:
|
||||||
|
# Give features from data-loader
|
||||||
|
args.frontend = None
|
||||||
|
args.frontend_conf = {}
|
||||||
|
frontend = None
|
||||||
|
input_size = args.input_size
|
||||||
|
|
||||||
|
# 2. Data augmentation for spectrogram
|
||||||
|
if args.specaug is not None:
|
||||||
|
specaug_class = specaug_choices.get_class(args.specaug)
|
||||||
|
specaug = specaug_class(**args.specaug_conf)
|
||||||
|
else:
|
||||||
|
specaug = None
|
||||||
|
|
||||||
|
# 3. Normalization layer
|
||||||
|
if args.normalize is not None:
|
||||||
|
normalize_class = normalize_choices.get_class(args.normalize)
|
||||||
|
normalize = normalize_class(**args.normalize_conf)
|
||||||
|
else:
|
||||||
|
normalize = None
|
||||||
|
|
||||||
|
# 4. Pre-encoder input block
|
||||||
|
# NOTE(kan-bayashi): Use getattr to keep the compatibility
|
||||||
|
if getattr(args, "preencoder", None) is not None:
|
||||||
|
preencoder_class = preencoder_choices.get_class(args.preencoder)
|
||||||
|
preencoder = preencoder_class(**args.preencoder_conf)
|
||||||
|
input_size = preencoder.output_size()
|
||||||
|
else:
|
||||||
|
preencoder = None
|
||||||
|
|
||||||
|
# 5. Encoder
|
||||||
|
encoder_class = encoder_choices.get_class(args.encoder)
|
||||||
|
encoder = encoder_class(
|
||||||
|
input_size=input_size,
|
||||||
|
**args.encoder_conf,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 6. Build model
|
||||||
|
try:
|
||||||
|
model_class = model_choices.get_class(args.model)
|
||||||
|
except AttributeError:
|
||||||
|
model_class = model_choices.get_class("data2vec")
|
||||||
|
model = model_class(
|
||||||
|
frontend=frontend,
|
||||||
|
specaug=specaug,
|
||||||
|
normalize=normalize,
|
||||||
|
preencoder=preencoder,
|
||||||
|
encoder=encoder,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 7. Initialize
|
||||||
|
if args.init is not None:
|
||||||
|
initialize(model, args.init)
|
||||||
|
|
||||||
|
assert check_return_type(model)
|
||||||
|
return model
|
||||||
Loading…
Reference in New Issue
Block a user