multi_layer_conv.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. # Copyright 2019 Tomoki Hayashi
  4. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  5. """Layer modules for FFT block in FastSpeech (Feed-forward Transformer)."""
  6. import torch
  7. class MultiLayeredConv1d(torch.nn.Module):
  8. """Multi-layered conv1d for Transformer block.
  9. This is a module of multi-leyered conv1d designed
  10. to replace positionwise feed-forward network
  11. in Transforner block, which is introduced in
  12. `FastSpeech: Fast, Robust and Controllable Text to Speech`_.
  13. .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
  14. https://arxiv.org/pdf/1905.09263.pdf
  15. """
  16. def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
  17. """Initialize MultiLayeredConv1d module.
  18. Args:
  19. in_chans (int): Number of input channels.
  20. hidden_chans (int): Number of hidden channels.
  21. kernel_size (int): Kernel size of conv1d.
  22. dropout_rate (float): Dropout rate.
  23. """
  24. super(MultiLayeredConv1d, self).__init__()
  25. self.w_1 = torch.nn.Conv1d(
  26. in_chans,
  27. hidden_chans,
  28. kernel_size,
  29. stride=1,
  30. padding=(kernel_size - 1) // 2,
  31. )
  32. self.w_2 = torch.nn.Conv1d(
  33. hidden_chans,
  34. in_chans,
  35. kernel_size,
  36. stride=1,
  37. padding=(kernel_size - 1) // 2,
  38. )
  39. self.dropout = torch.nn.Dropout(dropout_rate)
  40. def forward(self, x):
  41. """Calculate forward propagation.
  42. Args:
  43. x (torch.Tensor): Batch of input tensors (B, T, in_chans).
  44. Returns:
  45. torch.Tensor: Batch of output tensors (B, T, hidden_chans).
  46. """
  47. x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
  48. return self.w_2(self.dropout(x).transpose(-1, 1)).transpose(-1, 1)
  49. class Conv1dLinear(torch.nn.Module):
  50. """Conv1D + Linear for Transformer block.
  51. A variant of MultiLayeredConv1d, which replaces second conv-layer to linear.
  52. """
  53. def __init__(self, in_chans, hidden_chans, kernel_size, dropout_rate):
  54. """Initialize Conv1dLinear module.
  55. Args:
  56. in_chans (int): Number of input channels.
  57. hidden_chans (int): Number of hidden channels.
  58. kernel_size (int): Kernel size of conv1d.
  59. dropout_rate (float): Dropout rate.
  60. """
  61. super(Conv1dLinear, self).__init__()
  62. self.w_1 = torch.nn.Conv1d(
  63. in_chans,
  64. hidden_chans,
  65. kernel_size,
  66. stride=1,
  67. padding=(kernel_size - 1) // 2,
  68. )
  69. self.w_2 = torch.nn.Linear(hidden_chans, in_chans)
  70. self.dropout = torch.nn.Dropout(dropout_rate)
  71. def forward(self, x):
  72. """Calculate forward propagation.
  73. Args:
  74. x (torch.Tensor): Batch of input tensors (B, T, in_chans).
  75. Returns:
  76. torch.Tensor: Batch of output tensors (B, T, hidden_chans).
  77. """
  78. x = torch.relu(self.w_1(x.transpose(-1, 1))).transpose(-1, 1)
  79. return self.w_2(self.dropout(x))