subsampling.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. """Subsampling layer definition."""
  2. import torch
  3. class OnnxConv2dSubsampling(torch.nn.Module):
  4. """Convolutional 2D subsampling (to 1/4 length).
  5. Args:
  6. idim (int): Input dimension.
  7. odim (int): Output dimension.
  8. dropout_rate (float): Dropout rate.
  9. pos_enc (torch.nn.Module): Custom position encoding layer.
  10. """
  11. def __init__(self, model):
  12. """Construct an Conv2dSubsampling object."""
  13. super().__init__()
  14. self.conv = model.conv
  15. self.out = model.out
  16. def forward(self, x, x_mask):
  17. """Subsample x.
  18. Args:
  19. x (torch.Tensor): Input tensor (#batch, time, idim).
  20. x_mask (torch.Tensor): Input mask (#batch, 1, time).
  21. Returns:
  22. torch.Tensor: Subsampled tensor (#batch, time', odim),
  23. where time' = time // 4.
  24. torch.Tensor: Subsampled mask (#batch, 1, time'),
  25. where time' = time // 4.
  26. """
  27. x = x.unsqueeze(1) # (b, c, t, f)
  28. x = self.conv(x)
  29. b, c, t, f = x.size()
  30. x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
  31. if x_mask is None:
  32. return x, None
  33. return x, x_mask[:, :-2:2][:, :-2:2]
  34. def __getitem__(self, key):
  35. """Get item.
  36. When reset_parameters() is called, if use_scaled_pos_enc is used,
  37. return the positioning encoding.
  38. """
  39. if key != -1:
  40. raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
  41. return self.out[key]
  42. class OnnxConv2dSubsampling2(torch.nn.Module):
  43. """Convolutional 2D subsampling (to 1/2 length).
  44. Args:
  45. idim (int): Input dimension.
  46. odim (int): Output dimension.
  47. dropout_rate (float): Dropout rate.
  48. pos_enc (torch.nn.Module): Custom position encoding layer.
  49. """
  50. def __init__(self, model):
  51. """Construct an Conv2dSubsampling object."""
  52. super().__init__()
  53. self.conv = model.conv
  54. self.out = model.out
  55. def forward(self, x, x_mask):
  56. """Subsample x.
  57. Args:
  58. x (torch.Tensor): Input tensor (#batch, time, idim).
  59. x_mask (torch.Tensor): Input mask (#batch, 1, time).
  60. Returns:
  61. torch.Tensor: Subsampled tensor (#batch, time', odim),
  62. where time' = time // 2.
  63. torch.Tensor: Subsampled mask (#batch, 1, time'),
  64. where time' = time // 2.
  65. """
  66. x = x.unsqueeze(1) # (b, c, t, f)
  67. x = self.conv(x)
  68. b, c, t, f = x.size()
  69. x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
  70. if x_mask is None:
  71. return x, None
  72. return x, x_mask[:, :-2:2][:, :-2:1]
  73. def __getitem__(self, key):
  74. """Get item.
  75. When reset_parameters() is called, if use_scaled_pos_enc is used,
  76. return the positioning encoding.
  77. """
  78. if key != -1:
  79. raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
  80. return self.out[key]
  81. class OnnxConv2dSubsampling6(torch.nn.Module):
  82. """Convolutional 2D subsampling (to 1/6 length).
  83. Args:
  84. idim (int): Input dimension.
  85. odim (int): Output dimension.
  86. dropout_rate (float): Dropout rate.
  87. pos_enc (torch.nn.Module): Custom position encoding layer.
  88. """
  89. def __init__(self, model):
  90. """Construct an Conv2dSubsampling object."""
  91. super().__init__()
  92. self.conv = model.conv
  93. self.out = model.out
  94. def forward(self, x, x_mask):
  95. """Subsample x.
  96. Args:
  97. x (torch.Tensor): Input tensor (#batch, time, idim).
  98. x_mask (torch.Tensor): Input mask (#batch, 1, time).
  99. Returns:
  100. torch.Tensor: Subsampled tensor (#batch, time', odim),
  101. where time' = time // 6.
  102. torch.Tensor: Subsampled mask (#batch, 1, time'),
  103. where time' = time // 6.
  104. """
  105. x = x.unsqueeze(1) # (b, c, t, f)
  106. x = self.conv(x)
  107. b, c, t, f = x.size()
  108. x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
  109. if x_mask is None:
  110. return x, None
  111. return x, x_mask[:, :-2:2][:, :-4:3]
  112. class OnnxConv2dSubsampling8(torch.nn.Module):
  113. """Convolutional 2D subsampling (to 1/8 length).
  114. Args:
  115. idim (int): Input dimension.
  116. odim (int): Output dimension.
  117. dropout_rate (float): Dropout rate.
  118. pos_enc (torch.nn.Module): Custom position encoding layer.
  119. """
  120. def __init__(self, model):
  121. """Construct an Conv2dSubsampling object."""
  122. super().__init__()
  123. self.conv = model.conv
  124. self.out = model.out
  125. def forward(self, x, x_mask):
  126. """Subsample x.
  127. Args:
  128. x (torch.Tensor): Input tensor (#batch, time, idim).
  129. x_mask (torch.Tensor): Input mask (#batch, 1, time).
  130. Returns:
  131. torch.Tensor: Subsampled tensor (#batch, time', odim),
  132. where time' = time // 8.
  133. torch.Tensor: Subsampled mask (#batch, 1, time'),
  134. where time' = time // 8.
  135. """
  136. x = x.unsqueeze(1) # (b, c, t, f)
  137. x = self.conv(x)
  138. b, c, t, f = x.size()
  139. x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
  140. if x_mask is None:
  141. return x, None
  142. return x, x_mask[:, :-2:2][:, :-2:2][:, :-2:2]