layer_norm.py 958 B

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. # Copyright 2019 Shigeki Karita
  4. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  5. """Layer normalization module."""
  6. import torch
  7. class LayerNorm(torch.nn.LayerNorm):
  8. """Layer normalization module.
  9. Args:
  10. nout (int): Output dim size.
  11. dim (int): Dimension to be normalized.
  12. """
  13. def __init__(self, nout, dim=-1):
  14. """Construct an LayerNorm object."""
  15. super(LayerNorm, self).__init__(nout, eps=1e-12)
  16. self.dim = dim
  17. def forward(self, x):
  18. """Apply layer normalization.
  19. Args:
  20. x (torch.Tensor): Input tensor.
  21. Returns:
  22. torch.Tensor: Normalized tensor.
  23. """
  24. if self.dim == -1:
  25. return super(LayerNorm, self).forward(x)
  26. return (
  27. super(LayerNorm, self)
  28. .forward(x.transpose(self.dim, -1))
  29. .transpose(self.dim, -1)
  30. )