FunASR/funasr/build_utils/build_pretrain_model.py
2023-07-29 15:18:02 +08:00

113 lines
3.0 KiB
Python

from funasr.layers.global_mvn import GlobalMVN
from funasr.layers.utterance_mvn import UtteranceMVN
from funasr.models.data2vec import Data2VecPretrainModel
from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
from funasr.models.frontend.default import DefaultFrontend
from funasr.models.frontend.windowing import SlidingWindow
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.models.specaug.specaug import SpecAug
from funasr.torch_utils.initialize import initialize
from funasr.train.class_choices import ClassChoices
frontend_choices = ClassChoices(
name="frontend",
classes=dict(
default=DefaultFrontend,
sliding_window=SlidingWindow,
wav_frontend=WavFrontend,
),
default="default",
)
specaug_choices = ClassChoices(
name="specaug",
classes=dict(specaug=SpecAug),
default=None,
optional=True,
)
normalize_choices = ClassChoices(
"normalize",
classes=dict(
global_mvn=GlobalMVN,
utterance_mvn=UtteranceMVN,
),
default=None,
optional=True,
)
encoder_choices = ClassChoices(
"encoder",
classes=dict(
data2vec_encoder=Data2VecEncoder,
),
default="data2vec_encoder",
)
model_choices = ClassChoices(
"model",
classes=dict(
data2vec=Data2VecPretrainModel,
),
default="data2vec",
)
class_choices_list = [
# --frontend and --frontend_conf
frontend_choices,
# --specaug and --specaug_conf
specaug_choices,
# --normalize and --normalize_conf
normalize_choices,
# --encoder and --encoder_conf
encoder_choices,
# --model and --model_conf
model_choices,
]
def build_pretrain_model(args):
# frontend
if args.input_size is None:
frontend_class = frontend_choices.get_class(args.frontend)
frontend = frontend_class(**args.frontend_conf)
input_size = frontend.output_size()
else:
args.frontend = None
args.frontend_conf = {}
frontend = None
input_size = args.input_size
# data augmentation for spectrogram
if args.specaug is not None:
specaug_class = specaug_choices.get_class(args.specaug)
specaug = specaug_class(**args.specaug_conf)
else:
specaug = None
# normalization layer
if args.normalize is not None:
normalize_class = normalize_choices.get_class(args.normalize)
normalize = normalize_class(**args.normalize_conf)
else:
normalize = None
# encoder
encoder_class = encoder_choices.get_class(args.encoder)
encoder = encoder_class(
input_size=input_size,
**args.encoder_conf,
)
if args.model == "data2vec":
model_class = model_choices.get_class("data2vec")
model = model_class(
frontend=frontend,
specaug=specaug,
normalize=normalize,
encoder=encoder,
)
else:
raise NotImplementedError("Not supported model: {}".format(args.model))
# initialize
if args.init is not None:
initialize(model, args.init)
return model