FunASR/funasr/build_utils/build_scheduler.py
2023-04-23 17:30:38 +08:00

44 lines
1.5 KiB
Python

import torch
import torch.multiprocessing
import torch.nn
import torch.optim
from funasr.schedulers.noam_lr import NoamLR
from funasr.schedulers.tri_stage_scheduler import TriStageLR
from funasr.schedulers.warmup_lr import WarmupLR
def build_scheduler(args, optimizers):
scheduler_classes = dict(
ReduceLROnPlateau=torch.optim.lr_scheduler.ReduceLROnPlateau,
lambdalr=torch.optim.lr_scheduler.LambdaLR,
steplr=torch.optim.lr_scheduler.StepLR,
multisteplr=torch.optim.lr_scheduler.MultiStepLR,
exponentiallr=torch.optim.lr_scheduler.ExponentialLR,
CosineAnnealingLR=torch.optim.lr_scheduler.CosineAnnealingLR,
noamlr=NoamLR,
warmuplr=WarmupLR,
tri_stage=TriStageLR,
cycliclr=torch.optim.lr_scheduler.CyclicLR,
onecyclelr=torch.optim.lr_scheduler.OneCycleLR,
CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
)
schedulers = []
for i, optim in enumerate(optimizers, 1):
suf = "" if i == 1 else str(i)
name = getattr(args, f"scheduler{suf}")
conf = getattr(args, f"scheduler{suf}_conf")
if name is not None:
cls_ = scheduler_classes.get(name)
if cls_ is None:
raise ValueError(
f"must be one of {list(scheduler_classes)}: {name}"
)
scheduler = cls_(optim, **conf)
else:
scheduler = None
schedulers.append(scheduler)
return schedulers