abs_scheduler.py 1.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. from abc import ABC
  2. from abc import abstractmethod
  3. import torch.optim.lr_scheduler as L
  4. class AbsScheduler(ABC):
  5. @abstractmethod
  6. def step(self, epoch: int = None):
  7. pass
  8. @abstractmethod
  9. def state_dict(self):
  10. pass
  11. @abstractmethod
  12. def load_state_dict(self, state):
  13. pass
  14. # If you need to define custom scheduler, please inherit these classes
  15. class AbsBatchStepScheduler(AbsScheduler):
  16. @abstractmethod
  17. def step(self, epoch: int = None):
  18. pass
  19. @abstractmethod
  20. def state_dict(self):
  21. pass
  22. @abstractmethod
  23. def load_state_dict(self, state):
  24. pass
  25. class AbsEpochStepScheduler(AbsScheduler):
  26. @abstractmethod
  27. def step(self, epoch: int = None):
  28. pass
  29. @abstractmethod
  30. def state_dict(self):
  31. pass
  32. @abstractmethod
  33. def load_state_dict(self, state):
  34. pass
  35. class AbsValEpochStepScheduler(AbsEpochStepScheduler):
  36. @abstractmethod
  37. def step(self, val, epoch: int = None):
  38. pass
  39. @abstractmethod
  40. def state_dict(self):
  41. pass
  42. @abstractmethod
  43. def load_state_dict(self, state):
  44. pass
  45. # Create alias type to check the type
  46. # Note(kamo): Currently PyTorch doesn't provide the base class
  47. # to judge these classes.
  48. AbsValEpochStepScheduler.register(L.ReduceLROnPlateau)
  49. for s in [
  50. L.ReduceLROnPlateau,
  51. L.LambdaLR,
  52. L.StepLR,
  53. L.MultiStepLR,
  54. L.MultiStepLR,
  55. L.ExponentialLR,
  56. L.CosineAnnealingLR,
  57. ]:
  58. AbsEpochStepScheduler.register(s)
  59. AbsBatchStepScheduler.register(L.CyclicLR)
  60. for s in [
  61. L.OneCycleLR,
  62. L.CosineAnnealingWarmRestarts,
  63. ]:
  64. AbsBatchStepScheduler.register(s)