FunASR/funasr/utils/build_optimizer.py
speech_asr 680cdb55bb update
2023-04-19 14:49:36 +08:00

26 lines
779 B
Python

import torch
from funasr.optimizers.fairseq_adam import FairseqAdam
from funasr.optimizers.sgd import SGD
def build_optimizer(args, model):
optim_classes = dict(
adam=torch.optim.Adam,
fairseq_adam=FairseqAdam,
adamw=torch.optim.AdamW,
sgd=SGD,
adadelta=torch.optim.Adadelta,
adagrad=torch.optim.Adagrad,
adamax=torch.optim.Adamax,
asgd=torch.optim.ASGD,
lbfgs=torch.optim.LBFGS,
rmsprop=torch.optim.RMSprop,
rprop=torch.optim.Rprop,
)
optim_class = optim_classes.get(args.optim)
if optim_class is None:
raise ValueError(f"must be one of {list(optim_classes)}: {args.optim}")
optimizer = optim_class(model.parameters(), **args.optim_conf)
return optimizer