mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update
This commit is contained in:
parent
d6cc6896e4
commit
3e77fd4430
@ -210,7 +210,6 @@ def build_asr_model(args):
|
||||
|
||||
# frontend
|
||||
if args.input_size is None:
|
||||
# Extract features in the model
|
||||
frontend_class = frontend_choices.get_class(args.frontend)
|
||||
if args.frontend == 'wav_frontend':
|
||||
frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
|
||||
@ -218,7 +217,6 @@ def build_asr_model(args):
|
||||
frontend = frontend_class(**args.frontend_conf)
|
||||
input_size = frontend.output_size()
|
||||
else:
|
||||
# Give features from data-loader
|
||||
args.frontend = None
|
||||
args.frontend_conf = {}
|
||||
frontend = None
|
||||
|
||||
@ -1,9 +1,12 @@
|
||||
from funasr.utils.build_asr_model import build_asr_model
|
||||
from funasr.utils.build_pretrain_model import build_pretrain_model
|
||||
|
||||
|
||||
def build_model(args):
|
||||
if args.task_name == "asr":
|
||||
model = build_asr_model(args)
|
||||
elif args.task_name == "pretrain":
|
||||
model = build_pretrain_model(args)
|
||||
else:
|
||||
raise NotImplementedError("Not supported task: {}".format(args.task_name))
|
||||
|
||||
|
||||
105
funasr/utils/build_pretrain_model.py
Normal file
105
funasr/utils/build_pretrain_model.py
Normal file
@ -0,0 +1,105 @@
|
||||
from funasr.layers.global_mvn import GlobalMVN
|
||||
from funasr.layers.utterance_mvn import UtteranceMVN
|
||||
from funasr.models.data2vec import Data2VecPretrainModel
|
||||
from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
|
||||
from funasr.models.frontend.default import DefaultFrontend
|
||||
from funasr.models.frontend.windowing import SlidingWindow
|
||||
from funasr.models.specaug.specaug import SpecAug
|
||||
from funasr.torch_utils.initialize import initialize
|
||||
from funasr.train.class_choices import ClassChoices
|
||||
|
||||
frontend_choices = ClassChoices(
|
||||
name="frontend",
|
||||
classes=dict(default=DefaultFrontend, sliding_window=SlidingWindow),
|
||||
default="default",
|
||||
)
|
||||
specaug_choices = ClassChoices(
|
||||
name="specaug",
|
||||
classes=dict(specaug=SpecAug),
|
||||
default=None,
|
||||
optional=True,
|
||||
)
|
||||
normalize_choices = ClassChoices(
|
||||
"normalize",
|
||||
classes=dict(
|
||||
global_mvn=GlobalMVN,
|
||||
utterance_mvn=UtteranceMVN,
|
||||
),
|
||||
default=None,
|
||||
optional=True,
|
||||
)
|
||||
encoder_choices = ClassChoices(
|
||||
"encoder",
|
||||
classes=dict(
|
||||
data2vec_encoder=Data2VecEncoder,
|
||||
),
|
||||
default="data2vec_encoder",
|
||||
)
|
||||
model_choices = ClassChoices(
|
||||
"model",
|
||||
classes=dict(
|
||||
data2vec=Data2VecPretrainModel,
|
||||
),
|
||||
default="data2vec",
|
||||
)
|
||||
class_choices_list = [
|
||||
# --frontend and --frontend_conf
|
||||
frontend_choices,
|
||||
# --specaug and --specaug_conf
|
||||
specaug_choices,
|
||||
# --normalize and --normalize_conf
|
||||
normalize_choices,
|
||||
# --encoder and --encoder_conf
|
||||
encoder_choices,
|
||||
# --model and --model_conf
|
||||
model_choices,
|
||||
]
|
||||
|
||||
|
||||
def build_pretrain_model(args):
|
||||
if args.model_name == "data2vec":
|
||||
# frontend
|
||||
if args.input_size is None:
|
||||
frontend_class = frontend_choices.get_class(args.frontend)
|
||||
frontend = frontend_class(**args.frontend_conf)
|
||||
input_size = frontend.output_size()
|
||||
else:
|
||||
args.frontend = None
|
||||
args.frontend_conf = {}
|
||||
frontend = None
|
||||
input_size = args.input_size
|
||||
|
||||
# data augmentation for spectrogram
|
||||
if args.specaug is not None:
|
||||
specaug_class = specaug_choices.get_class(args.specaug)
|
||||
specaug = specaug_class(**args.specaug_conf)
|
||||
else:
|
||||
specaug = None
|
||||
|
||||
# normalization layer
|
||||
if args.normalize is not None:
|
||||
normalize_class = normalize_choices.get_class(args.normalize)
|
||||
normalize = normalize_class(**args.normalize_conf)
|
||||
else:
|
||||
normalize = None
|
||||
|
||||
# encoder
|
||||
encoder_class = encoder_choices.get_class(args.encoder)
|
||||
encoder = encoder_class(
|
||||
input_size=input_size,
|
||||
**args.encoder_conf,
|
||||
)
|
||||
|
||||
model_class = model_choices.get_class("data2vec")
|
||||
model = model_class(
|
||||
frontend=frontend,
|
||||
specaug=specaug,
|
||||
normalize=normalize,
|
||||
encoder=encoder,
|
||||
)
|
||||
|
||||
# 7. Initialize
|
||||
if args.init is not None:
|
||||
initialize(model, args.init)
|
||||
|
||||
return model
|
||||
Loading…
Reference in New Issue
Block a user