subsampling.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. # Copyright 2019 Shigeki Karita
  4. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  5. """Subsampling layer definition."""
  6. import torch
  7. import torch.nn.functional as F
  8. from funasr.modules.embedding import PositionalEncoding
  9. class TooShortUttError(Exception):
  10. """Raised when the utt is too short for subsampling.
  11. Args:
  12. message (str): Message for error catch
  13. actual_size (int): the short size that cannot pass the subsampling
  14. limit (int): the limit size for subsampling
  15. """
  16. def __init__(self, message, actual_size, limit):
  17. """Construct a TooShortUttError for error handler."""
  18. super().__init__(message)
  19. self.actual_size = actual_size
  20. self.limit = limit
  21. def check_short_utt(ins, size):
  22. """Check if the utterance is too short for subsampling."""
  23. if isinstance(ins, Conv2dSubsampling2) and size < 3:
  24. return True, 3
  25. if isinstance(ins, Conv2dSubsampling) and size < 7:
  26. return True, 7
  27. if isinstance(ins, Conv2dSubsampling6) and size < 11:
  28. return True, 11
  29. if isinstance(ins, Conv2dSubsampling8) and size < 15:
  30. return True, 15
  31. return False, -1
  32. class Conv2dSubsampling(torch.nn.Module):
  33. """Convolutional 2D subsampling (to 1/4 length).
  34. Args:
  35. idim (int): Input dimension.
  36. odim (int): Output dimension.
  37. dropout_rate (float): Dropout rate.
  38. pos_enc (torch.nn.Module): Custom position encoding layer.
  39. """
  40. def __init__(self, idim, odim, dropout_rate, pos_enc=None):
  41. """Construct an Conv2dSubsampling object."""
  42. super(Conv2dSubsampling, self).__init__()
  43. self.conv = torch.nn.Sequential(
  44. torch.nn.Conv2d(1, odim, 3, 2),
  45. torch.nn.ReLU(),
  46. torch.nn.Conv2d(odim, odim, 3, 2),
  47. torch.nn.ReLU(),
  48. )
  49. self.out = torch.nn.Sequential(
  50. torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim),
  51. pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
  52. )
  53. def forward(self, x, x_mask):
  54. """Subsample x.
  55. Args:
  56. x (torch.Tensor): Input tensor (#batch, time, idim).
  57. x_mask (torch.Tensor): Input mask (#batch, 1, time).
  58. Returns:
  59. torch.Tensor: Subsampled tensor (#batch, time', odim),
  60. where time' = time // 4.
  61. torch.Tensor: Subsampled mask (#batch, 1, time'),
  62. where time' = time // 4.
  63. """
  64. x = x.unsqueeze(1) # (b, c, t, f)
  65. x = self.conv(x)
  66. b, c, t, f = x.size()
  67. x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
  68. if x_mask is None:
  69. return x, None
  70. return x, x_mask[:, :, :-2:2][:, :, :-2:2]
  71. def __getitem__(self, key):
  72. """Get item.
  73. When reset_parameters() is called, if use_scaled_pos_enc is used,
  74. return the positioning encoding.
  75. """
  76. if key != -1:
  77. raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
  78. return self.out[key]
  79. class Conv2dSubsampling2(torch.nn.Module):
  80. """Convolutional 2D subsampling (to 1/2 length).
  81. Args:
  82. idim (int): Input dimension.
  83. odim (int): Output dimension.
  84. dropout_rate (float): Dropout rate.
  85. pos_enc (torch.nn.Module): Custom position encoding layer.
  86. """
  87. def __init__(self, idim, odim, dropout_rate, pos_enc=None):
  88. """Construct an Conv2dSubsampling2 object."""
  89. super(Conv2dSubsampling2, self).__init__()
  90. self.conv = torch.nn.Sequential(
  91. torch.nn.Conv2d(1, odim, 3, 2),
  92. torch.nn.ReLU(),
  93. torch.nn.Conv2d(odim, odim, 3, 1),
  94. torch.nn.ReLU(),
  95. )
  96. self.out = torch.nn.Sequential(
  97. torch.nn.Linear(odim * (((idim - 1) // 2 - 2)), odim),
  98. pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
  99. )
  100. def forward(self, x, x_mask):
  101. """Subsample x.
  102. Args:
  103. x (torch.Tensor): Input tensor (#batch, time, idim).
  104. x_mask (torch.Tensor): Input mask (#batch, 1, time).
  105. Returns:
  106. torch.Tensor: Subsampled tensor (#batch, time', odim),
  107. where time' = time // 2.
  108. torch.Tensor: Subsampled mask (#batch, 1, time'),
  109. where time' = time // 2.
  110. """
  111. x = x.unsqueeze(1) # (b, c, t, f)
  112. x = self.conv(x)
  113. b, c, t, f = x.size()
  114. x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
  115. if x_mask is None:
  116. return x, None
  117. return x, x_mask[:, :, :-2:2][:, :, :-2:1]
  118. def __getitem__(self, key):
  119. """Get item.
  120. When reset_parameters() is called, if use_scaled_pos_enc is used,
  121. return the positioning encoding.
  122. """
  123. if key != -1:
  124. raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
  125. return self.out[key]
  126. class Conv2dSubsampling6(torch.nn.Module):
  127. """Convolutional 2D subsampling (to 1/6 length).
  128. Args:
  129. idim (int): Input dimension.
  130. odim (int): Output dimension.
  131. dropout_rate (float): Dropout rate.
  132. pos_enc (torch.nn.Module): Custom position encoding layer.
  133. """
  134. def __init__(self, idim, odim, dropout_rate, pos_enc=None):
  135. """Construct an Conv2dSubsampling6 object."""
  136. super(Conv2dSubsampling6, self).__init__()
  137. self.conv = torch.nn.Sequential(
  138. torch.nn.Conv2d(1, odim, 3, 2),
  139. torch.nn.ReLU(),
  140. torch.nn.Conv2d(odim, odim, 5, 3),
  141. torch.nn.ReLU(),
  142. )
  143. self.out = torch.nn.Sequential(
  144. torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3), odim),
  145. pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
  146. )
  147. def forward(self, x, x_mask):
  148. """Subsample x.
  149. Args:
  150. x (torch.Tensor): Input tensor (#batch, time, idim).
  151. x_mask (torch.Tensor): Input mask (#batch, 1, time).
  152. Returns:
  153. torch.Tensor: Subsampled tensor (#batch, time', odim),
  154. where time' = time // 6.
  155. torch.Tensor: Subsampled mask (#batch, 1, time'),
  156. where time' = time // 6.
  157. """
  158. x = x.unsqueeze(1) # (b, c, t, f)
  159. x = self.conv(x)
  160. b, c, t, f = x.size()
  161. x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
  162. if x_mask is None:
  163. return x, None
  164. return x, x_mask[:, :, :-2:2][:, :, :-4:3]
  165. class Conv2dSubsampling8(torch.nn.Module):
  166. """Convolutional 2D subsampling (to 1/8 length).
  167. Args:
  168. idim (int): Input dimension.
  169. odim (int): Output dimension.
  170. dropout_rate (float): Dropout rate.
  171. pos_enc (torch.nn.Module): Custom position encoding layer.
  172. """
  173. def __init__(self, idim, odim, dropout_rate, pos_enc=None):
  174. """Construct an Conv2dSubsampling8 object."""
  175. super(Conv2dSubsampling8, self).__init__()
  176. self.conv = torch.nn.Sequential(
  177. torch.nn.Conv2d(1, odim, 3, 2),
  178. torch.nn.ReLU(),
  179. torch.nn.Conv2d(odim, odim, 3, 2),
  180. torch.nn.ReLU(),
  181. torch.nn.Conv2d(odim, odim, 3, 2),
  182. torch.nn.ReLU(),
  183. )
  184. self.out = torch.nn.Sequential(
  185. torch.nn.Linear(odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim),
  186. pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
  187. )
  188. def forward(self, x, x_mask):
  189. """Subsample x.
  190. Args:
  191. x (torch.Tensor): Input tensor (#batch, time, idim).
  192. x_mask (torch.Tensor): Input mask (#batch, 1, time).
  193. Returns:
  194. torch.Tensor: Subsampled tensor (#batch, time', odim),
  195. where time' = time // 8.
  196. torch.Tensor: Subsampled mask (#batch, 1, time'),
  197. where time' = time // 8.
  198. """
  199. x = x.unsqueeze(1) # (b, c, t, f)
  200. x = self.conv(x)
  201. b, c, t, f = x.size()
  202. x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
  203. if x_mask is None:
  204. return x, None
  205. return x, x_mask[:, :, :-2:2][:, :, :-2:2][:, :, :-2:2]
  206. class Conv1dSubsampling(torch.nn.Module):
  207. """Convolutional 1D subsampling (to 1/2 length).
  208. Args:
  209. idim (int): Input dimension.
  210. odim (int): Output dimension.
  211. dropout_rate (float): Dropout rate.
  212. pos_enc (torch.nn.Module): Custom position encoding layer.
  213. """
  214. def __init__(self, idim, odim, kernel_size, stride, pad):
  215. super(Conv1dSubsampling, self).__init__()
  216. self.conv = torch.nn.Conv1d(idim, odim, kernel_size, stride)
  217. self.pad_fn = torch.nn.ConstantPad1d(pad, 0.0)
  218. self.stride = stride
  219. self.odim = odim
  220. def output_size(self) -> int:
  221. return self.odim
  222. def forward(self, x, x_len):
  223. """Subsample x.
  224. """
  225. x = x.transpose(1, 2) # (b, d ,t)
  226. x = self.pad_fn(x)
  227. x = F.relu(self.conv(x))
  228. x = x.transpose(1, 2) # (b, t ,d)
  229. if x_len is None:
  230. return x, None
  231. x_len = (x_len - 1) // self.stride + 1
  232. return x, x_len
  233. def __getitem__(self, key):
  234. """Get item.
  235. When reset_parameters() is called, if use_scaled_pos_enc is used,
  236. return the positioning encoding.
  237. """
  238. if key != -1:
  239. raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
  240. return self.out[key]