sgd.py 828 B

1234567891011121314151617181920212223242526272829303132
  1. import torch
  2. from typeguard import check_argument_types
  3. class SGD(torch.optim.SGD):
  4. """Thin inheritance of torch.optim.SGD to bind the required arguments, 'lr'
  5. Note that
  6. the arguments of the optimizer invoked by AbsTask.main()
  7. must have default value except for 'param'.
  8. I can't understand why only SGD.lr doesn't have the default value.
  9. """
  10. def __init__(
  11. self,
  12. params,
  13. lr: float = 0.1,
  14. momentum: float = 0.0,
  15. dampening: float = 0.0,
  16. weight_decay: float = 0.0,
  17. nesterov: bool = False,
  18. ):
  19. assert check_argument_types()
  20. super().__init__(
  21. params,
  22. lr=lr,
  23. momentum=momentum,
  24. dampening=dampening,
  25. weight_decay=weight_decay,
  26. nesterov=nesterov,
  27. )