This commit is contained in:
speech_asr 2023-04-20 16:59:26 +08:00
parent 3e77fd4430
commit eac9f111b5
5 changed files with 131 additions and 37 deletions

View File

@ -79,7 +79,62 @@ def build_args(args):
default=None,
help="The file path of noise scp file.",
)
elif args.task_name == "pretrain":
from funasr.utils.build_pretrain_model import class_choices_list
for class_choices in class_choices_list:
# Append --<name> and --<name>_conf.
# e.g. --encoder and --encoder_conf
class_choices.add_arguments(parser)
parser.add_argument(
"--init",
type=lambda x: str_or_none(x.lower()),
default=None,
help="The initialization method",
choices=[
"chainer",
"xavier_uniform",
"xavier_normal",
"kaiming_uniform",
"kaiming_normal",
None,
],
)
parser.add_argument(
"--input_size",
type=int_or_none,
default=None,
help="The number of input dimension of the feature",
)
parser.add_argument(
"--feats_type",
type=str,
default='fbank',
help="feats type, e.g. fbank, wav, ark_wav(needed to be scale normalization)",
)
parser.add_argument(
"--noise_db_range",
type=str,
default="13_15",
help="The range of noise decibel level.",
)
parser.add_argument(
"--pred_masked_weight",
type=float,
default=1.0,
help="weight for predictive loss for masked frames",
)
parser.add_argument(
"--pred_nomask_weight",
type=float,
default=0.0,
help="weight for predictive loss for unmasked frames",
)
parser.add_argument(
"--loss_weights",
type=float,
default=0.0,
help="weights for additional loss terms (not first one)",
)
else:
raise NotImplementedError("Not supported task: {}".format(args.task_name))

View File

@ -345,6 +345,7 @@ def build_asr_model(args):
else:
raise NotImplementedError("Not supported model: {}".format(args.model))
# initialize
if args.init is not None:
initialize(model, args.init)

View File

@ -0,0 +1,34 @@
from funasr.lm.abs_model import AbsLM
from funasr.lm.seq_rnn_lm import SequentialRNNLM
from funasr.lm.transformer_lm import TransformerLM
from funasr.torch_utils.initialize import initialize
from funasr.train.class_choices import ClassChoices
lm_choices = ClassChoices(
"lm",
classes=dict(
seq_rnn=SequentialRNNLM,
transformer=TransformerLM,
),
type_check=AbsLM,
default="seq_rnn",
)
class_choices_list = [
# --lm and --lm_conf
lm_choices
]
def build_pretrain_model(args):
# token_list
if args.token_list is not None:
with open(args.token_list) as f:
token_list = [line.rstrip() for line in f]
args.token_list = list(token_list)
vocab_size = len(token_list)
logging.info(f"Vocabulary size: {vocab_size}")
else:
vocab_size = None
return model

View File

@ -7,6 +7,8 @@ def build_model(args):
model = build_asr_model(args)
elif args.task_name == "pretrain":
model = build_pretrain_model(args)
elif args.task_name == "lm":
model = build_lm_model(args)
else:
raise NotImplementedError("Not supported task: {}".format(args.task_name))

View File

@ -57,39 +57,39 @@ class_choices_list = [
def build_pretrain_model(args):
# frontend
if args.input_size is None:
frontend_class = frontend_choices.get_class(args.frontend)
frontend = frontend_class(**args.frontend_conf)
input_size = frontend.output_size()
else:
args.frontend = None
args.frontend_conf = {}
frontend = None
input_size = args.input_size
# data augmentation for spectrogram
if args.specaug is not None:
specaug_class = specaug_choices.get_class(args.specaug)
specaug = specaug_class(**args.specaug_conf)
else:
specaug = None
# normalization layer
if args.normalize is not None:
normalize_class = normalize_choices.get_class(args.normalize)
normalize = normalize_class(**args.normalize_conf)
else:
normalize = None
# encoder
encoder_class = encoder_choices.get_class(args.encoder)
encoder = encoder_class(
input_size=input_size,
**args.encoder_conf,
)
if args.model_name == "data2vec":
# frontend
if args.input_size is None:
frontend_class = frontend_choices.get_class(args.frontend)
frontend = frontend_class(**args.frontend_conf)
input_size = frontend.output_size()
else:
args.frontend = None
args.frontend_conf = {}
frontend = None
input_size = args.input_size
# data augmentation for spectrogram
if args.specaug is not None:
specaug_class = specaug_choices.get_class(args.specaug)
specaug = specaug_class(**args.specaug_conf)
else:
specaug = None
# normalization layer
if args.normalize is not None:
normalize_class = normalize_choices.get_class(args.normalize)
normalize = normalize_class(**args.normalize_conf)
else:
normalize = None
# encoder
encoder_class = encoder_choices.get_class(args.encoder)
encoder = encoder_class(
input_size=input_size,
**args.encoder_conf,
)
model_class = model_choices.get_class("data2vec")
model = model_class(
frontend=frontend,
@ -97,9 +97,11 @@ def build_pretrain_model(args):
normalize=normalize,
encoder=encoder,
)
else:
raise NotImplementedError("Not supported model: {}".format(args.model))
# 7. Initialize
if args.init is not None:
initialize(model, args.init)
# initialize
if args.init is not None:
initialize(model, args.init)
return model
return model