mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
108 lines
2.9 KiB
Python
108 lines
2.9 KiB
Python
from funasr.layers.global_mvn import GlobalMVN
|
|
from funasr.layers.utterance_mvn import UtteranceMVN
|
|
from funasr.models.data2vec import Data2VecPretrainModel
|
|
from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
|
|
from funasr.models.frontend.default import DefaultFrontend
|
|
from funasr.models.frontend.windowing import SlidingWindow
|
|
from funasr.models.specaug.specaug import SpecAug
|
|
from funasr.torch_utils.initialize import initialize
|
|
from funasr.train.class_choices import ClassChoices
|
|
|
|
frontend_choices = ClassChoices(
|
|
name="frontend",
|
|
classes=dict(default=DefaultFrontend, sliding_window=SlidingWindow),
|
|
default="default",
|
|
)
|
|
specaug_choices = ClassChoices(
|
|
name="specaug",
|
|
classes=dict(specaug=SpecAug),
|
|
default=None,
|
|
optional=True,
|
|
)
|
|
normalize_choices = ClassChoices(
|
|
"normalize",
|
|
classes=dict(
|
|
global_mvn=GlobalMVN,
|
|
utterance_mvn=UtteranceMVN,
|
|
),
|
|
default=None,
|
|
optional=True,
|
|
)
|
|
encoder_choices = ClassChoices(
|
|
"encoder",
|
|
classes=dict(
|
|
data2vec_encoder=Data2VecEncoder,
|
|
),
|
|
default="data2vec_encoder",
|
|
)
|
|
model_choices = ClassChoices(
|
|
"model",
|
|
classes=dict(
|
|
data2vec=Data2VecPretrainModel,
|
|
),
|
|
default="data2vec",
|
|
)
|
|
class_choices_list = [
|
|
# --frontend and --frontend_conf
|
|
frontend_choices,
|
|
# --specaug and --specaug_conf
|
|
specaug_choices,
|
|
# --normalize and --normalize_conf
|
|
normalize_choices,
|
|
# --encoder and --encoder_conf
|
|
encoder_choices,
|
|
# --model and --model_conf
|
|
model_choices,
|
|
]
|
|
|
|
|
|
def build_pretrain_model(args):
|
|
# frontend
|
|
if args.input_size is None:
|
|
frontend_class = frontend_choices.get_class(args.frontend)
|
|
frontend = frontend_class(**args.frontend_conf)
|
|
input_size = frontend.output_size()
|
|
else:
|
|
args.frontend = None
|
|
args.frontend_conf = {}
|
|
frontend = None
|
|
input_size = args.input_size
|
|
|
|
# data augmentation for spectrogram
|
|
if args.specaug is not None:
|
|
specaug_class = specaug_choices.get_class(args.specaug)
|
|
specaug = specaug_class(**args.specaug_conf)
|
|
else:
|
|
specaug = None
|
|
|
|
# normalization layer
|
|
if args.normalize is not None:
|
|
normalize_class = normalize_choices.get_class(args.normalize)
|
|
normalize = normalize_class(**args.normalize_conf)
|
|
else:
|
|
normalize = None
|
|
|
|
# encoder
|
|
encoder_class = encoder_choices.get_class(args.encoder)
|
|
encoder = encoder_class(
|
|
input_size=input_size,
|
|
**args.encoder_conf,
|
|
)
|
|
|
|
if args.model_name == "data2vec":
|
|
model_class = model_choices.get_class("data2vec")
|
|
model = model_class(
|
|
frontend=frontend,
|
|
specaug=specaug,
|
|
normalize=normalize,
|
|
encoder=encoder,
|
|
)
|
|
else:
|
|
raise NotImplementedError("Not supported model: {}".format(args.model))
|
|
|
|
# initialize
|
|
if args.init is not None:
|
|
initialize(model, args.init)
|
|
|
|
return model
|