noam_lr.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. """Noam learning rate scheduler module."""
  2. from typing import Union
  3. import warnings
  4. import torch
  5. from torch.optim.lr_scheduler import _LRScheduler
  6. from typeguard import check_argument_types
  7. from funasr.schedulers.abs_scheduler import AbsBatchStepScheduler
  8. class NoamLR(_LRScheduler, AbsBatchStepScheduler):
  9. """The LR scheduler proposed by Noam
  10. Ref:
  11. "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf
  12. FIXME(kamo): PyTorch doesn't provide _LRScheduler as public class,
  13. thus the behaviour isn't guaranteed at forward PyTorch version.
  14. NOTE(kamo): The "model_size" in original implementation is derived from
  15. the model, but in this implementation, this parameter is a constant value.
  16. You need to change it if the model is changed.
  17. """
  18. def __init__(
  19. self,
  20. optimizer: torch.optim.Optimizer,
  21. model_size: Union[int, float] = 320,
  22. warmup_steps: Union[int, float] = 25000,
  23. last_epoch: int = -1,
  24. ):
  25. assert check_argument_types()
  26. self.model_size = model_size
  27. self.warmup_steps = warmup_steps
  28. lr = list(optimizer.param_groups)[0]["lr"]
  29. new_lr = self.lr_for_WarmupLR(lr)
  30. warnings.warn(
  31. f"NoamLR is deprecated. "
  32. f"Use WarmupLR(warmup_steps={warmup_steps}) with Optimizer(lr={new_lr})",
  33. )
  34. # __init__() must be invoked before setting field
  35. # because step() is also invoked in __init__()
  36. super().__init__(optimizer, last_epoch)
  37. def lr_for_WarmupLR(self, lr: float) -> float:
  38. return lr / self.model_size**0.5 / self.warmup_steps**0.5
  39. def __repr__(self):
  40. return (
  41. f"{self.__class__.__name__}(model_size={self.model_size}, "
  42. f"warmup_steps={self.warmup_steps})"
  43. )
  44. def get_lr(self):
  45. step_num = self.last_epoch + 1
  46. return [
  47. lr
  48. * self.model_size**-0.5
  49. * min(step_num**-0.5, step_num * self.warmup_steps**-1.5)
  50. for lr in self.base_lrs
  51. ]