| 123456789101112131415161718192021222324252627282930313233343536373839404142 |
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
- # Copyright 2019 Shigeki Karita
- # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
- """Layer normalization module."""
- import torch
- class LayerNorm(torch.nn.LayerNorm):
- """Layer normalization module.
- Args:
- nout (int): Output dim size.
- dim (int): Dimension to be normalized.
- """
- def __init__(self, nout, dim=-1):
- """Construct an LayerNorm object."""
- super(LayerNorm, self).__init__(nout, eps=1e-12)
- self.dim = dim
- def forward(self, x):
- """Apply layer normalization.
- Args:
- x (torch.Tensor): Input tensor.
- Returns:
- torch.Tensor: Normalized tensor.
- """
- if self.dim == -1:
- return super(LayerNorm, self).forward(x)
- return (
- super(LayerNorm, self)
- .forward(x.transpose(self.dim, -1))
- .transpose(self.dim, -1)
- )
|