layer_norm.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. # Copyright 2019 Shigeki Karita
  4. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  5. """Layer normalization module."""
  6. import torch
  7. import torch.nn as nn
  8. class LayerNorm(torch.nn.LayerNorm):
  9. """Layer normalization module.
  10. Args:
  11. nout (int): Output dim size.
  12. dim (int): Dimension to be normalized.
  13. """
  14. def __init__(self, nout, dim=-1):
  15. """Construct an LayerNorm object."""
  16. super(LayerNorm, self).__init__(nout, eps=1e-12)
  17. self.dim = dim
  18. def forward(self, x):
  19. """Apply layer normalization.
  20. Args:
  21. x (torch.Tensor): Input tensor.
  22. Returns:
  23. torch.Tensor: Normalized tensor.
  24. """
  25. if self.dim == -1:
  26. return super(LayerNorm, self).forward(x)
  27. return (
  28. super(LayerNorm, self)
  29. .forward(x.transpose(self.dim, -1))
  30. .transpose(self.dim, -1)
  31. )
  32. class GlobalLayerNorm(nn.Module):
  33. """Calculate Global Layer Normalization.
  34. Arguments
  35. ---------
  36. dim : (int or list or torch.Size)
  37. Input shape from an expected input of size.
  38. eps : float
  39. A value added to the denominator for numerical stability.
  40. elementwise_affine : bool
  41. A boolean value that when set to True,
  42. this module has learnable per-element affine parameters
  43. initialized to ones (for weights) and zeros (for biases).
  44. Example
  45. -------
  46. >>> x = torch.randn(5, 10, 20)
  47. >>> GLN = GlobalLayerNorm(10, 3)
  48. >>> x_norm = GLN(x)
  49. """
  50. def __init__(self, dim, shape, eps=1e-8, elementwise_affine=True):
  51. super(GlobalLayerNorm, self).__init__()
  52. self.dim = dim
  53. self.eps = eps
  54. self.elementwise_affine = elementwise_affine
  55. if self.elementwise_affine:
  56. if shape == 3:
  57. self.weight = nn.Parameter(torch.ones(self.dim, 1))
  58. self.bias = nn.Parameter(torch.zeros(self.dim, 1))
  59. if shape == 4:
  60. self.weight = nn.Parameter(torch.ones(self.dim, 1, 1))
  61. self.bias = nn.Parameter(torch.zeros(self.dim, 1, 1))
  62. else:
  63. self.register_parameter("weight", None)
  64. self.register_parameter("bias", None)
  65. def forward(self, x):
  66. """Returns the normalized tensor.
  67. Arguments
  68. ---------
  69. x : torch.Tensor
  70. Tensor of size [N, C, K, S] or [N, C, L].
  71. """
  72. # x = N x C x K x S or N x C x L
  73. # N x 1 x 1
  74. # cln: mean,var N x 1 x K x S
  75. # gln: mean,var N x 1 x 1
  76. if x.dim() == 3:
  77. mean = torch.mean(x, (1, 2), keepdim=True)
  78. var = torch.mean((x - mean) ** 2, (1, 2), keepdim=True)
  79. if self.elementwise_affine:
  80. x = (
  81. self.weight * (x - mean) / torch.sqrt(var + self.eps)
  82. + self.bias
  83. )
  84. else:
  85. x = (x - mean) / torch.sqrt(var + self.eps)
  86. if x.dim() == 4:
  87. mean = torch.mean(x, (1, 2, 3), keepdim=True)
  88. var = torch.mean((x - mean) ** 2, (1, 2, 3), keepdim=True)
  89. if self.elementwise_affine:
  90. x = (
  91. self.weight * (x - mean) / torch.sqrt(var + self.eps)
  92. + self.bias
  93. )
  94. else:
  95. x = (x - mean) / torch.sqrt(var + self.eps)
  96. return x
  97. class CumulativeLayerNorm(nn.LayerNorm):
  98. """Calculate Cumulative Layer Normalization.
  99. Arguments
  100. ---------
  101. dim : int
  102. Dimension that you want to normalize.
  103. elementwise_affine : True
  104. Learnable per-element affine parameters.
  105. Example
  106. -------
  107. >>> x = torch.randn(5, 10, 20)
  108. >>> CLN = CumulativeLayerNorm(10)
  109. >>> x_norm = CLN(x)
  110. """
  111. def __init__(self, dim, elementwise_affine=True):
  112. super(CumulativeLayerNorm, self).__init__(
  113. dim, elementwise_affine=elementwise_affine, eps=1e-8
  114. )
  115. def forward(self, x):
  116. """Returns the normalized tensor.
  117. Arguments
  118. ---------
  119. x : torch.Tensor
  120. Tensor size [N, C, K, S] or [N, C, L]
  121. """
  122. # x: N x C x K x S or N x C x L
  123. # N x K x S x C
  124. if x.dim() == 4:
  125. x = x.permute(0, 2, 3, 1).contiguous()
  126. # N x K x S x C == only channel norm
  127. x = super().forward(x)
  128. # N x C x K x S
  129. x = x.permute(0, 3, 1, 2).contiguous()
  130. if x.dim() == 3:
  131. x = torch.transpose(x, 1, 2)
  132. # N x L x C == only channel norm
  133. x = super().forward(x)
  134. # N x C x L
  135. x = torch.transpose(x, 1, 2)
  136. return x
  137. class ScaleNorm(nn.Module):
  138. def __init__(self, dim, eps = 1e-5):
  139. super().__init__()
  140. self.scale = dim ** -0.5
  141. self.eps = eps
  142. self.g = nn.Parameter(torch.ones(1))
  143. def forward(self, x):
  144. norm = torch.norm(x, dim = -1, keepdim = True) * self.scale
  145. return x / norm.clamp(min = self.eps) * self.g