This commit is contained in:
speech_asr 2023-04-19 14:49:36 +08:00
parent 58fb22cb2b
commit 680cdb55bb
4 changed files with 63 additions and 18 deletions

View File

@ -9,6 +9,8 @@ from funasr.utils import config_argparse
from funasr.utils.build_dataloader import build_dataloader from funasr.utils.build_dataloader import build_dataloader
from funasr.utils.build_distributed import build_distributed from funasr.utils.build_distributed import build_distributed
from funasr.utils.prepare_data import prepare_data from funasr.utils.prepare_data import prepare_data
from funasr.utils.build_optimizer import build_optimizer
from funasr.utils.build_scheduler import build_scheduler
from funasr.utils.types import str2bool from funasr.utils.types import str2bool
@ -355,20 +357,6 @@ if __name__ == '__main__':
distributed_option.dist_rank, distributed_option.dist_rank,
distributed_option.local_rank)) distributed_option.local_rank))
# optimizers = cls.build_optimizers(args, model=model) model = build_model(args)
# schedulers = [] optimizers = build_optimizer(args, model=model)
# for i, optim in enumerate(optimizers, 1): schedule = build_scheduler(args)
# suf = "" if i == 1 else str(i)
# name = getattr(args, f"scheduler{suf}")
# conf = getattr(args, f"scheduler{suf}_conf")
# if name is not None:
# cls_ = scheduler_classes.get(name)
# if cls_ is None:
# raise ValueError(
# f"must be one of {list(scheduler_classes)}: {name}"
# )
# scheduler = cls_(optim, **conf)
# else:
# scheduler = None
#
# schedulers.append(scheduler)

View File

@ -2,7 +2,7 @@ import logging
def build_model(args): def build_model(args):
if args.token_list is not None: if args.token_list is not None:
with open(args.token_list, encoding="utf-8") as f: with open(args.token_list) as f:
token_list = [line.rstrip() for line in f] token_list = [line.rstrip() for line in f]
args.token_list = list(token_list) args.token_list = list(token_list)
vocab_size = len(token_list) vocab_size = len(token_list)

View File

@ -0,0 +1,26 @@
import torch
from funasr.optimizers.fairseq_adam import FairseqAdam
from funasr.optimizers.sgd import SGD
def build_optimizer(args, model):
optim_classes = dict(
adam=torch.optim.Adam,
fairseq_adam=FairseqAdam,
adamw=torch.optim.AdamW,
sgd=SGD,
adadelta=torch.optim.Adadelta,
adagrad=torch.optim.Adagrad,
adamax=torch.optim.Adamax,
asgd=torch.optim.ASGD,
lbfgs=torch.optim.LBFGS,
rmsprop=torch.optim.RMSprop,
rprop=torch.optim.Rprop,
)
optim_class = optim_classes.get(args.optim)
if optim_class is None:
raise ValueError(f"must be one of {list(optim_classes)}: {args.optim}")
optimizer = optim_class(model.parameters(), **args.optim_conf)
return optimizer

View File

@ -0,0 +1,31 @@
import torch
import torch.multiprocessing
import torch.nn
import torch.optim
from funasr.schedulers.noam_lr import NoamLR
from funasr.schedulers.tri_stage_scheduler import TriStageLR
from funasr.schedulers.warmup_lr import WarmupLR
def build_scheduler(args, optimizer):
scheduler_classes = dict(
ReduceLROnPlateau=torch.optim.lr_scheduler.ReduceLROnPlateau,
lambdalr=torch.optim.lr_scheduler.LambdaLR,
steplr=torch.optim.lr_scheduler.StepLR,
multisteplr=torch.optim.lr_scheduler.MultiStepLR,
exponentiallr=torch.optim.lr_scheduler.ExponentialLR,
CosineAnnealingLR=torch.optim.lr_scheduler.CosineAnnealingLR,
noamlr=NoamLR,
warmuplr=WarmupLR,
tri_stage=TriStageLR,
cycliclr=torch.optim.lr_scheduler.CyclicLR,
onecyclelr=torch.optim.lr_scheduler.OneCycleLR,
CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
)
scheduler_class = scheduler_classes.get(args.scheduler)
if scheduler_class is None:
raise ValueError(f"must be one of {list(scheduler_classes)}: {args.scheduler}")
scheduler = scheduler_class(optimizer, **args.scheduler_conf)
return scheduler