build_scheduler.py 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344
  1. import torch
  2. import torch.multiprocessing
  3. import torch.nn
  4. import torch.optim
  5. from funasr.schedulers.noam_lr import NoamLR
  6. from funasr.schedulers.tri_stage_scheduler import TriStageLR
  7. from funasr.schedulers.warmup_lr import WarmupLR
  8. def build_scheduler(args, optimizers):
  9. scheduler_classes = dict(
  10. ReduceLROnPlateau=torch.optim.lr_scheduler.ReduceLROnPlateau,
  11. lambdalr=torch.optim.lr_scheduler.LambdaLR,
  12. steplr=torch.optim.lr_scheduler.StepLR,
  13. multisteplr=torch.optim.lr_scheduler.MultiStepLR,
  14. exponentiallr=torch.optim.lr_scheduler.ExponentialLR,
  15. CosineAnnealingLR=torch.optim.lr_scheduler.CosineAnnealingLR,
  16. noamlr=NoamLR,
  17. warmuplr=WarmupLR,
  18. tri_stage=TriStageLR,
  19. cycliclr=torch.optim.lr_scheduler.CyclicLR,
  20. onecyclelr=torch.optim.lr_scheduler.OneCycleLR,
  21. CosineAnnealingWarmRestarts=torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
  22. )
  23. schedulers = []
  24. for i, optim in enumerate(optimizers, 1):
  25. suf = "" if i == 1 else str(i)
  26. name = getattr(args, f"scheduler{suf}")
  27. conf = getattr(args, f"scheduler{suf}_conf")
  28. if name is not None:
  29. cls_ = scheduler_classes.get(name)
  30. if cls_ is None:
  31. raise ValueError(
  32. f"must be one of {list(scheduler_classes)}: {name}"
  33. )
  34. scheduler = cls_(optim, **conf)
  35. else:
  36. scheduler = None
  37. schedulers.append(scheduler)
  38. return schedulers