dynamic_conv2d.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. """Dynamic 2-Dimensional Convolution module."""
  2. import numpy
  3. import torch
  4. from torch import nn
  5. import torch.nn.functional as F
  6. MIN_VALUE = float(numpy.finfo(numpy.float32).min)
  7. class DynamicConvolution2D(nn.Module):
  8. """Dynamic 2-Dimensional Convolution layer.
  9. This implementation is based on
  10. https://github.com/pytorch/fairseq/tree/master/fairseq
  11. Args:
  12. wshare (int): the number of kernel of convolution
  13. n_feat (int): the number of features
  14. dropout_rate (float): dropout_rate
  15. kernel_size (int): kernel size (length)
  16. use_kernel_mask (bool): Use causal mask or not for convolution kernel
  17. use_bias (bool): Use bias term or not.
  18. """
  19. def __init__(
  20. self,
  21. wshare,
  22. n_feat,
  23. dropout_rate,
  24. kernel_size,
  25. use_kernel_mask=False,
  26. use_bias=False,
  27. ):
  28. """Construct Dynamic 2-Dimensional Convolution layer."""
  29. super(DynamicConvolution2D, self).__init__()
  30. assert n_feat % wshare == 0
  31. self.wshare = wshare
  32. self.use_kernel_mask = use_kernel_mask
  33. self.dropout_rate = dropout_rate
  34. self.kernel_size = kernel_size
  35. self.padding_size = int(kernel_size / 2)
  36. self.attn_t = None
  37. self.attn_f = None
  38. # linear -> GLU -- -> lightconv -> linear
  39. # \ /
  40. # Linear
  41. self.linear1 = nn.Linear(n_feat, n_feat * 2)
  42. self.linear2 = nn.Linear(n_feat * 2, n_feat)
  43. self.linear_weight = nn.Linear(n_feat, self.wshare * 1 * kernel_size)
  44. nn.init.xavier_uniform(self.linear_weight.weight)
  45. self.linear_weight_f = nn.Linear(n_feat, kernel_size)
  46. nn.init.xavier_uniform(self.linear_weight_f.weight)
  47. self.act = nn.GLU()
  48. # dynamic conv related
  49. self.use_bias = use_bias
  50. if self.use_bias:
  51. self.bias = nn.Parameter(torch.Tensor(n_feat))
  52. def forward(self, query, key, value, mask):
  53. """Forward of 'Dynamic 2-Dimensional Convolution'.
  54. This function takes query, key and value but uses only query.
  55. This is just for compatibility with self-attention layer (attention.py)
  56. Args:
  57. query (torch.Tensor): (batch, time1, d_model) input tensor
  58. key (torch.Tensor): (batch, time2, d_model) NOT USED
  59. value (torch.Tensor): (batch, time2, d_model) NOT USED
  60. mask (torch.Tensor): (batch, time1, time2) mask
  61. Return:
  62. x (torch.Tensor): (batch, time1, d_model) output
  63. """
  64. # linear -> GLU -- -> lightconv -> linear
  65. # \ /
  66. # Linear
  67. x = query
  68. B, T, C = x.size()
  69. H = self.wshare
  70. k = self.kernel_size
  71. # first liner layer
  72. x = self.linear1(x)
  73. # GLU activation
  74. x = self.act(x)
  75. # convolution of frequency axis
  76. weight_f = self.linear_weight_f(x).view(B * T, 1, k) # B x T x k
  77. self.attn_f = weight_f.view(B, T, k).unsqueeze(1)
  78. xf = F.conv1d(
  79. x.view(1, B * T, C), weight_f, padding=self.padding_size, groups=B * T
  80. )
  81. xf = xf.view(B, T, C)
  82. # get kernel of convolution
  83. weight = self.linear_weight(x) # B x T x kH
  84. weight = F.dropout(weight, self.dropout_rate, training=self.training)
  85. weight = weight.view(B, T, H, k).transpose(1, 2).contiguous() # B x H x T x k
  86. weight_new = torch.zeros(B * H * T * (T + k - 1), dtype=weight.dtype)
  87. weight_new = weight_new.view(B, H, T, T + k - 1).fill_(float("-inf"))
  88. weight_new = weight_new.to(x.device) # B x H x T x T+k-1
  89. weight_new.as_strided(
  90. (B, H, T, k), ((T + k - 1) * T * H, (T + k - 1) * T, T + k, 1)
  91. ).copy_(weight)
  92. weight_new = weight_new.narrow(-1, int((k - 1) / 2), T) # B x H x T x T(k)
  93. if self.use_kernel_mask:
  94. kernel_mask = torch.tril(torch.ones(T, T, device=x.device)).unsqueeze(0)
  95. weight_new = weight_new.masked_fill(kernel_mask == 0.0, float("-inf"))
  96. weight_new = F.softmax(weight_new, dim=-1)
  97. self.attn_t = weight_new
  98. weight_new = weight_new.view(B * H, T, T)
  99. # convolution
  100. x = x.transpose(1, 2).contiguous() # B x C x T
  101. x = x.view(B * H, int(C / H), T).transpose(1, 2)
  102. x = torch.bmm(weight_new, x)
  103. x = x.transpose(1, 2).contiguous().view(B, C, T)
  104. if self.use_bias:
  105. x = x + self.bias.view(1, -1, 1)
  106. x = x.transpose(1, 2) # B x T x C
  107. x = torch.cat((x, xf), -1) # B x T x Cx2
  108. if mask is not None and not self.use_kernel_mask:
  109. mask = mask.transpose(-1, -2)
  110. x = x.masked_fill(mask == 0, 0.0)
  111. # second linear layer
  112. x = self.linear2(x)
  113. return x