mirror of
https://github.com/modelscope/FunASR
synced 2025-09-15 14:48:36 +08:00
44 lines
1.5 KiB
Python
44 lines
1.5 KiB
Python
import torch
|
|
import torch.multiprocessing
|
|
import torch.nn
|
|
import torch.optim
|
|
|
|
from funasr.schedulers.noam_lr import NoamLR
|
|
from funasr.schedulers.tri_stage_scheduler import TriStageLR
|
|
from funasr.schedulers.warmup_lr import WarmupLR
|
|
|
|
|
|
def build_scheduler(args, optimizers):
|
|
scheduler_classes = dict(
|
|
ReduceLROnPlateau=torch.optim.lr_scheduler.ReduceLROnPlateau,
|
|
lambdalr=torch.optim.lr_scheduler.LambdaLR,
|
|
steplr=torch.optim.lr_scheduler.StepLR,
|
|
multisteplr=torch.optim.lr_scheduler.MultiStepLR,
|
|
exponentiallr=torch.optim.lr_scheduler.ExponentialLR,
|
|
CosineAnnealingLR=torch.optim.lr_scheduler.CosineAnnealingLR,
|
|
noamlr=NoamLR,
|
|
warmuplr=WarmupLR,
|
|
tri_stage=TriStageLR,
|
|
cycliclr=torch.optim.lr_scheduler.CyclicLR,
|
|
onecyclelr=torch.optim.lr_scheduler.OneCycleLR,
|
|
CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
|
|
)
|
|
|
|
schedulers = []
|
|
for i, optim in enumerate(optimizers, 1):
|
|
suf = "" if i == 1 else str(i)
|
|
name = getattr(args, f"scheduler{suf}")
|
|
conf = getattr(args, f"scheduler{suf}_conf")
|
|
if name is not None:
|
|
cls_ = scheduler_classes.get(name)
|
|
if cls_ is None:
|
|
raise ValueError(
|
|
f"must be one of {list(scheduler_classes)}: {name}"
|
|
)
|
|
scheduler = cls_(optim, **conf)
|
|
else:
|
|
scheduler = None
|
|
|
|
schedulers.append(scheduler)
|
|
|
|
return schedulers |