add_gradient_noise.py 987 B

12345678910111213141516171819202122232425262728293031
  1. import torch
  2. def add_gradient_noise(
  3. model: torch.nn.Module,
  4. iteration: int,
  5. duration: float = 100,
  6. eta: float = 1.0,
  7. scale_factor: float = 0.55,
  8. ):
  9. """Adds noise from a standard normal distribution to the gradients.
  10. The standard deviation (`sigma`) is controlled
  11. by the three hyper-parameters below.
  12. `sigma` goes to zero (no noise) with more iterations.
  13. Args:
  14. model: Model.
  15. iteration: Number of iterations.
  16. duration: {100, 1000}: Number of durations to control
  17. the interval of the `sigma` change.
  18. eta: {0.01, 0.3, 1.0}: The magnitude of `sigma`.
  19. scale_factor: {0.55}: The scale of `sigma`.
  20. """
  21. interval = (iteration // duration) + 1
  22. sigma = eta / interval**scale_factor
  23. for param in model.parameters():
  24. if param.grad is not None:
  25. _shape = param.grad.size()
  26. noise = sigma * torch.randn(_shape).to(param.device)
  27. param.grad += noise