| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177 |
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
- # Copyright 2019 Shigeki Karita
- # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
- """Layer normalization module."""
- import torch
- import torch.nn as nn
- class LayerNorm(torch.nn.LayerNorm):
- """Layer normalization module.
- Args:
- nout (int): Output dim size.
- dim (int): Dimension to be normalized.
- """
- def __init__(self, nout, dim=-1):
- """Construct an LayerNorm object."""
- super(LayerNorm, self).__init__(nout, eps=1e-12)
- self.dim = dim
- def forward(self, x):
- """Apply layer normalization.
- Args:
- x (torch.Tensor): Input tensor.
- Returns:
- torch.Tensor: Normalized tensor.
- """
- if self.dim == -1:
- return super(LayerNorm, self).forward(x)
- return (
- super(LayerNorm, self)
- .forward(x.transpose(self.dim, -1))
- .transpose(self.dim, -1)
- )
- class GlobalLayerNorm(nn.Module):
- """Calculate Global Layer Normalization.
- Arguments
- ---------
- dim : (int or list or torch.Size)
- Input shape from an expected input of size.
- eps : float
- A value added to the denominator for numerical stability.
- elementwise_affine : bool
- A boolean value that when set to True,
- this module has learnable per-element affine parameters
- initialized to ones (for weights) and zeros (for biases).
- Example
- -------
- >>> x = torch.randn(5, 10, 20)
- >>> GLN = GlobalLayerNorm(10, 3)
- >>> x_norm = GLN(x)
- """
- def __init__(self, dim, shape, eps=1e-8, elementwise_affine=True):
- super(GlobalLayerNorm, self).__init__()
- self.dim = dim
- self.eps = eps
- self.elementwise_affine = elementwise_affine
- if self.elementwise_affine:
- if shape == 3:
- self.weight = nn.Parameter(torch.ones(self.dim, 1))
- self.bias = nn.Parameter(torch.zeros(self.dim, 1))
- if shape == 4:
- self.weight = nn.Parameter(torch.ones(self.dim, 1, 1))
- self.bias = nn.Parameter(torch.zeros(self.dim, 1, 1))
- else:
- self.register_parameter("weight", None)
- self.register_parameter("bias", None)
- def forward(self, x):
- """Returns the normalized tensor.
- Arguments
- ---------
- x : torch.Tensor
- Tensor of size [N, C, K, S] or [N, C, L].
- """
- # x = N x C x K x S or N x C x L
- # N x 1 x 1
- # cln: mean,var N x 1 x K x S
- # gln: mean,var N x 1 x 1
- if x.dim() == 3:
- mean = torch.mean(x, (1, 2), keepdim=True)
- var = torch.mean((x - mean) ** 2, (1, 2), keepdim=True)
- if self.elementwise_affine:
- x = (
- self.weight * (x - mean) / torch.sqrt(var + self.eps)
- + self.bias
- )
- else:
- x = (x - mean) / torch.sqrt(var + self.eps)
- if x.dim() == 4:
- mean = torch.mean(x, (1, 2, 3), keepdim=True)
- var = torch.mean((x - mean) ** 2, (1, 2, 3), keepdim=True)
- if self.elementwise_affine:
- x = (
- self.weight * (x - mean) / torch.sqrt(var + self.eps)
- + self.bias
- )
- else:
- x = (x - mean) / torch.sqrt(var + self.eps)
- return x
- class CumulativeLayerNorm(nn.LayerNorm):
- """Calculate Cumulative Layer Normalization.
- Arguments
- ---------
- dim : int
- Dimension that you want to normalize.
- elementwise_affine : True
- Learnable per-element affine parameters.
- Example
- -------
- >>> x = torch.randn(5, 10, 20)
- >>> CLN = CumulativeLayerNorm(10)
- >>> x_norm = CLN(x)
- """
- def __init__(self, dim, elementwise_affine=True):
- super(CumulativeLayerNorm, self).__init__(
- dim, elementwise_affine=elementwise_affine, eps=1e-8
- )
- def forward(self, x):
- """Returns the normalized tensor.
- Arguments
- ---------
- x : torch.Tensor
- Tensor size [N, C, K, S] or [N, C, L]
- """
- # x: N x C x K x S or N x C x L
- # N x K x S x C
- if x.dim() == 4:
- x = x.permute(0, 2, 3, 1).contiguous()
- # N x K x S x C == only channel norm
- x = super().forward(x)
- # N x C x K x S
- x = x.permute(0, 3, 1, 2).contiguous()
- if x.dim() == 3:
- x = torch.transpose(x, 1, 2)
- # N x L x C == only channel norm
- x = super().forward(x)
- # N x C x L
- x = torch.transpose(x, 1, 2)
- return x
- class ScaleNorm(nn.Module):
- def __init__(self, dim, eps = 1e-5):
- super().__init__()
- self.scale = dim ** -0.5
- self.eps = eps
- self.g = nn.Parameter(torch.ones(1))
- def forward(self, x):
- norm = torch.norm(x, dim = -1, keepdim = True) * self.scale
- return x / norm.clamp(min = self.eps) * self.g
|