initialize.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. #!/usr/bin/env python3
  2. """Initialize modules for espnet2 neural networks."""
  3. import math
  4. import torch
  5. def initialize(model: torch.nn.Module, init: str):
  6. """Initialize weights of a neural network module.
  7. Parameters are initialized using the given method or distribution.
  8. Custom initialization routines can be implemented into submodules
  9. as function `espnet_initialization_fn` within the custom module.
  10. Args:
  11. model: Target.
  12. init: Method of initialization.
  13. """
  14. # weight init
  15. for p in model.parameters():
  16. if p.dim() > 1:
  17. if init == "xavier_uniform":
  18. torch.nn.init.xavier_uniform_(p.data)
  19. elif init == "xavier_normal":
  20. torch.nn.init.xavier_normal_(p.data)
  21. elif init == "kaiming_uniform":
  22. torch.nn.init.kaiming_uniform_(p.data, nonlinearity="relu")
  23. elif init == "kaiming_normal":
  24. torch.nn.init.kaiming_normal_(p.data, nonlinearity="relu")
  25. else:
  26. raise ValueError("Unknown initialization: " + init)
  27. # bias init
  28. for p in model.parameters():
  29. if p.dim() == 1:
  30. p.data.zero_()
  31. # reset some modules with default init
  32. for m in model.modules():
  33. if isinstance(
  34. m, (torch.nn.Embedding, torch.nn.LayerNorm, torch.nn.GroupNorm)
  35. ):
  36. m.reset_parameters()
  37. if hasattr(m, "espnet_initialization_fn"):
  38. m.espnet_initialization_fn()
  39. # TODO(xkc): Hacking s3prl_frontend and wav2vec2encoder initialization
  40. if getattr(model, "encoder", None) and getattr(
  41. model.encoder, "reload_pretrained_parameters", None
  42. ):
  43. model.encoder.reload_pretrained_parameters()
  44. if getattr(model, "frontend", None) and getattr(
  45. model.frontend, "reload_pretrained_parameters", None
  46. ):
  47. model.frontend.reload_pretrained_parameters()