From 680cdb55bbde415c2f750e58808faedc6d1a6bf3 Mon Sep 17 00:00:00 2001 From: speech_asr Date: Wed, 19 Apr 2023 14:49:36 +0800 Subject: [PATCH] update --- funasr/bin/train.py | 22 +++++----------------- funasr/utils/build_model.py | 2 +- funasr/utils/build_optimizer.py | 26 ++++++++++++++++++++++++++ funasr/utils/build_scheduler.py | 31 +++++++++++++++++++++++++++++++ 4 files changed, 63 insertions(+), 18 deletions(-) create mode 100644 funasr/utils/build_optimizer.py create mode 100644 funasr/utils/build_scheduler.py diff --git a/funasr/bin/train.py b/funasr/bin/train.py index 2a5dc9815..c0e41575d 100644 --- a/funasr/bin/train.py +++ b/funasr/bin/train.py @@ -9,6 +9,8 @@ from funasr.utils import config_argparse from funasr.utils.build_dataloader import build_dataloader from funasr.utils.build_distributed import build_distributed from funasr.utils.prepare_data import prepare_data +from funasr.utils.build_optimizer import build_optimizer +from funasr.utils.build_scheduler import build_scheduler from funasr.utils.types import str2bool @@ -355,20 +357,6 @@ if __name__ == '__main__': distributed_option.dist_rank, distributed_option.local_rank)) - # optimizers = cls.build_optimizers(args, model=model) - # schedulers = [] - # for i, optim in enumerate(optimizers, 1): - # suf = "" if i == 1 else str(i) - # name = getattr(args, f"scheduler{suf}") - # conf = getattr(args, f"scheduler{suf}_conf") - # if name is not None: - # cls_ = scheduler_classes.get(name) - # if cls_ is None: - # raise ValueError( - # f"must be one of {list(scheduler_classes)}: {name}" - # ) - # scheduler = cls_(optim, **conf) - # else: - # scheduler = None - # - # schedulers.append(scheduler) + model = build_model(args) + optimizers = build_optimizer(args, model=model) + schedule = build_scheduler(args) diff --git a/funasr/utils/build_model.py b/funasr/utils/build_model.py index 9fa980a09..abe384cd8 100644 --- a/funasr/utils/build_model.py +++ b/funasr/utils/build_model.py @@ -2,7 +2,7 @@ import logging def build_model(args): if args.token_list is not None: - with open(args.token_list, encoding="utf-8") as f: + with open(args.token_list) as f: token_list = [line.rstrip() for line in f] args.token_list = list(token_list) vocab_size = len(token_list) diff --git a/funasr/utils/build_optimizer.py b/funasr/utils/build_optimizer.py new file mode 100644 index 000000000..3b2799492 --- /dev/null +++ b/funasr/utils/build_optimizer.py @@ -0,0 +1,26 @@ +import torch + +from funasr.optimizers.fairseq_adam import FairseqAdam +from funasr.optimizers.sgd import SGD + + +def build_optimizer(args, model): + optim_classes = dict( + adam=torch.optim.Adam, + fairseq_adam=FairseqAdam, + adamw=torch.optim.AdamW, + sgd=SGD, + adadelta=torch.optim.Adadelta, + adagrad=torch.optim.Adagrad, + adamax=torch.optim.Adamax, + asgd=torch.optim.ASGD, + lbfgs=torch.optim.LBFGS, + rmsprop=torch.optim.RMSprop, + rprop=torch.optim.Rprop, + ) + + optim_class = optim_classes.get(args.optim) + if optim_class is None: + raise ValueError(f"must be one of {list(optim_classes)}: {args.optim}") + optimizer = optim_class(model.parameters(), **args.optim_conf) + return optimizer \ No newline at end of file diff --git a/funasr/utils/build_scheduler.py b/funasr/utils/build_scheduler.py new file mode 100644 index 000000000..f0e6d1f54 --- /dev/null +++ b/funasr/utils/build_scheduler.py @@ -0,0 +1,31 @@ +import torch +import torch.multiprocessing +import torch.nn +import torch.optim + +from funasr.schedulers.noam_lr import NoamLR +from funasr.schedulers.tri_stage_scheduler import TriStageLR +from funasr.schedulers.warmup_lr import WarmupLR + + +def build_scheduler(args, optimizer): + scheduler_classes = dict( + ReduceLROnPlateau=torch.optim.lr_scheduler.ReduceLROnPlateau, + lambdalr=torch.optim.lr_scheduler.LambdaLR, + steplr=torch.optim.lr_scheduler.StepLR, + multisteplr=torch.optim.lr_scheduler.MultiStepLR, + exponentiallr=torch.optim.lr_scheduler.ExponentialLR, + CosineAnnealingLR=torch.optim.lr_scheduler.CosineAnnealingLR, + noamlr=NoamLR, + warmuplr=WarmupLR, + tri_stage=TriStageLR, + cycliclr=torch.optim.lr_scheduler.CyclicLR, + onecyclelr=torch.optim.lr_scheduler.OneCycleLR, + CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts, + ) + + scheduler_class = scheduler_classes.get(args.scheduler) + if scheduler_class is None: + raise ValueError(f"must be one of {list(scheduler_classes)}: {args.scheduler}") + scheduler = scheduler_class(optimizer, **args.scheduler_conf) + return scheduler \ No newline at end of file