lightconv.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. """Lightweight Convolution Module."""
  2. import numpy
  3. import torch
  4. from torch import nn
  5. import torch.nn.functional as F
  6. MIN_VALUE = float(numpy.finfo(numpy.float32).min)
  7. class LightweightConvolution(nn.Module):
  8. """Lightweight Convolution layer.
  9. This implementation is based on
  10. https://github.com/pytorch/fairseq/tree/master/fairseq
  11. Args:
  12. wshare (int): the number of kernel of convolution
  13. n_feat (int): the number of features
  14. dropout_rate (float): dropout_rate
  15. kernel_size (int): kernel size (length)
  16. use_kernel_mask (bool): Use causal mask or not for convolution kernel
  17. use_bias (bool): Use bias term or not.
  18. """
  19. def __init__(
  20. self,
  21. wshare,
  22. n_feat,
  23. dropout_rate,
  24. kernel_size,
  25. use_kernel_mask=False,
  26. use_bias=False,
  27. ):
  28. """Construct Lightweight Convolution layer."""
  29. super(LightweightConvolution, self).__init__()
  30. assert n_feat % wshare == 0
  31. self.wshare = wshare
  32. self.use_kernel_mask = use_kernel_mask
  33. self.dropout_rate = dropout_rate
  34. self.kernel_size = kernel_size
  35. self.padding_size = int(kernel_size / 2)
  36. # linear -> GLU -> lightconv -> linear
  37. self.linear1 = nn.Linear(n_feat, n_feat * 2)
  38. self.linear2 = nn.Linear(n_feat, n_feat)
  39. self.act = nn.GLU()
  40. # lightconv related
  41. self.weight = nn.Parameter(
  42. torch.Tensor(self.wshare, 1, kernel_size).uniform_(0, 1)
  43. )
  44. self.use_bias = use_bias
  45. if self.use_bias:
  46. self.bias = nn.Parameter(torch.Tensor(n_feat))
  47. # mask of kernel
  48. kernel_mask0 = torch.zeros(self.wshare, int(kernel_size / 2))
  49. kernel_mask1 = torch.ones(self.wshare, int(kernel_size / 2 + 1))
  50. self.kernel_mask = torch.cat((kernel_mask1, kernel_mask0), dim=-1).unsqueeze(1)
  51. def forward(self, query, key, value, mask):
  52. """Forward of 'Lightweight Convolution'.
  53. This function takes query, key and value but uses only query.
  54. This is just for compatibility with self-attention layer (attention.py)
  55. Args:
  56. query (torch.Tensor): (batch, time1, d_model) input tensor
  57. key (torch.Tensor): (batch, time2, d_model) NOT USED
  58. value (torch.Tensor): (batch, time2, d_model) NOT USED
  59. mask (torch.Tensor): (batch, time1, time2) mask
  60. Return:
  61. x (torch.Tensor): (batch, time1, d_model) output
  62. """
  63. # linear -> GLU -> lightconv -> linear
  64. x = query
  65. B, T, C = x.size()
  66. H = self.wshare
  67. # first liner layer
  68. x = self.linear1(x)
  69. # GLU activation
  70. x = self.act(x)
  71. # lightconv
  72. x = x.transpose(1, 2).contiguous().view(-1, H, T) # B x C x T
  73. weight = F.dropout(self.weight, self.dropout_rate, training=self.training)
  74. if self.use_kernel_mask:
  75. self.kernel_mask = self.kernel_mask.to(x.device)
  76. weight = weight.masked_fill(self.kernel_mask == 0.0, float("-inf"))
  77. weight = F.softmax(weight, dim=-1)
  78. x = F.conv1d(x, weight, padding=self.padding_size, groups=self.wshare).view(
  79. B, C, T
  80. )
  81. if self.use_bias:
  82. x = x + self.bias.view(1, -1, 1)
  83. x = x.transpose(1, 2) # B x T x C
  84. if mask is not None and not self.use_kernel_mask:
  85. mask = mask.transpose(-1, -2)
  86. x = x.masked_fill(mask == 0, 0.0)
  87. # second linear layer
  88. x = self.linear2(x)
  89. return x