noam_lr.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. """Noam learning rate scheduler module."""
  2. from typing import Union
  3. import warnings
  4. import torch
  5. from torch.optim.lr_scheduler import _LRScheduler
  6. from funasr.schedulers.abs_scheduler import AbsBatchStepScheduler
  7. class NoamLR(_LRScheduler, AbsBatchStepScheduler):
  8. """The LR scheduler proposed by Noam
  9. Ref:
  10. "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf
  11. FIXME(kamo): PyTorch doesn't provide _LRScheduler as public class,
  12. thus the behaviour isn't guaranteed at forward PyTorch version.
  13. NOTE(kamo): The "model_size" in original implementation is derived from
  14. the model, but in this implementation, this parameter is a constant value.
  15. You need to change it if the model is changed.
  16. """
  17. def __init__(
  18. self,
  19. optimizer: torch.optim.Optimizer,
  20. model_size: Union[int, float] = 320,
  21. warmup_steps: Union[int, float] = 25000,
  22. last_epoch: int = -1,
  23. ):
  24. self.model_size = model_size
  25. self.warmup_steps = warmup_steps
  26. lr = list(optimizer.param_groups)[0]["lr"]
  27. new_lr = self.lr_for_WarmupLR(lr)
  28. warnings.warn(
  29. f"NoamLR is deprecated. "
  30. f"Use WarmupLR(warmup_steps={warmup_steps}) with Optimizer(lr={new_lr})",
  31. )
  32. # __init__() must be invoked before setting field
  33. # because step() is also invoked in __init__()
  34. super().__init__(optimizer, last_epoch)
  35. def lr_for_WarmupLR(self, lr: float) -> float:
  36. return lr / self.model_size**0.5 / self.warmup_steps**0.5
  37. def __repr__(self):
  38. return (
  39. f"{self.__class__.__name__}(model_size={self.model_size}, "
  40. f"warmup_steps={self.warmup_steps})"
  41. )
  42. def get_lr(self):
  43. step_num = self.last_epoch + 1
  44. return [
  45. lr
  46. * self.model_size**-0.5
  47. * min(step_num**-0.5, step_num * self.warmup_steps**-1.5)
  48. for lr in self.base_lrs
  49. ]