build_optimizer.py 779 B

1234567891011121314151617181920212223242526
  1. import torch
  2. from funasr.optimizers.fairseq_adam import FairseqAdam
  3. from funasr.optimizers.sgd import SGD
  4. def build_optimizer(args, model):
  5. optim_classes = dict(
  6. adam=torch.optim.Adam,
  7. fairseq_adam=FairseqAdam,
  8. adamw=torch.optim.AdamW,
  9. sgd=SGD,
  10. adadelta=torch.optim.Adadelta,
  11. adagrad=torch.optim.Adagrad,
  12. adamax=torch.optim.Adamax,
  13. asgd=torch.optim.ASGD,
  14. lbfgs=torch.optim.LBFGS,
  15. rmsprop=torch.optim.RMSprop,
  16. rprop=torch.optim.Rprop,
  17. )
  18. optim_class = optim_classes.get(args.optim)
  19. if optim_class is None:
  20. raise ValueError(f"must be one of {list(optim_classes)}: {args.optim}")
  21. optimizer = optim_class(model.parameters(), **args.optim_conf)
  22. return optimizer