subsampling.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. # Copyright 2019 Shigeki Karita
  4. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  5. """Subsampling layer definition."""
  6. import numpy as np
  7. import torch
  8. import torch.nn.functional as F
  9. from funasr.modules.embedding import PositionalEncoding
  10. import logging
  11. from funasr.modules.streaming_utils.utils import sequence_mask
  12. class TooShortUttError(Exception):
  13. """Raised when the utt is too short for subsampling.
  14. Args:
  15. message (str): Message for error catch
  16. actual_size (int): the short size that cannot pass the subsampling
  17. limit (int): the limit size for subsampling
  18. """
  19. def __init__(self, message, actual_size, limit):
  20. """Construct a TooShortUttError for error handler."""
  21. super().__init__(message)
  22. self.actual_size = actual_size
  23. self.limit = limit
  24. def check_short_utt(ins, size):
  25. """Check if the utterance is too short for subsampling."""
  26. if isinstance(ins, Conv2dSubsampling2) and size < 3:
  27. return True, 3
  28. if isinstance(ins, Conv2dSubsampling) and size < 7:
  29. return True, 7
  30. if isinstance(ins, Conv2dSubsampling6) and size < 11:
  31. return True, 11
  32. if isinstance(ins, Conv2dSubsampling8) and size < 15:
  33. return True, 15
  34. return False, -1
  35. class Conv2dSubsampling(torch.nn.Module):
  36. """Convolutional 2D subsampling (to 1/4 length).
  37. Args:
  38. idim (int): Input dimension.
  39. odim (int): Output dimension.
  40. dropout_rate (float): Dropout rate.
  41. pos_enc (torch.nn.Module): Custom position encoding layer.
  42. """
  43. def __init__(self, idim, odim, dropout_rate, pos_enc=None):
  44. """Construct an Conv2dSubsampling object."""
  45. super(Conv2dSubsampling, self).__init__()
  46. self.conv = torch.nn.Sequential(
  47. torch.nn.Conv2d(1, odim, 3, 2),
  48. torch.nn.ReLU(),
  49. torch.nn.Conv2d(odim, odim, 3, 2),
  50. torch.nn.ReLU(),
  51. )
  52. self.out = torch.nn.Sequential(
  53. torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim),
  54. pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
  55. )
  56. def forward(self, x, x_mask):
  57. """Subsample x.
  58. Args:
  59. x (torch.Tensor): Input tensor (#batch, time, idim).
  60. x_mask (torch.Tensor): Input mask (#batch, 1, time).
  61. Returns:
  62. torch.Tensor: Subsampled tensor (#batch, time', odim),
  63. where time' = time // 4.
  64. torch.Tensor: Subsampled mask (#batch, 1, time'),
  65. where time' = time // 4.
  66. """
  67. x = x.unsqueeze(1) # (b, c, t, f)
  68. x = self.conv(x)
  69. b, c, t, f = x.size()
  70. x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
  71. if x_mask is None:
  72. return x, None
  73. return x, x_mask[:, :, :-2:2][:, :, :-2:2]
  74. def __getitem__(self, key):
  75. """Get item.
  76. When reset_parameters() is called, if use_scaled_pos_enc is used,
  77. return the positioning encoding.
  78. """
  79. if key != -1:
  80. raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
  81. return self.out[key]
  82. class Conv2dSubsamplingPad(torch.nn.Module):
  83. """Convolutional 2D subsampling (to 1/4 length).
  84. Args:
  85. idim (int): Input dimension.
  86. odim (int): Output dimension.
  87. dropout_rate (float): Dropout rate.
  88. pos_enc (torch.nn.Module): Custom position encoding layer.
  89. """
  90. def __init__(self, idim, odim, dropout_rate, pos_enc=None):
  91. """Construct an Conv2dSubsampling object."""
  92. super(Conv2dSubsamplingPad, self).__init__()
  93. self.conv = torch.nn.Sequential(
  94. torch.nn.Conv2d(1, odim, 3, 2, padding=(0, 0)),
  95. torch.nn.ReLU(),
  96. torch.nn.Conv2d(odim, odim, 3, 2, padding=(0, 0)),
  97. torch.nn.ReLU(),
  98. )
  99. self.out = torch.nn.Sequential(
  100. torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim),
  101. pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
  102. )
  103. self.pad_fn = torch.nn.ConstantPad1d((0, 4), 0.0)
  104. def forward(self, x, x_mask):
  105. """Subsample x.
  106. Args:
  107. x (torch.Tensor): Input tensor (#batch, time, idim).
  108. x_mask (torch.Tensor): Input mask (#batch, 1, time).
  109. Returns:
  110. torch.Tensor: Subsampled tensor (#batch, time', odim),
  111. where time' = time // 4.
  112. torch.Tensor: Subsampled mask (#batch, 1, time'),
  113. where time' = time // 4.
  114. """
  115. x = x.transpose(1, 2)
  116. x = self.pad_fn(x)
  117. x = x.transpose(1, 2)
  118. x = x.unsqueeze(1) # (b, c, t, f)
  119. x = self.conv(x)
  120. b, c, t, f = x.size()
  121. x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
  122. if x_mask is None:
  123. return x, None
  124. x_len = torch.sum(x_mask[:, 0, :], dim=-1)
  125. x_len = (x_len - 1) // 2 + 1
  126. x_len = (x_len - 1) // 2 + 1
  127. mask = sequence_mask(x_len, None, x_len.dtype, x[0].device)
  128. return x, mask[:, None, :]
  129. def __getitem__(self, key):
  130. """Get item.
  131. When reset_parameters() is called, if use_scaled_pos_enc is used,
  132. return the positioning encoding.
  133. """
  134. if key != -1:
  135. raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
  136. return self.out[key]
  137. class Conv2dSubsampling2(torch.nn.Module):
  138. """Convolutional 2D subsampling (to 1/2 length).
  139. Args:
  140. idim (int): Input dimension.
  141. odim (int): Output dimension.
  142. dropout_rate (float): Dropout rate.
  143. pos_enc (torch.nn.Module): Custom position encoding layer.
  144. """
  145. def __init__(self, idim, odim, dropout_rate, pos_enc=None):
  146. """Construct an Conv2dSubsampling2 object."""
  147. super(Conv2dSubsampling2, self).__init__()
  148. self.conv = torch.nn.Sequential(
  149. torch.nn.Conv2d(1, odim, 3, 2),
  150. torch.nn.ReLU(),
  151. torch.nn.Conv2d(odim, odim, 3, 1),
  152. torch.nn.ReLU(),
  153. )
  154. self.out = torch.nn.Sequential(
  155. torch.nn.Linear(odim * (((idim - 1) // 2 - 2)), odim),
  156. pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
  157. )
  158. def forward(self, x, x_mask):
  159. """Subsample x.
  160. Args:
  161. x (torch.Tensor): Input tensor (#batch, time, idim).
  162. x_mask (torch.Tensor): Input mask (#batch, 1, time).
  163. Returns:
  164. torch.Tensor: Subsampled tensor (#batch, time', odim),
  165. where time' = time // 2.
  166. torch.Tensor: Subsampled mask (#batch, 1, time'),
  167. where time' = time // 2.
  168. """
  169. x = x.unsqueeze(1) # (b, c, t, f)
  170. x = self.conv(x)
  171. b, c, t, f = x.size()
  172. x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
  173. if x_mask is None:
  174. return x, None
  175. return x, x_mask[:, :, :-2:2][:, :, :-2:1]
  176. def __getitem__(self, key):
  177. """Get item.
  178. When reset_parameters() is called, if use_scaled_pos_enc is used,
  179. return the positioning encoding.
  180. """
  181. if key != -1:
  182. raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
  183. return self.out[key]
  184. class Conv2dSubsampling6(torch.nn.Module):
  185. """Convolutional 2D subsampling (to 1/6 length).
  186. Args:
  187. idim (int): Input dimension.
  188. odim (int): Output dimension.
  189. dropout_rate (float): Dropout rate.
  190. pos_enc (torch.nn.Module): Custom position encoding layer.
  191. """
  192. def __init__(self, idim, odim, dropout_rate, pos_enc=None):
  193. """Construct an Conv2dSubsampling6 object."""
  194. super(Conv2dSubsampling6, self).__init__()
  195. self.conv = torch.nn.Sequential(
  196. torch.nn.Conv2d(1, odim, 3, 2),
  197. torch.nn.ReLU(),
  198. torch.nn.Conv2d(odim, odim, 5, 3),
  199. torch.nn.ReLU(),
  200. )
  201. self.out = torch.nn.Sequential(
  202. torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3), odim),
  203. pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
  204. )
  205. def forward(self, x, x_mask):
  206. """Subsample x.
  207. Args:
  208. x (torch.Tensor): Input tensor (#batch, time, idim).
  209. x_mask (torch.Tensor): Input mask (#batch, 1, time).
  210. Returns:
  211. torch.Tensor: Subsampled tensor (#batch, time', odim),
  212. where time' = time // 6.
  213. torch.Tensor: Subsampled mask (#batch, 1, time'),
  214. where time' = time // 6.
  215. """
  216. x = x.unsqueeze(1) # (b, c, t, f)
  217. x = self.conv(x)
  218. b, c, t, f = x.size()
  219. x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
  220. if x_mask is None:
  221. return x, None
  222. return x, x_mask[:, :, :-2:2][:, :, :-4:3]
  223. class Conv2dSubsampling8(torch.nn.Module):
  224. """Convolutional 2D subsampling (to 1/8 length).
  225. Args:
  226. idim (int): Input dimension.
  227. odim (int): Output dimension.
  228. dropout_rate (float): Dropout rate.
  229. pos_enc (torch.nn.Module): Custom position encoding layer.
  230. """
  231. def __init__(self, idim, odim, dropout_rate, pos_enc=None):
  232. """Construct an Conv2dSubsampling8 object."""
  233. super(Conv2dSubsampling8, self).__init__()
  234. self.conv = torch.nn.Sequential(
  235. torch.nn.Conv2d(1, odim, 3, 2),
  236. torch.nn.ReLU(),
  237. torch.nn.Conv2d(odim, odim, 3, 2),
  238. torch.nn.ReLU(),
  239. torch.nn.Conv2d(odim, odim, 3, 2),
  240. torch.nn.ReLU(),
  241. )
  242. self.out = torch.nn.Sequential(
  243. torch.nn.Linear(odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim),
  244. pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
  245. )
  246. def forward(self, x, x_mask):
  247. """Subsample x.
  248. Args:
  249. x (torch.Tensor): Input tensor (#batch, time, idim).
  250. x_mask (torch.Tensor): Input mask (#batch, 1, time).
  251. Returns:
  252. torch.Tensor: Subsampled tensor (#batch, time', odim),
  253. where time' = time // 8.
  254. torch.Tensor: Subsampled mask (#batch, 1, time'),
  255. where time' = time // 8.
  256. """
  257. x = x.unsqueeze(1) # (b, c, t, f)
  258. x = self.conv(x)
  259. b, c, t, f = x.size()
  260. x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
  261. if x_mask is None:
  262. return x, None
  263. return x, x_mask[:, :, :-2:2][:, :, :-2:2][:, :, :-2:2]
  264. class Conv1dSubsampling(torch.nn.Module):
  265. """Convolutional 1D subsampling (to 1/2 length).
  266. Args:
  267. idim (int): Input dimension.
  268. odim (int): Output dimension.
  269. dropout_rate (float): Dropout rate.
  270. pos_enc (torch.nn.Module): Custom position encoding layer.
  271. """
  272. def __init__(self, idim, odim, kernel_size, stride, pad,
  273. tf2torch_tensor_name_prefix_torch: str = "stride_conv",
  274. tf2torch_tensor_name_prefix_tf: str = "seq2seq/proj_encoder/downsampling",
  275. ):
  276. super(Conv1dSubsampling, self).__init__()
  277. self.conv = torch.nn.Conv1d(idim, odim, kernel_size, stride)
  278. self.pad_fn = torch.nn.ConstantPad1d(pad, 0.0)
  279. self.stride = stride
  280. self.odim = odim
  281. self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
  282. self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
  283. def output_size(self) -> int:
  284. return self.odim
  285. def forward(self, x, x_len):
  286. """Subsample x.
  287. """
  288. x = x.transpose(1, 2) # (b, d ,t)
  289. x = self.pad_fn(x)
  290. x = F.relu(self.conv(x))
  291. x = x.transpose(1, 2) # (b, t ,d)
  292. if x_len is None:
  293. return x, None
  294. x_len = (x_len - 1) // self.stride + 1
  295. return x, x_len
  296. def gen_tf2torch_map_dict(self):
  297. tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
  298. tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
  299. map_dict_local = {
  300. ## predictor
  301. "{}.conv.weight".format(tensor_name_prefix_torch):
  302. {"name": "{}/conv1d/kernel".format(tensor_name_prefix_tf),
  303. "squeeze": None,
  304. "transpose": (2, 1, 0),
  305. }, # (256,256,3),(3,256,256)
  306. "{}.conv.bias".format(tensor_name_prefix_torch):
  307. {"name": "{}/conv1d/bias".format(tensor_name_prefix_tf),
  308. "squeeze": None,
  309. "transpose": None,
  310. }, # (256,),(256,)
  311. }
  312. return map_dict_local
  313. def convert_tf2torch(self,
  314. var_dict_tf,
  315. var_dict_torch,
  316. ):
  317. map_dict = self.gen_tf2torch_map_dict()
  318. var_dict_torch_update = dict()
  319. for name in sorted(var_dict_torch.keys(), reverse=False):
  320. names = name.split('.')
  321. if names[0] == self.tf2torch_tensor_name_prefix_torch:
  322. name_tf = map_dict[name]["name"]
  323. data_tf = var_dict_tf[name_tf]
  324. if map_dict[name]["squeeze"] is not None:
  325. data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
  326. if map_dict[name]["transpose"] is not None:
  327. data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
  328. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  329. var_dict_torch_update[name] = data_tf
  330. logging.info(
  331. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
  332. var_dict_tf[name_tf].shape))
  333. return var_dict_torch_update