multi_layer_conv.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. # Copyright 2019 Tomoki Hayashi
  4. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  5. """Layer modules for FFT block in FastSpeech (Feed-forward Transformer)."""
  6. import torch
  7. class MultiLayeredConv1d(torch.nn.Module):
  8. """Multi-layered conv1d for Transformer block.
  9. This is a module of multi-leyered conv1d designed
  10. to replace positionwise feed-forward network
  11. in Transforner block, which is introduced in
  12. `FastSpeech: Fast, Robust and Controllable Text to Speech`_.
  13. .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
  14. https://arxiv.org/pdf/1905.09263.pdf
  15. """
  16. def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
  17. """Initialize MultiLayeredConv1d module.
  18. Args:
  19. in_chans (int): Number of input channels.
  20. hidden_chans (int): Number of hidden channels.
  21. kernel_size (int): Kernel size of conv1d.
  22. dropout_rate (float): Dropout rate.
  23. """
  24. super(MultiLayeredConv1d, self).__init__()
  25. self.w_1 = torch.nn.Conv1d(
  26. in_chans,
  27. hidden_chans,
  28. kernel_size,
  29. stride=1,
  30. padding=(kernel_size - 1) // 2,
  31. )
  32. self.w_2 = torch.nn.Conv1d(
  33. hidden_chans,
  34. in_chans,
  35. kernel_size,
  36. stride=1,
  37. padding=(kernel_size - 1) // 2,
  38. )
  39. self.dropout = torch.nn.Dropout(dropout_rate)
  40. def forward(self, x):
  41. """Calculate forward propagation.
  42. Args:
  43. x (torch.Tensor): Batch of input tensors (B, T, in_chans).
  44. Returns:
  45. torch.Tensor: Batch of output tensors (B, T, hidden_chans).
  46. """
  47. x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
  48. return self.w_2(self.dropout(x).transpose(-1, 1)).transpose(-1, 1)
  49. class FsmnFeedForward(torch.nn.Module):
  50. """Position-wise feed forward for FSMN blocks.
  51. This is a module of multi-leyered conv1d designed
  52. to replace position-wise feed-forward network
  53. in FSMN block.
  54. """
  55. def __init__(self, in_chans, hidden_chans, out_chans, kernel_size, dropout_rate):
  56. """Initialize FsmnFeedForward module.
  57. Args:
  58. in_chans (int): Number of input channels.
  59. hidden_chans (int): Number of hidden channels.
  60. out_chans (int): Number of output channels.
  61. kernel_size (int): Kernel size of conv1d.
  62. dropout_rate (float): Dropout rate.
  63. """
  64. super(FsmnFeedForward, self).__init__()
  65. self.w_1 = torch.nn.Conv1d(
  66. in_chans,
  67. hidden_chans,
  68. kernel_size,
  69. stride=1,
  70. padding=(kernel_size - 1) // 2,
  71. )
  72. self.w_2 = torch.nn.Conv1d(
  73. hidden_chans,
  74. out_chans,
  75. kernel_size,
  76. stride=1,
  77. padding=(kernel_size - 1) // 2,
  78. bias=False
  79. )
  80. self.norm = torch.nn.LayerNorm(hidden_chans)
  81. self.dropout = torch.nn.Dropout(dropout_rate)
  82. def forward(self, x, ilens=None):
  83. """Calculate forward propagation.
  84. Args:
  85. x (torch.Tensor): Batch of input tensors (B, T, in_chans).
  86. Returns:
  87. torch.Tensor: Batch of output tensors (B, T, out_chans).
  88. """
  89. x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
  90. return self.w_2(self.norm(self.dropout(x)).transpose(-1, 1)).transpose(-1, 1), ilens
  91. class Conv1dLinear(torch.nn.Module):
  92. """Conv1D + Linear for Transformer block.
  93. A variant of MultiLayeredConv1d, which replaces second conv-layer to linear.
  94. """
  95. def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
  96. """Initialize Conv1dLinear module.
  97. Args:
  98. in_chans (int): Number of input channels.
  99. hidden_chans (int): Number of hidden channels.
  100. kernel_size (int): Kernel size of conv1d.
  101. dropout_rate (float): Dropout rate.
  102. """
  103. super(Conv1dLinear, self).__init__()
  104. self.w_1 = torch.nn.Conv1d(
  105. in_chans,
  106. hidden_chans,
  107. kernel_size,
  108. stride=1,
  109. padding=(kernel_size - 1) // 2,
  110. )
  111. self.w_2 = torch.nn.Linear(hidden_chans, in_chans)
  112. self.dropout = torch.nn.Dropout(dropout_rate)
  113. def forward(self, x):
  114. """Calculate forward propagation.
  115. Args:
  116. x (torch.Tensor): Batch of input tensors (B, T, in_chans).
  117. Returns:
  118. torch.Tensor: Batch of output tensors (B, T, hidden_chans).
  119. """
  120. x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
  121. return self.w_2(self.dropout(x))