mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
update
This commit is contained in:
parent
3e77fd4430
commit
eac9f111b5
@ -79,7 +79,62 @@ def build_args(args):
|
||||
default=None,
|
||||
help="The file path of noise scp file.",
|
||||
)
|
||||
|
||||
elif args.task_name == "pretrain":
|
||||
from funasr.utils.build_pretrain_model import class_choices_list
|
||||
for class_choices in class_choices_list:
|
||||
# Append --<name> and --<name>_conf.
|
||||
# e.g. --encoder and --encoder_conf
|
||||
class_choices.add_arguments(parser)
|
||||
parser.add_argument(
|
||||
"--init",
|
||||
type=lambda x: str_or_none(x.lower()),
|
||||
default=None,
|
||||
help="The initialization method",
|
||||
choices=[
|
||||
"chainer",
|
||||
"xavier_uniform",
|
||||
"xavier_normal",
|
||||
"kaiming_uniform",
|
||||
"kaiming_normal",
|
||||
None,
|
||||
],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input_size",
|
||||
type=int_or_none,
|
||||
default=None,
|
||||
help="The number of input dimension of the feature",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--feats_type",
|
||||
type=str,
|
||||
default='fbank',
|
||||
help="feats type, e.g. fbank, wav, ark_wav(needed to be scale normalization)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--noise_db_range",
|
||||
type=str,
|
||||
default="13_15",
|
||||
help="The range of noise decibel level.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pred_masked_weight",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="weight for predictive loss for masked frames",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pred_nomask_weight",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="weight for predictive loss for unmasked frames",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--loss_weights",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="weights for additional loss terms (not first one)",
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Not supported task: {}".format(args.task_name))
|
||||
|
||||
|
||||
@ -345,6 +345,7 @@ def build_asr_model(args):
|
||||
else:
|
||||
raise NotImplementedError("Not supported model: {}".format(args.model))
|
||||
|
||||
# initialize
|
||||
if args.init is not None:
|
||||
initialize(model, args.init)
|
||||
|
||||
|
||||
34
funasr/utils/build_lm_model.py
Normal file
34
funasr/utils/build_lm_model.py
Normal file
@ -0,0 +1,34 @@
|
||||
from funasr.lm.abs_model import AbsLM
|
||||
from funasr.lm.seq_rnn_lm import SequentialRNNLM
|
||||
from funasr.lm.transformer_lm import TransformerLM
|
||||
from funasr.torch_utils.initialize import initialize
|
||||
from funasr.train.class_choices import ClassChoices
|
||||
|
||||
lm_choices = ClassChoices(
|
||||
"lm",
|
||||
classes=dict(
|
||||
seq_rnn=SequentialRNNLM,
|
||||
transformer=TransformerLM,
|
||||
),
|
||||
type_check=AbsLM,
|
||||
default="seq_rnn",
|
||||
)
|
||||
|
||||
class_choices_list = [
|
||||
# --lm and --lm_conf
|
||||
lm_choices
|
||||
]
|
||||
|
||||
|
||||
def build_pretrain_model(args):
|
||||
# token_list
|
||||
if args.token_list is not None:
|
||||
with open(args.token_list) as f:
|
||||
token_list = [line.rstrip() for line in f]
|
||||
args.token_list = list(token_list)
|
||||
vocab_size = len(token_list)
|
||||
logging.info(f"Vocabulary size: {vocab_size}")
|
||||
else:
|
||||
vocab_size = None
|
||||
|
||||
return model
|
||||
@ -7,6 +7,8 @@ def build_model(args):
|
||||
model = build_asr_model(args)
|
||||
elif args.task_name == "pretrain":
|
||||
model = build_pretrain_model(args)
|
||||
elif args.task_name == "lm":
|
||||
model = build_lm_model(args)
|
||||
else:
|
||||
raise NotImplementedError("Not supported task: {}".format(args.task_name))
|
||||
|
||||
|
||||
@ -57,39 +57,39 @@ class_choices_list = [
|
||||
|
||||
|
||||
def build_pretrain_model(args):
|
||||
# frontend
|
||||
if args.input_size is None:
|
||||
frontend_class = frontend_choices.get_class(args.frontend)
|
||||
frontend = frontend_class(**args.frontend_conf)
|
||||
input_size = frontend.output_size()
|
||||
else:
|
||||
args.frontend = None
|
||||
args.frontend_conf = {}
|
||||
frontend = None
|
||||
input_size = args.input_size
|
||||
|
||||
# data augmentation for spectrogram
|
||||
if args.specaug is not None:
|
||||
specaug_class = specaug_choices.get_class(args.specaug)
|
||||
specaug = specaug_class(**args.specaug_conf)
|
||||
else:
|
||||
specaug = None
|
||||
|
||||
# normalization layer
|
||||
if args.normalize is not None:
|
||||
normalize_class = normalize_choices.get_class(args.normalize)
|
||||
normalize = normalize_class(**args.normalize_conf)
|
||||
else:
|
||||
normalize = None
|
||||
|
||||
# encoder
|
||||
encoder_class = encoder_choices.get_class(args.encoder)
|
||||
encoder = encoder_class(
|
||||
input_size=input_size,
|
||||
**args.encoder_conf,
|
||||
)
|
||||
|
||||
if args.model_name == "data2vec":
|
||||
# frontend
|
||||
if args.input_size is None:
|
||||
frontend_class = frontend_choices.get_class(args.frontend)
|
||||
frontend = frontend_class(**args.frontend_conf)
|
||||
input_size = frontend.output_size()
|
||||
else:
|
||||
args.frontend = None
|
||||
args.frontend_conf = {}
|
||||
frontend = None
|
||||
input_size = args.input_size
|
||||
|
||||
# data augmentation for spectrogram
|
||||
if args.specaug is not None:
|
||||
specaug_class = specaug_choices.get_class(args.specaug)
|
||||
specaug = specaug_class(**args.specaug_conf)
|
||||
else:
|
||||
specaug = None
|
||||
|
||||
# normalization layer
|
||||
if args.normalize is not None:
|
||||
normalize_class = normalize_choices.get_class(args.normalize)
|
||||
normalize = normalize_class(**args.normalize_conf)
|
||||
else:
|
||||
normalize = None
|
||||
|
||||
# encoder
|
||||
encoder_class = encoder_choices.get_class(args.encoder)
|
||||
encoder = encoder_class(
|
||||
input_size=input_size,
|
||||
**args.encoder_conf,
|
||||
)
|
||||
|
||||
model_class = model_choices.get_class("data2vec")
|
||||
model = model_class(
|
||||
frontend=frontend,
|
||||
@ -97,9 +97,11 @@ def build_pretrain_model(args):
|
||||
normalize=normalize,
|
||||
encoder=encoder,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Not supported model: {}".format(args.model))
|
||||
|
||||
# 7. Initialize
|
||||
if args.init is not None:
|
||||
initialize(model, args.init)
|
||||
# initialize
|
||||
if args.init is not None:
|
||||
initialize(model, args.init)
|
||||
|
||||
return model
|
||||
return model
|
||||
|
||||
Loading…
Reference in New Issue
Block a user