| 1234567891011121314151617181920212223242526272829303132333435363738394041424344 |
- import torch
- import torch.multiprocessing
- import torch.nn
- import torch.optim
- from funasr.schedulers.noam_lr import NoamLR
- from funasr.schedulers.tri_stage_scheduler import TriStageLR
- from funasr.schedulers.warmup_lr import WarmupLR
- def build_scheduler(args, optimizers):
- scheduler_classes = dict(
- ReduceLROnPlateau=torch.optim.lr_scheduler.ReduceLROnPlateau,
- lambdalr=torch.optim.lr_scheduler.LambdaLR,
- steplr=torch.optim.lr_scheduler.StepLR,
- multisteplr=torch.optim.lr_scheduler.MultiStepLR,
- exponentiallr=torch.optim.lr_scheduler.ExponentialLR,
- CosineAnnealingLR=torch.optim.lr_scheduler.CosineAnnealingLR,
- noamlr=NoamLR,
- warmuplr=WarmupLR,
- tri_stage=TriStageLR,
- cycliclr=torch.optim.lr_scheduler.CyclicLR,
- onecyclelr=torch.optim.lr_scheduler.OneCycleLR,
- CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
- )
- schedulers = []
- for i, optim in enumerate(optimizers, 1):
- suf = "" if i == 1 else str(i)
- name = getattr(args, f"scheduler{suf}")
- conf = getattr(args, f"scheduler{suf}_conf")
- if name is not None:
- cls_ = scheduler_classes.get(name)
- if cls_ is None:
- raise ValueError(
- f"must be one of {list(scheduler_classes)}: {name}"
- )
- scheduler = cls_(optim, **conf)
- else:
- scheduler = None
- schedulers.append(scheduler)
- return schedulers
|