lightconv2d.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. """Lightweight 2-Dimensional Convolution module."""
  2. import numpy
  3. import torch
  4. from torch import nn
  5. import torch.nn.functional as F
  6. MIN_VALUE = float(numpy.finfo(numpy.float32).min)
  7. class LightweightConvolution2D(nn.Module):
  8. """Lightweight 2-Dimensional Convolution layer.
  9. This implementation is based on
  10. https://github.com/pytorch/fairseq/tree/master/fairseq
  11. Args:
  12. wshare (int): the number of kernel of convolution
  13. n_feat (int): the number of features
  14. dropout_rate (float): dropout_rate
  15. kernel_size (int): kernel size (length)
  16. use_kernel_mask (bool): Use causal mask or not for convolution kernel
  17. use_bias (bool): Use bias term or not.
  18. """
  19. def __init__(
  20. self,
  21. wshare,
  22. n_feat,
  23. dropout_rate,
  24. kernel_size,
  25. use_kernel_mask=False,
  26. use_bias=False,
  27. ):
  28. """Construct Lightweight 2-Dimensional Convolution layer."""
  29. super(LightweightConvolution2D, self).__init__()
  30. assert n_feat % wshare == 0
  31. self.wshare = wshare
  32. self.use_kernel_mask = use_kernel_mask
  33. self.dropout_rate = dropout_rate
  34. self.kernel_size = kernel_size
  35. self.padding_size = int(kernel_size / 2)
  36. # linear -> GLU -> lightconv -> linear
  37. self.linear1 = nn.Linear(n_feat, n_feat * 2)
  38. self.linear2 = nn.Linear(n_feat * 2, n_feat)
  39. self.act = nn.GLU()
  40. # lightconv related
  41. self.weight = nn.Parameter(
  42. torch.Tensor(self.wshare, 1, kernel_size).uniform_(0, 1)
  43. )
  44. self.weight_f = nn.Parameter(torch.Tensor(1, 1, kernel_size).uniform_(0, 1))
  45. self.use_bias = use_bias
  46. if self.use_bias:
  47. self.bias = nn.Parameter(torch.Tensor(n_feat))
  48. # mask of kernel
  49. kernel_mask0 = torch.zeros(self.wshare, int(kernel_size / 2))
  50. kernel_mask1 = torch.ones(self.wshare, int(kernel_size / 2 + 1))
  51. self.kernel_mask = torch.cat((kernel_mask1, kernel_mask0), dim=-1).unsqueeze(1)
  52. def forward(self, query, key, value, mask):
  53. """Forward of 'Lightweight 2-Dimensional Convolution'.
  54. This function takes query, key and value but uses only query.
  55. This is just for compatibility with self-attention layer (attention.py)
  56. Args:
  57. query (torch.Tensor): (batch, time1, d_model) input tensor
  58. key (torch.Tensor): (batch, time2, d_model) NOT USED
  59. value (torch.Tensor): (batch, time2, d_model) NOT USED
  60. mask (torch.Tensor): (batch, time1, time2) mask
  61. Return:
  62. x (torch.Tensor): (batch, time1, d_model) output
  63. """
  64. # linear -> GLU -> lightconv -> linear
  65. x = query
  66. B, T, C = x.size()
  67. H = self.wshare
  68. # first liner layer
  69. x = self.linear1(x)
  70. # GLU activation
  71. x = self.act(x)
  72. # convolution along frequency axis
  73. weight_f = F.softmax(self.weight_f, dim=-1)
  74. weight_f = F.dropout(weight_f, self.dropout_rate, training=self.training)
  75. weight_new = torch.zeros(
  76. B * T, 1, self.kernel_size, device=x.device, dtype=x.dtype
  77. ).copy_(weight_f)
  78. xf = F.conv1d(
  79. x.view(1, B * T, C), weight_new, padding=self.padding_size, groups=B * T
  80. ).view(B, T, C)
  81. # lightconv
  82. x = x.transpose(1, 2).contiguous().view(-1, H, T) # B x C x T
  83. weight = F.dropout(self.weight, self.dropout_rate, training=self.training)
  84. if self.use_kernel_mask:
  85. self.kernel_mask = self.kernel_mask.to(x.device)
  86. weight = weight.masked_fill(self.kernel_mask == 0.0, float("-inf"))
  87. weight = F.softmax(weight, dim=-1)
  88. x = F.conv1d(x, weight, padding=self.padding_size, groups=self.wshare).view(
  89. B, C, T
  90. )
  91. if self.use_bias:
  92. x = x + self.bias.view(1, -1, 1)
  93. x = x.transpose(1, 2) # B x T x C
  94. x = torch.cat((x, xf), -1) # B x T x Cx2
  95. if mask is not None and not self.use_kernel_mask:
  96. mask = mask.transpose(-1, -2)
  97. x = x.masked_fill(mask == 0, 0.0)
  98. # second linear layer
  99. x = self.linear2(x)
  100. return x