sgd.py 747 B

123456789101112131415161718192021222324252627282930
  1. import torch
  2. class SGD(torch.optim.SGD):
  3. """Thin inheritance of torch.optim.SGD to bind the required arguments, 'lr'
  4. Note that
  5. the arguments of the optimizer invoked by AbsTask.main()
  6. must have default value except for 'param'.
  7. I can't understand why only SGD.lr doesn't have the default value.
  8. """
  9. def __init__(
  10. self,
  11. params,
  12. lr: float = 0.1,
  13. momentum: float = 0.0,
  14. dampening: float = 0.0,
  15. weight_decay: float = 0.0,
  16. nesterov: bool = False,
  17. ):
  18. super().__init__(
  19. params,
  20. lr=lr,
  21. momentum=momentum,
  22. dampening=dampening,
  23. weight_decay=weight_decay,
  24. nesterov=nesterov,
  25. )