initialize.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. #!/usr/bin/env python3
  2. """Initialize modules for espnet2 neural networks."""
  3. import math
  4. import torch
  5. from typeguard import check_argument_types
  6. def initialize(model: torch.nn.Module, init: str):
  7. """Initialize weights of a neural network module.
  8. Parameters are initialized using the given method or distribution.
  9. Custom initialization routines can be implemented into submodules
  10. as function `espnet_initialization_fn` within the custom module.
  11. Args:
  12. model: Target.
  13. init: Method of initialization.
  14. """
  15. assert check_argument_types()
  16. if init == "chainer":
  17. # 1. lecun_normal_init_parameters
  18. for p in model.parameters():
  19. data = p.data
  20. if data.dim() == 1:
  21. # bias
  22. data.zero_()
  23. elif data.dim() == 2:
  24. # linear weight
  25. n = data.size(1)
  26. stdv = 1.0 / math.sqrt(n)
  27. data.normal_(0, stdv)
  28. elif data.dim() in (3, 4):
  29. # conv weight
  30. n = data.size(1)
  31. for k in data.size()[2:]:
  32. n *= k
  33. stdv = 1.0 / math.sqrt(n)
  34. data.normal_(0, stdv)
  35. else:
  36. raise NotImplementedError
  37. for mod in model.modules():
  38. # 2. embed weight ~ Normal(0, 1)
  39. if isinstance(mod, torch.nn.Embedding):
  40. mod.weight.data.normal_(0, 1)
  41. # 3. forget-bias = 1.0
  42. elif isinstance(mod, torch.nn.RNNCellBase):
  43. n = mod.bias_ih.size(0)
  44. mod.bias_ih.data[n // 4 : n // 2].fill_(1.0)
  45. elif isinstance(mod, torch.nn.RNNBase):
  46. for name, param in mod.named_parameters():
  47. if "bias" in name:
  48. n = param.size(0)
  49. param.data[n // 4 : n // 2].fill_(1.0)
  50. if hasattr(mod, "espnet_initialization_fn"):
  51. mod.espnet_initialization_fn()
  52. else:
  53. # weight init
  54. for p in model.parameters():
  55. if p.dim() > 1:
  56. if init == "xavier_uniform":
  57. torch.nn.init.xavier_uniform_(p.data)
  58. elif init == "xavier_normal":
  59. torch.nn.init.xavier_normal_(p.data)
  60. elif init == "kaiming_uniform":
  61. torch.nn.init.kaiming_uniform_(p.data, nonlinearity="relu")
  62. elif init == "kaiming_normal":
  63. torch.nn.init.kaiming_normal_(p.data, nonlinearity="relu")
  64. else:
  65. raise ValueError("Unknown initialization: " + init)
  66. # bias init
  67. for p in model.parameters():
  68. if p.dim() == 1:
  69. p.data.zero_()
  70. # reset some modules with default init
  71. for m in model.modules():
  72. if isinstance(
  73. m, (torch.nn.Embedding, torch.nn.LayerNorm, torch.nn.GroupNorm)
  74. ):
  75. m.reset_parameters()
  76. if hasattr(m, "espnet_initialization_fn"):
  77. m.espnet_initialization_fn()
  78. # TODO(xkc): Hacking s3prl_frontend and wav2vec2encoder initialization
  79. if getattr(model, "encoder", None) and getattr(
  80. model.encoder, "reload_pretrained_parameters", None
  81. ):
  82. model.encoder.reload_pretrained_parameters()
  83. if getattr(model, "frontend", None) and getattr(
  84. model.frontend, "reload_pretrained_parameters", None
  85. ):
  86. model.frontend.reload_pretrained_parameters()
  87. if getattr(model, "postencoder", None) and getattr(
  88. model.postencoder, "reload_pretrained_parameters", None
  89. ):
  90. model.postencoder.reload_pretrained_parameters()