subsampling.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. # Copyright 2019 Shigeki Karita
  4. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  5. """Subsampling layer definition."""
  6. import numpy as np
  7. import torch
  8. import torch.nn.functional as F
  9. from funasr.modules.embedding import PositionalEncoding
  10. import logging
  11. from funasr.modules.streaming_utils.utils import sequence_mask
  12. from funasr.modules.nets_utils import sub_factor_to_params, pad_to_len
  13. from typing import Optional, Tuple, Union
  14. import math
  15. class TooShortUttError(Exception):
  16. """Raised when the utt is too short for subsampling.
  17. Args:
  18. message (str): Message for error catch
  19. actual_size (int): the short size that cannot pass the subsampling
  20. limit (int): the limit size for subsampling
  21. """
  22. def __init__(self, message, actual_size, limit):
  23. """Construct a TooShortUttError for error handler."""
  24. super().__init__(message)
  25. self.actual_size = actual_size
  26. self.limit = limit
  27. def check_short_utt(ins, size):
  28. """Check if the utterance is too short for subsampling."""
  29. if isinstance(ins, Conv2dSubsampling2) and size < 3:
  30. return True, 3
  31. if isinstance(ins, Conv2dSubsampling) and size < 7:
  32. return True, 7
  33. if isinstance(ins, Conv2dSubsampling6) and size < 11:
  34. return True, 11
  35. if isinstance(ins, Conv2dSubsampling8) and size < 15:
  36. return True, 15
  37. return False, -1
  38. class Conv2dSubsampling(torch.nn.Module):
  39. """Convolutional 2D subsampling (to 1/4 length).
  40. Args:
  41. idim (int): Input dimension.
  42. odim (int): Output dimension.
  43. dropout_rate (float): Dropout rate.
  44. pos_enc (torch.nn.Module): Custom position encoding layer.
  45. """
  46. def __init__(self, idim, odim, dropout_rate, pos_enc=None):
  47. """Construct an Conv2dSubsampling object."""
  48. super(Conv2dSubsampling, self).__init__()
  49. self.conv = torch.nn.Sequential(
  50. torch.nn.Conv2d(1, odim, 3, 2),
  51. torch.nn.ReLU(),
  52. torch.nn.Conv2d(odim, odim, 3, 2),
  53. torch.nn.ReLU(),
  54. )
  55. self.out = torch.nn.Sequential(
  56. torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim),
  57. pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
  58. )
  59. def forward(self, x, x_mask):
  60. """Subsample x.
  61. Args:
  62. x (torch.Tensor): Input tensor (#batch, time, idim).
  63. x_mask (torch.Tensor): Input mask (#batch, 1, time).
  64. Returns:
  65. torch.Tensor: Subsampled tensor (#batch, time', odim),
  66. where time' = time // 4.
  67. torch.Tensor: Subsampled mask (#batch, 1, time'),
  68. where time' = time // 4.
  69. """
  70. x = x.unsqueeze(1) # (b, c, t, f)
  71. x = self.conv(x)
  72. b, c, t, f = x.size()
  73. x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
  74. if x_mask is None:
  75. return x, None
  76. return x, x_mask[:, :, :-2:2][:, :, :-2:2]
  77. def __getitem__(self, key):
  78. """Get item.
  79. When reset_parameters() is called, if use_scaled_pos_enc is used,
  80. return the positioning encoding.
  81. """
  82. if key != -1:
  83. raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
  84. return self.out[key]
  85. class Conv2dSubsamplingPad(torch.nn.Module):
  86. """Convolutional 2D subsampling (to 1/4 length).
  87. Args:
  88. idim (int): Input dimension.
  89. odim (int): Output dimension.
  90. dropout_rate (float): Dropout rate.
  91. pos_enc (torch.nn.Module): Custom position encoding layer.
  92. """
  93. def __init__(self, idim, odim, dropout_rate, pos_enc=None):
  94. """Construct an Conv2dSubsampling object."""
  95. super(Conv2dSubsamplingPad, self).__init__()
  96. self.conv = torch.nn.Sequential(
  97. torch.nn.Conv2d(1, odim, 3, 2, padding=(0, 0)),
  98. torch.nn.ReLU(),
  99. torch.nn.Conv2d(odim, odim, 3, 2, padding=(0, 0)),
  100. torch.nn.ReLU(),
  101. )
  102. self.out = torch.nn.Sequential(
  103. torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim),
  104. pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
  105. )
  106. self.pad_fn = torch.nn.ConstantPad1d((0, 4), 0.0)
  107. def forward(self, x, x_mask):
  108. """Subsample x.
  109. Args:
  110. x (torch.Tensor): Input tensor (#batch, time, idim).
  111. x_mask (torch.Tensor): Input mask (#batch, 1, time).
  112. Returns:
  113. torch.Tensor: Subsampled tensor (#batch, time', odim),
  114. where time' = time // 4.
  115. torch.Tensor: Subsampled mask (#batch, 1, time'),
  116. where time' = time // 4.
  117. """
  118. x = x.transpose(1, 2)
  119. x = self.pad_fn(x)
  120. x = x.transpose(1, 2)
  121. x = x.unsqueeze(1) # (b, c, t, f)
  122. x = self.conv(x)
  123. b, c, t, f = x.size()
  124. x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
  125. if x_mask is None:
  126. return x, None
  127. x_len = torch.sum(x_mask[:, 0, :], dim=-1)
  128. x_len = (x_len - 1) // 2 + 1
  129. x_len = (x_len - 1) // 2 + 1
  130. mask = sequence_mask(x_len, None, x_len.dtype, x[0].device)
  131. return x, mask[:, None, :]
  132. def __getitem__(self, key):
  133. """Get item.
  134. When reset_parameters() is called, if use_scaled_pos_enc is used,
  135. return the positioning encoding.
  136. """
  137. if key != -1:
  138. raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
  139. return self.out[key]
  140. class Conv2dSubsampling2(torch.nn.Module):
  141. """Convolutional 2D subsampling (to 1/2 length).
  142. Args:
  143. idim (int): Input dimension.
  144. odim (int): Output dimension.
  145. dropout_rate (float): Dropout rate.
  146. pos_enc (torch.nn.Module): Custom position encoding layer.
  147. """
  148. def __init__(self, idim, odim, dropout_rate, pos_enc=None):
  149. """Construct an Conv2dSubsampling2 object."""
  150. super(Conv2dSubsampling2, self).__init__()
  151. self.conv = torch.nn.Sequential(
  152. torch.nn.Conv2d(1, odim, 3, 2),
  153. torch.nn.ReLU(),
  154. torch.nn.Conv2d(odim, odim, 3, 1),
  155. torch.nn.ReLU(),
  156. )
  157. self.out = torch.nn.Sequential(
  158. torch.nn.Linear(odim * (((idim - 1) // 2 - 2)), odim),
  159. pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
  160. )
  161. def forward(self, x, x_mask):
  162. """Subsample x.
  163. Args:
  164. x (torch.Tensor): Input tensor (#batch, time, idim).
  165. x_mask (torch.Tensor): Input mask (#batch, 1, time).
  166. Returns:
  167. torch.Tensor: Subsampled tensor (#batch, time', odim),
  168. where time' = time // 2.
  169. torch.Tensor: Subsampled mask (#batch, 1, time'),
  170. where time' = time // 2.
  171. """
  172. x = x.unsqueeze(1) # (b, c, t, f)
  173. x = self.conv(x)
  174. b, c, t, f = x.size()
  175. x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
  176. if x_mask is None:
  177. return x, None
  178. return x, x_mask[:, :, :-2:2][:, :, :-2:1]
  179. def __getitem__(self, key):
  180. """Get item.
  181. When reset_parameters() is called, if use_scaled_pos_enc is used,
  182. return the positioning encoding.
  183. """
  184. if key != -1:
  185. raise NotImplementedError("Support only `-1` (for `reset_parameters`).")
  186. return self.out[key]
  187. class Conv2dSubsampling6(torch.nn.Module):
  188. """Convolutional 2D subsampling (to 1/6 length).
  189. Args:
  190. idim (int): Input dimension.
  191. odim (int): Output dimension.
  192. dropout_rate (float): Dropout rate.
  193. pos_enc (torch.nn.Module): Custom position encoding layer.
  194. """
  195. def __init__(self, idim, odim, dropout_rate, pos_enc=None):
  196. """Construct an Conv2dSubsampling6 object."""
  197. super(Conv2dSubsampling6, self).__init__()
  198. self.conv = torch.nn.Sequential(
  199. torch.nn.Conv2d(1, odim, 3, 2),
  200. torch.nn.ReLU(),
  201. torch.nn.Conv2d(odim, odim, 5, 3),
  202. torch.nn.ReLU(),
  203. )
  204. self.out = torch.nn.Sequential(
  205. torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3), odim),
  206. pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
  207. )
  208. def forward(self, x, x_mask):
  209. """Subsample x.
  210. Args:
  211. x (torch.Tensor): Input tensor (#batch, time, idim).
  212. x_mask (torch.Tensor): Input mask (#batch, 1, time).
  213. Returns:
  214. torch.Tensor: Subsampled tensor (#batch, time', odim),
  215. where time' = time // 6.
  216. torch.Tensor: Subsampled mask (#batch, 1, time'),
  217. where time' = time // 6.
  218. """
  219. x = x.unsqueeze(1) # (b, c, t, f)
  220. x = self.conv(x)
  221. b, c, t, f = x.size()
  222. x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
  223. if x_mask is None:
  224. return x, None
  225. return x, x_mask[:, :, :-2:2][:, :, :-4:3]
  226. class Conv2dSubsampling8(torch.nn.Module):
  227. """Convolutional 2D subsampling (to 1/8 length).
  228. Args:
  229. idim (int): Input dimension.
  230. odim (int): Output dimension.
  231. dropout_rate (float): Dropout rate.
  232. pos_enc (torch.nn.Module): Custom position encoding layer.
  233. """
  234. def __init__(self, idim, odim, dropout_rate, pos_enc=None):
  235. """Construct an Conv2dSubsampling8 object."""
  236. super(Conv2dSubsampling8, self).__init__()
  237. self.conv = torch.nn.Sequential(
  238. torch.nn.Conv2d(1, odim, 3, 2),
  239. torch.nn.ReLU(),
  240. torch.nn.Conv2d(odim, odim, 3, 2),
  241. torch.nn.ReLU(),
  242. torch.nn.Conv2d(odim, odim, 3, 2),
  243. torch.nn.ReLU(),
  244. )
  245. self.out = torch.nn.Sequential(
  246. torch.nn.Linear(odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim),
  247. pos_enc if pos_enc is not None else PositionalEncoding(odim, dropout_rate),
  248. )
  249. def forward(self, x, x_mask):
  250. """Subsample x.
  251. Args:
  252. x (torch.Tensor): Input tensor (#batch, time, idim).
  253. x_mask (torch.Tensor): Input mask (#batch, 1, time).
  254. Returns:
  255. torch.Tensor: Subsampled tensor (#batch, time', odim),
  256. where time' = time // 8.
  257. torch.Tensor: Subsampled mask (#batch, 1, time'),
  258. where time' = time // 8.
  259. """
  260. x = x.unsqueeze(1) # (b, c, t, f)
  261. x = self.conv(x)
  262. b, c, t, f = x.size()
  263. x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
  264. if x_mask is None:
  265. return x, None
  266. return x, x_mask[:, :, :-2:2][:, :, :-2:2][:, :, :-2:2]
  267. class Conv1dSubsampling(torch.nn.Module):
  268. """Convolutional 1D subsampling (to 1/2 length).
  269. Args:
  270. idim (int): Input dimension.
  271. odim (int): Output dimension.
  272. dropout_rate (float): Dropout rate.
  273. pos_enc (torch.nn.Module): Custom position encoding layer.
  274. """
  275. def __init__(self, idim, odim, kernel_size, stride, pad,
  276. tf2torch_tensor_name_prefix_torch: str = "stride_conv",
  277. tf2torch_tensor_name_prefix_tf: str = "seq2seq/proj_encoder/downsampling",
  278. ):
  279. super(Conv1dSubsampling, self).__init__()
  280. self.conv = torch.nn.Conv1d(idim, odim, kernel_size, stride)
  281. self.pad_fn = torch.nn.ConstantPad1d(pad, 0.0)
  282. self.stride = stride
  283. self.odim = odim
  284. self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
  285. self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
  286. def output_size(self) -> int:
  287. return self.odim
  288. def forward(self, x, x_len):
  289. """Subsample x.
  290. """
  291. x = x.transpose(1, 2) # (b, d ,t)
  292. x = self.pad_fn(x)
  293. x = F.relu(self.conv(x))
  294. x = x.transpose(1, 2) # (b, t ,d)
  295. if x_len is None:
  296. return x, None
  297. x_len = (x_len - 1) // self.stride + 1
  298. return x, x_len
  299. def gen_tf2torch_map_dict(self):
  300. tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
  301. tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
  302. map_dict_local = {
  303. ## predictor
  304. "{}.conv.weight".format(tensor_name_prefix_torch):
  305. {"name": "{}/conv1d/kernel".format(tensor_name_prefix_tf),
  306. "squeeze": None,
  307. "transpose": (2, 1, 0),
  308. }, # (256,256,3),(3,256,256)
  309. "{}.conv.bias".format(tensor_name_prefix_torch):
  310. {"name": "{}/conv1d/bias".format(tensor_name_prefix_tf),
  311. "squeeze": None,
  312. "transpose": None,
  313. }, # (256,),(256,)
  314. }
  315. return map_dict_local
  316. def convert_tf2torch(self,
  317. var_dict_tf,
  318. var_dict_torch,
  319. ):
  320. map_dict = self.gen_tf2torch_map_dict()
  321. var_dict_torch_update = dict()
  322. for name in sorted(var_dict_torch.keys(), reverse=False):
  323. names = name.split('.')
  324. if names[0] == self.tf2torch_tensor_name_prefix_torch:
  325. name_tf = map_dict[name]["name"]
  326. data_tf = var_dict_tf[name_tf]
  327. if map_dict[name]["squeeze"] is not None:
  328. data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
  329. if map_dict[name]["transpose"] is not None:
  330. data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
  331. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  332. var_dict_torch_update[name] = data_tf
  333. logging.info(
  334. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
  335. var_dict_tf[name_tf].shape))
  336. return var_dict_torch_update
  337. class StreamingConvInput(torch.nn.Module):
  338. """Streaming ConvInput module definition.
  339. Args:
  340. input_size: Input size.
  341. conv_size: Convolution size.
  342. subsampling_factor: Subsampling factor.
  343. vgg_like: Whether to use a VGG-like network.
  344. output_size: Block output dimension.
  345. """
  346. def __init__(
  347. self,
  348. input_size: int,
  349. conv_size: Union[int, Tuple],
  350. subsampling_factor: int = 4,
  351. vgg_like: bool = True,
  352. conv_kernel_size: int = 3,
  353. output_size: Optional[int] = None,
  354. ) -> None:
  355. """Construct a ConvInput object."""
  356. super().__init__()
  357. if vgg_like:
  358. if subsampling_factor == 1:
  359. conv_size1, conv_size2 = conv_size
  360. self.conv = torch.nn.Sequential(
  361. torch.nn.Conv2d(1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
  362. torch.nn.ReLU(),
  363. torch.nn.Conv2d(conv_size1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
  364. torch.nn.ReLU(),
  365. torch.nn.MaxPool2d((1, 2)),
  366. torch.nn.Conv2d(conv_size1, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
  367. torch.nn.ReLU(),
  368. torch.nn.Conv2d(conv_size2, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
  369. torch.nn.ReLU(),
  370. torch.nn.MaxPool2d((1, 2)),
  371. )
  372. output_proj = conv_size2 * ((input_size // 2) // 2)
  373. self.subsampling_factor = 1
  374. self.stride_1 = 1
  375. self.create_new_mask = self.create_new_vgg_mask
  376. else:
  377. conv_size1, conv_size2 = conv_size
  378. kernel_1 = int(subsampling_factor / 2)
  379. self.conv = torch.nn.Sequential(
  380. torch.nn.Conv2d(1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
  381. torch.nn.ReLU(),
  382. torch.nn.Conv2d(conv_size1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
  383. torch.nn.ReLU(),
  384. torch.nn.MaxPool2d((kernel_1, 2)),
  385. torch.nn.Conv2d(conv_size1, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
  386. torch.nn.ReLU(),
  387. torch.nn.Conv2d(conv_size2, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
  388. torch.nn.ReLU(),
  389. torch.nn.MaxPool2d((2, 2)),
  390. )
  391. output_proj = conv_size2 * ((input_size // 2) // 2)
  392. self.subsampling_factor = subsampling_factor
  393. self.create_new_mask = self.create_new_vgg_mask
  394. self.stride_1 = kernel_1
  395. else:
  396. if subsampling_factor == 1:
  397. self.conv = torch.nn.Sequential(
  398. torch.nn.Conv2d(1, conv_size, 3, [1,2], [1,0]),
  399. torch.nn.ReLU(),
  400. torch.nn.Conv2d(conv_size, conv_size, conv_kernel_size, [1,2], [1,0]),
  401. torch.nn.ReLU(),
  402. )
  403. output_proj = conv_size * (((input_size - 1) // 2 - 1) // 2)
  404. self.subsampling_factor = subsampling_factor
  405. self.kernel_2 = conv_kernel_size
  406. self.stride_2 = 1
  407. self.create_new_mask = self.create_new_conv2d_mask
  408. else:
  409. kernel_2, stride_2, conv_2_output_size = sub_factor_to_params(
  410. subsampling_factor,
  411. input_size,
  412. )
  413. self.conv = torch.nn.Sequential(
  414. torch.nn.Conv2d(1, conv_size, 3, 2, [1,0]),
  415. torch.nn.ReLU(),
  416. torch.nn.Conv2d(conv_size, conv_size, kernel_2, stride_2, [(kernel_2-1)//2, 0]),
  417. torch.nn.ReLU(),
  418. )
  419. output_proj = conv_size * conv_2_output_size
  420. self.subsampling_factor = subsampling_factor
  421. self.kernel_2 = kernel_2
  422. self.stride_2 = stride_2
  423. self.create_new_mask = self.create_new_conv2d_mask
  424. self.vgg_like = vgg_like
  425. self.min_frame_length = 7
  426. if output_size is not None:
  427. self.output = torch.nn.Linear(output_proj, output_size)
  428. self.output_size = output_size
  429. else:
  430. self.output = None
  431. self.output_size = output_proj
  432. def forward(
  433. self, x: torch.Tensor, mask: Optional[torch.Tensor], chunk_size: Optional[torch.Tensor]
  434. ) -> Tuple[torch.Tensor, torch.Tensor]:
  435. """Encode input sequences.
  436. Args:
  437. x: ConvInput input sequences. (B, T, D_feats)
  438. mask: Mask of input sequences. (B, 1, T)
  439. Returns:
  440. x: ConvInput output sequences. (B, sub(T), D_out)
  441. mask: Mask of output sequences. (B, 1, sub(T))
  442. """
  443. if mask is not None:
  444. mask = self.create_new_mask(mask)
  445. olens = max(mask.eq(0).sum(1))
  446. b, t, f = x.size()
  447. x = x.unsqueeze(1) # (b. 1. t. f)
  448. if chunk_size is not None:
  449. max_input_length = int(
  450. chunk_size * self.subsampling_factor * (math.ceil(float(t) / (chunk_size * self.subsampling_factor) ))
  451. )
  452. x = map(lambda inputs: pad_to_len(inputs, max_input_length, 1), x)
  453. x = list(x)
  454. x = torch.stack(x, dim=0)
  455. N_chunks = max_input_length // ( chunk_size * self.subsampling_factor)
  456. x = x.view(b * N_chunks, 1, chunk_size * self.subsampling_factor, f)
  457. x = self.conv(x)
  458. _, c, _, f = x.size()
  459. if chunk_size is not None:
  460. x = x.transpose(1, 2).contiguous().view(b, -1, c * f)[:,:olens,:]
  461. else:
  462. x = x.transpose(1, 2).contiguous().view(b, -1, c * f)
  463. if self.output is not None:
  464. x = self.output(x)
  465. return x, mask[:,:olens][:,:x.size(1)]
  466. def create_new_vgg_mask(self, mask: torch.Tensor) -> torch.Tensor:
  467. """Create a new mask for VGG output sequences.
  468. Args:
  469. mask: Mask of input sequences. (B, T)
  470. Returns:
  471. mask: Mask of output sequences. (B, sub(T))
  472. """
  473. if self.subsampling_factor > 1:
  474. vgg1_t_len = mask.size(1) - (mask.size(1) % (self.subsampling_factor // 2 ))
  475. mask = mask[:, :vgg1_t_len][:, ::self.subsampling_factor // 2]
  476. vgg2_t_len = mask.size(1) - (mask.size(1) % 2)
  477. mask = mask[:, :vgg2_t_len][:, ::2]
  478. else:
  479. mask = mask
  480. return mask
  481. def create_new_conv2d_mask(self, mask: torch.Tensor) -> torch.Tensor:
  482. """Create new conformer mask for Conv2d output sequences.
  483. Args:
  484. mask: Mask of input sequences. (B, T)
  485. Returns:
  486. mask: Mask of output sequences. (B, sub(T))
  487. """
  488. if self.subsampling_factor > 1:
  489. return mask[:, ::2][:, ::self.stride_2]
  490. else:
  491. return mask
  492. def get_size_before_subsampling(self, size: int) -> int:
  493. """Return the original size before subsampling for a given size.
  494. Args:
  495. size: Number of frames after subsampling.
  496. Returns:
  497. : Number of frames before subsampling.
  498. """
  499. return size * self.subsampling_factor