nets_utils.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768
  1. # -*- coding: utf-8 -*-
  2. """Network related utility tools."""
  3. import logging
  4. from typing import Dict, List, Tuple
  5. import numpy as np
  6. import torch
  7. def to_device(m, x):
  8. """Send tensor into the device of the module.
  9. Args:
  10. m (torch.nn.Module): Torch module.
  11. x (Tensor): Torch tensor.
  12. Returns:
  13. Tensor: Torch tensor located in the same place as torch module.
  14. """
  15. if isinstance(m, torch.nn.Module):
  16. device = next(m.parameters()).device
  17. elif isinstance(m, torch.Tensor):
  18. device = m.device
  19. else:
  20. raise TypeError(
  21. "Expected torch.nn.Module or torch.tensor, " f"bot got: {type(m)}"
  22. )
  23. return x.to(device)
  24. def pad_list(xs, pad_value):
  25. """Perform padding for the list of tensors.
  26. Args:
  27. xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
  28. pad_value (float): Value for padding.
  29. Returns:
  30. Tensor: Padded tensor (B, Tmax, `*`).
  31. Examples:
  32. >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
  33. >>> x
  34. [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
  35. >>> pad_list(x, 0)
  36. tensor([[1., 1., 1., 1.],
  37. [1., 1., 0., 0.],
  38. [1., 0., 0., 0.]])
  39. """
  40. n_batch = len(xs)
  41. max_len = max(x.size(0) for x in xs)
  42. pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
  43. for i in range(n_batch):
  44. pad[i, : xs[i].size(0)] = xs[i]
  45. return pad
  46. def pad_list_all_dim(xs, pad_value):
  47. """Perform padding for the list of tensors.
  48. Args:
  49. xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
  50. pad_value (float): Value for padding.
  51. Returns:
  52. Tensor: Padded tensor (B, Tmax, `*`).
  53. Examples:
  54. >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
  55. >>> x
  56. [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
  57. >>> pad_list(x, 0)
  58. tensor([[1., 1., 1., 1.],
  59. [1., 1., 0., 0.],
  60. [1., 0., 0., 0.]])
  61. """
  62. n_batch = len(xs)
  63. num_dim = len(xs[0].shape)
  64. max_len_all_dim = []
  65. for i in range(num_dim):
  66. max_len_all_dim.append(max(x.size(i) for x in xs))
  67. pad = xs[0].new(n_batch, *max_len_all_dim).fill_(pad_value)
  68. for i in range(n_batch):
  69. if num_dim == 1:
  70. pad[i, : xs[i].size(0)] = xs[i]
  71. elif num_dim == 2:
  72. pad[i, : xs[i].size(0), : xs[i].size(1)] = xs[i]
  73. elif num_dim == 3:
  74. pad[i, : xs[i].size(0), : xs[i].size(1), : xs[i].size(2)] = xs[i]
  75. else:
  76. raise ValueError(
  77. "pad_list_all_dim only support 1-D, 2-D and 3-D tensors, not {}-D.".format(num_dim)
  78. )
  79. return pad
  80. def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None):
  81. """Make mask tensor containing indices of padded part.
  82. Args:
  83. lengths (LongTensor or List): Batch of lengths (B,).
  84. xs (Tensor, optional): The reference tensor.
  85. If set, masks will be the same shape as this tensor.
  86. length_dim (int, optional): Dimension indicator of the above tensor.
  87. See the example.
  88. Returns:
  89. Tensor: Mask tensor containing indices of padded part.
  90. dtype=torch.uint8 in PyTorch 1.2-
  91. dtype=torch.bool in PyTorch 1.2+ (including 1.2)
  92. Examples:
  93. With only lengths.
  94. >>> lengths = [5, 3, 2]
  95. >>> make_pad_mask(lengths)
  96. masks = [[0, 0, 0, 0 ,0],
  97. [0, 0, 0, 1, 1],
  98. [0, 0, 1, 1, 1]]
  99. With the reference tensor.
  100. >>> xs = torch.zeros((3, 2, 4))
  101. >>> make_pad_mask(lengths, xs)
  102. tensor([[[0, 0, 0, 0],
  103. [0, 0, 0, 0]],
  104. [[0, 0, 0, 1],
  105. [0, 0, 0, 1]],
  106. [[0, 0, 1, 1],
  107. [0, 0, 1, 1]]], dtype=torch.uint8)
  108. >>> xs = torch.zeros((3, 2, 6))
  109. >>> make_pad_mask(lengths, xs)
  110. tensor([[[0, 0, 0, 0, 0, 1],
  111. [0, 0, 0, 0, 0, 1]],
  112. [[0, 0, 0, 1, 1, 1],
  113. [0, 0, 0, 1, 1, 1]],
  114. [[0, 0, 1, 1, 1, 1],
  115. [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
  116. With the reference tensor and dimension indicator.
  117. >>> xs = torch.zeros((3, 6, 6))
  118. >>> make_pad_mask(lengths, xs, 1)
  119. tensor([[[0, 0, 0, 0, 0, 0],
  120. [0, 0, 0, 0, 0, 0],
  121. [0, 0, 0, 0, 0, 0],
  122. [0, 0, 0, 0, 0, 0],
  123. [0, 0, 0, 0, 0, 0],
  124. [1, 1, 1, 1, 1, 1]],
  125. [[0, 0, 0, 0, 0, 0],
  126. [0, 0, 0, 0, 0, 0],
  127. [0, 0, 0, 0, 0, 0],
  128. [1, 1, 1, 1, 1, 1],
  129. [1, 1, 1, 1, 1, 1],
  130. [1, 1, 1, 1, 1, 1]],
  131. [[0, 0, 0, 0, 0, 0],
  132. [0, 0, 0, 0, 0, 0],
  133. [1, 1, 1, 1, 1, 1],
  134. [1, 1, 1, 1, 1, 1],
  135. [1, 1, 1, 1, 1, 1],
  136. [1, 1, 1, 1, 1, 1]]], dtype=torch.uint8)
  137. >>> make_pad_mask(lengths, xs, 2)
  138. tensor([[[0, 0, 0, 0, 0, 1],
  139. [0, 0, 0, 0, 0, 1],
  140. [0, 0, 0, 0, 0, 1],
  141. [0, 0, 0, 0, 0, 1],
  142. [0, 0, 0, 0, 0, 1],
  143. [0, 0, 0, 0, 0, 1]],
  144. [[0, 0, 0, 1, 1, 1],
  145. [0, 0, 0, 1, 1, 1],
  146. [0, 0, 0, 1, 1, 1],
  147. [0, 0, 0, 1, 1, 1],
  148. [0, 0, 0, 1, 1, 1],
  149. [0, 0, 0, 1, 1, 1]],
  150. [[0, 0, 1, 1, 1, 1],
  151. [0, 0, 1, 1, 1, 1],
  152. [0, 0, 1, 1, 1, 1],
  153. [0, 0, 1, 1, 1, 1],
  154. [0, 0, 1, 1, 1, 1],
  155. [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
  156. """
  157. if length_dim == 0:
  158. raise ValueError("length_dim cannot be 0: {}".format(length_dim))
  159. if not isinstance(lengths, list):
  160. lengths = lengths.tolist()
  161. bs = int(len(lengths))
  162. if maxlen is None:
  163. if xs is None:
  164. maxlen = int(max(lengths))
  165. else:
  166. maxlen = xs.size(length_dim)
  167. else:
  168. assert xs is None
  169. assert maxlen >= int(max(lengths))
  170. seq_range = torch.arange(0, maxlen, dtype=torch.int64)
  171. seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen)
  172. seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1)
  173. mask = seq_range_expand >= seq_length_expand
  174. if xs is not None:
  175. assert xs.size(0) == bs, (xs.size(0), bs)
  176. if length_dim < 0:
  177. length_dim = xs.dim() + length_dim
  178. # ind = (:, None, ..., None, :, , None, ..., None)
  179. ind = tuple(
  180. slice(None) if i in (0, length_dim) else None for i in range(xs.dim())
  181. )
  182. mask = mask[ind].expand_as(xs).to(xs.device)
  183. return mask
  184. def make_non_pad_mask(lengths, xs=None, length_dim=-1):
  185. """Make mask tensor containing indices of non-padded part.
  186. Args:
  187. lengths (LongTensor or List): Batch of lengths (B,).
  188. xs (Tensor, optional): The reference tensor.
  189. If set, masks will be the same shape as this tensor.
  190. length_dim (int, optional): Dimension indicator of the above tensor.
  191. See the example.
  192. Returns:
  193. ByteTensor: mask tensor containing indices of padded part.
  194. dtype=torch.uint8 in PyTorch 1.2-
  195. dtype=torch.bool in PyTorch 1.2+ (including 1.2)
  196. Examples:
  197. With only lengths.
  198. >>> lengths = [5, 3, 2]
  199. >>> make_non_pad_mask(lengths)
  200. masks = [[1, 1, 1, 1 ,1],
  201. [1, 1, 1, 0, 0],
  202. [1, 1, 0, 0, 0]]
  203. With the reference tensor.
  204. >>> xs = torch.zeros((3, 2, 4))
  205. >>> make_non_pad_mask(lengths, xs)
  206. tensor([[[1, 1, 1, 1],
  207. [1, 1, 1, 1]],
  208. [[1, 1, 1, 0],
  209. [1, 1, 1, 0]],
  210. [[1, 1, 0, 0],
  211. [1, 1, 0, 0]]], dtype=torch.uint8)
  212. >>> xs = torch.zeros((3, 2, 6))
  213. >>> make_non_pad_mask(lengths, xs)
  214. tensor([[[1, 1, 1, 1, 1, 0],
  215. [1, 1, 1, 1, 1, 0]],
  216. [[1, 1, 1, 0, 0, 0],
  217. [1, 1, 1, 0, 0, 0]],
  218. [[1, 1, 0, 0, 0, 0],
  219. [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
  220. With the reference tensor and dimension indicator.
  221. >>> xs = torch.zeros((3, 6, 6))
  222. >>> make_non_pad_mask(lengths, xs, 1)
  223. tensor([[[1, 1, 1, 1, 1, 1],
  224. [1, 1, 1, 1, 1, 1],
  225. [1, 1, 1, 1, 1, 1],
  226. [1, 1, 1, 1, 1, 1],
  227. [1, 1, 1, 1, 1, 1],
  228. [0, 0, 0, 0, 0, 0]],
  229. [[1, 1, 1, 1, 1, 1],
  230. [1, 1, 1, 1, 1, 1],
  231. [1, 1, 1, 1, 1, 1],
  232. [0, 0, 0, 0, 0, 0],
  233. [0, 0, 0, 0, 0, 0],
  234. [0, 0, 0, 0, 0, 0]],
  235. [[1, 1, 1, 1, 1, 1],
  236. [1, 1, 1, 1, 1, 1],
  237. [0, 0, 0, 0, 0, 0],
  238. [0, 0, 0, 0, 0, 0],
  239. [0, 0, 0, 0, 0, 0],
  240. [0, 0, 0, 0, 0, 0]]], dtype=torch.uint8)
  241. >>> make_non_pad_mask(lengths, xs, 2)
  242. tensor([[[1, 1, 1, 1, 1, 0],
  243. [1, 1, 1, 1, 1, 0],
  244. [1, 1, 1, 1, 1, 0],
  245. [1, 1, 1, 1, 1, 0],
  246. [1, 1, 1, 1, 1, 0],
  247. [1, 1, 1, 1, 1, 0]],
  248. [[1, 1, 1, 0, 0, 0],
  249. [1, 1, 1, 0, 0, 0],
  250. [1, 1, 1, 0, 0, 0],
  251. [1, 1, 1, 0, 0, 0],
  252. [1, 1, 1, 0, 0, 0],
  253. [1, 1, 1, 0, 0, 0]],
  254. [[1, 1, 0, 0, 0, 0],
  255. [1, 1, 0, 0, 0, 0],
  256. [1, 1, 0, 0, 0, 0],
  257. [1, 1, 0, 0, 0, 0],
  258. [1, 1, 0, 0, 0, 0],
  259. [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
  260. """
  261. return ~make_pad_mask(lengths, xs, length_dim)
  262. def mask_by_length(xs, lengths, fill=0):
  263. """Mask tensor according to length.
  264. Args:
  265. xs (Tensor): Batch of input tensor (B, `*`).
  266. lengths (LongTensor or List): Batch of lengths (B,).
  267. fill (int or float): Value to fill masked part.
  268. Returns:
  269. Tensor: Batch of masked input tensor (B, `*`).
  270. Examples:
  271. >>> x = torch.arange(5).repeat(3, 1) + 1
  272. >>> x
  273. tensor([[1, 2, 3, 4, 5],
  274. [1, 2, 3, 4, 5],
  275. [1, 2, 3, 4, 5]])
  276. >>> lengths = [5, 3, 2]
  277. >>> mask_by_length(x, lengths)
  278. tensor([[1, 2, 3, 4, 5],
  279. [1, 2, 3, 0, 0],
  280. [1, 2, 0, 0, 0]])
  281. """
  282. assert xs.size(0) == len(lengths)
  283. ret = xs.data.new(*xs.size()).fill_(fill)
  284. for i, l in enumerate(lengths):
  285. ret[i, :l] = xs[i, :l]
  286. return ret
  287. def th_accuracy(pad_outputs, pad_targets, ignore_label):
  288. """Calculate accuracy.
  289. Args:
  290. pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
  291. pad_targets (LongTensor): Target label tensors (B, Lmax, D).
  292. ignore_label (int): Ignore label id.
  293. Returns:
  294. float: Accuracy value (0.0 - 1.0).
  295. """
  296. pad_pred = pad_outputs.view(
  297. pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1)
  298. ).argmax(2)
  299. mask = pad_targets != ignore_label
  300. numerator = torch.sum(
  301. pad_pred.masked_select(mask) == pad_targets.masked_select(mask)
  302. )
  303. denominator = torch.sum(mask)
  304. return float(numerator) / float(denominator)
  305. def to_torch_tensor(x):
  306. """Change to torch.Tensor or ComplexTensor from numpy.ndarray.
  307. Args:
  308. x: Inputs. It should be one of numpy.ndarray, Tensor, ComplexTensor, and dict.
  309. Returns:
  310. Tensor or ComplexTensor: Type converted inputs.
  311. Examples:
  312. >>> xs = np.ones(3, dtype=np.float32)
  313. >>> xs = to_torch_tensor(xs)
  314. tensor([1., 1., 1.])
  315. >>> xs = torch.ones(3, 4, 5)
  316. >>> assert to_torch_tensor(xs) is xs
  317. >>> xs = {'real': xs, 'imag': xs}
  318. >>> to_torch_tensor(xs)
  319. ComplexTensor(
  320. Real:
  321. tensor([1., 1., 1.])
  322. Imag;
  323. tensor([1., 1., 1.])
  324. )
  325. """
  326. # If numpy, change to torch tensor
  327. if isinstance(x, np.ndarray):
  328. if x.dtype.kind == "c":
  329. # Dynamically importing because torch_complex requires python3
  330. from torch_complex.tensor import ComplexTensor
  331. return ComplexTensor(x)
  332. else:
  333. return torch.from_numpy(x)
  334. # If {'real': ..., 'imag': ...}, convert to ComplexTensor
  335. elif isinstance(x, dict):
  336. # Dynamically importing because torch_complex requires python3
  337. from torch_complex.tensor import ComplexTensor
  338. if "real" not in x or "imag" not in x:
  339. raise ValueError("has 'real' and 'imag' keys: {}".format(list(x)))
  340. # Relative importing because of using python3 syntax
  341. return ComplexTensor(x["real"], x["imag"])
  342. # If torch.Tensor, as it is
  343. elif isinstance(x, torch.Tensor):
  344. return x
  345. else:
  346. error = (
  347. "x must be numpy.ndarray, torch.Tensor or a dict like "
  348. "{{'real': torch.Tensor, 'imag': torch.Tensor}}, "
  349. "but got {}".format(type(x))
  350. )
  351. try:
  352. from torch_complex.tensor import ComplexTensor
  353. except Exception:
  354. # If PY2
  355. raise ValueError(error)
  356. else:
  357. # If PY3
  358. if isinstance(x, ComplexTensor):
  359. return x
  360. else:
  361. raise ValueError(error)
  362. def get_subsample(train_args, mode, arch):
  363. """Parse the subsampling factors from the args for the specified `mode` and `arch`.
  364. Args:
  365. train_args: argument Namespace containing options.
  366. mode: one of ('asr', 'mt', 'st')
  367. arch: one of ('rnn', 'rnn-t', 'rnn_mix', 'rnn_mulenc', 'transformer')
  368. Returns:
  369. np.ndarray / List[np.ndarray]: subsampling factors.
  370. """
  371. if arch == "transformer":
  372. return np.array([1])
  373. elif mode == "mt" and arch == "rnn":
  374. # +1 means input (+1) and layers outputs (train_args.elayer)
  375. subsample = np.ones(train_args.elayers + 1, dtype=np.int32)
  376. logging.warning("Subsampling is not performed for machine translation.")
  377. logging.info("subsample: " + " ".join([str(x) for x in subsample]))
  378. return subsample
  379. elif (
  380. (mode == "asr" and arch in ("rnn", "rnn-t"))
  381. or (mode == "mt" and arch == "rnn")
  382. or (mode == "st" and arch == "rnn")
  383. ):
  384. subsample = np.ones(train_args.elayers + 1, dtype=np.int32)
  385. if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
  386. ss = train_args.subsample.split("_")
  387. for j in range(min(train_args.elayers + 1, len(ss))):
  388. subsample[j] = int(ss[j])
  389. else:
  390. logging.warning(
  391. "Subsampling is not performed for vgg*. "
  392. "It is performed in max pooling layers at CNN."
  393. )
  394. logging.info("subsample: " + " ".join([str(x) for x in subsample]))
  395. return subsample
  396. elif mode == "asr" and arch == "rnn_mix":
  397. subsample = np.ones(
  398. train_args.elayers_sd + train_args.elayers + 1, dtype=np.int32
  399. )
  400. if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
  401. ss = train_args.subsample.split("_")
  402. for j in range(
  403. min(train_args.elayers_sd + train_args.elayers + 1, len(ss))
  404. ):
  405. subsample[j] = int(ss[j])
  406. else:
  407. logging.warning(
  408. "Subsampling is not performed for vgg*. "
  409. "It is performed in max pooling layers at CNN."
  410. )
  411. logging.info("subsample: " + " ".join([str(x) for x in subsample]))
  412. return subsample
  413. elif mode == "asr" and arch == "rnn_mulenc":
  414. subsample_list = []
  415. for idx in range(train_args.num_encs):
  416. subsample = np.ones(train_args.elayers[idx] + 1, dtype=np.int32)
  417. if train_args.etype[idx].endswith("p") and not train_args.etype[
  418. idx
  419. ].startswith("vgg"):
  420. ss = train_args.subsample[idx].split("_")
  421. for j in range(min(train_args.elayers[idx] + 1, len(ss))):
  422. subsample[j] = int(ss[j])
  423. else:
  424. logging.warning(
  425. "Encoder %d: Subsampling is not performed for vgg*. "
  426. "It is performed in max pooling layers at CNN.",
  427. idx + 1,
  428. )
  429. logging.info("subsample: " + " ".join([str(x) for x in subsample]))
  430. subsample_list.append(subsample)
  431. return subsample_list
  432. else:
  433. raise ValueError("Invalid options: mode={}, arch={}".format(mode, arch))
  434. def rename_state_dict(
  435. old_prefix: str, new_prefix: str, state_dict: Dict[str, torch.Tensor]
  436. ):
  437. """Replace keys of old prefix with new prefix in state dict."""
  438. # need this list not to break the dict iterator
  439. old_keys = [k for k in state_dict if k.startswith(old_prefix)]
  440. if len(old_keys) > 0:
  441. logging.warning(f"Rename: {old_prefix} -> {new_prefix}")
  442. for k in old_keys:
  443. v = state_dict.pop(k)
  444. new_k = k.replace(old_prefix, new_prefix)
  445. state_dict[new_k] = v
  446. class Swish(torch.nn.Module):
  447. """Swish activation definition.
  448. Swish(x) = (beta * x) * sigmoid(x)
  449. where beta = 1 defines standard Swish activation.
  450. References:
  451. https://arxiv.org/abs/2108.12943 / https://arxiv.org/abs/1710.05941v1.
  452. E-swish variant: https://arxiv.org/abs/1801.07145.
  453. Args:
  454. beta: Beta parameter for E-Swish.
  455. (beta >= 1. If beta < 1, use standard Swish).
  456. use_builtin: Whether to use PyTorch function if available.
  457. """
  458. def __init__(self, beta: float = 1.0, use_builtin: bool = False) -> None:
  459. super().__init__()
  460. self.beta = beta
  461. if beta > 1:
  462. self.swish = lambda x: (self.beta * x) * torch.sigmoid(x)
  463. else:
  464. if use_builtin:
  465. self.swish = torch.nn.SiLU()
  466. else:
  467. self.swish = lambda x: x * torch.sigmoid(x)
  468. def forward(self, x: torch.Tensor) -> torch.Tensor:
  469. """Forward computation."""
  470. return self.swish(x)
  471. def get_activation(act):
  472. """Return activation function."""
  473. activation_funcs = {
  474. "hardtanh": torch.nn.Hardtanh,
  475. "tanh": torch.nn.Tanh,
  476. "relu": torch.nn.ReLU,
  477. "selu": torch.nn.SELU,
  478. "swish": Swish,
  479. }
  480. return activation_funcs[act]()
  481. class TooShortUttError(Exception):
  482. """Raised when the utt is too short for subsampling.
  483. Args:
  484. message: Error message to display.
  485. actual_size: The size that cannot pass the subsampling.
  486. limit: The size limit for subsampling.
  487. """
  488. def __init__(self, message: str, actual_size: int, limit: int) -> None:
  489. """Construct a TooShortUttError module."""
  490. super().__init__(message)
  491. self.actual_size = actual_size
  492. self.limit = limit
  493. def check_short_utt(sub_factor: int, size: int) -> Tuple[bool, int]:
  494. """Check if the input is too short for subsampling.
  495. Args:
  496. sub_factor: Subsampling factor for Conv2DSubsampling.
  497. size: Input size.
  498. Returns:
  499. : Whether an error should be sent.
  500. : Size limit for specified subsampling factor.
  501. """
  502. if sub_factor == 2 and size < 3:
  503. return True, 7
  504. elif sub_factor == 4 and size < 7:
  505. return True, 7
  506. elif sub_factor == 6 and size < 11:
  507. return True, 11
  508. return False, -1
  509. def sub_factor_to_params(sub_factor: int, input_size: int) -> Tuple[int, int, int]:
  510. """Get conv2D second layer parameters for given subsampling factor.
  511. Args:
  512. sub_factor: Subsampling factor (1/X).
  513. input_size: Input size.
  514. Returns:
  515. : Kernel size for second convolution.
  516. : Stride for second convolution.
  517. : Conv2DSubsampling output size.
  518. """
  519. if sub_factor == 2:
  520. return 3, 1, (((input_size - 1) // 2 - 2))
  521. elif sub_factor == 4:
  522. return 3, 2, (((input_size - 1) // 2 - 1) // 2)
  523. elif sub_factor == 6:
  524. return 5, 3, (((input_size - 1) // 2 - 2) // 3)
  525. else:
  526. raise ValueError(
  527. "subsampling_factor parameter should be set to either 2, 4 or 6."
  528. )
  529. def make_chunk_mask(
  530. size: int,
  531. chunk_size: int,
  532. left_chunk_size: int = 0,
  533. device: torch.device = None,
  534. ) -> torch.Tensor:
  535. """Create chunk mask for the subsequent steps (size, size).
  536. Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
  537. Args:
  538. size: Size of the source mask.
  539. chunk_size: Number of frames in chunk.
  540. left_chunk_size: Size of the left context in chunks (0 means full context).
  541. device: Device for the mask tensor.
  542. Returns:
  543. mask: Chunk mask. (size, size)
  544. """
  545. mask = torch.zeros(size, size, device=device, dtype=torch.bool)
  546. for i in range(size):
  547. if left_chunk_size < 0:
  548. start = 0
  549. else:
  550. start = max((i // chunk_size - left_chunk_size) * chunk_size, 0)
  551. end = min((i // chunk_size + 1) * chunk_size, size)
  552. mask[i, start:end] = True
  553. return ~mask
  554. def make_source_mask(lengths: torch.Tensor) -> torch.Tensor:
  555. """Create source mask for given lengths.
  556. Reference: https://github.com/k2-fsa/icefall/blob/master/icefall/utils.py
  557. Args:
  558. lengths: Sequence lengths. (B,)
  559. Returns:
  560. : Mask for the sequence lengths. (B, max_len)
  561. """
  562. max_len = lengths.max()
  563. batch_size = lengths.size(0)
  564. expanded_lengths = torch.arange(max_len).expand(batch_size, max_len).to(lengths)
  565. return expanded_lengths >= lengths.unsqueeze(1)
  566. def get_transducer_task_io(
  567. labels: torch.Tensor,
  568. encoder_out_lens: torch.Tensor,
  569. ignore_id: int = -1,
  570. blank_id: int = 0,
  571. ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
  572. """Get Transducer loss I/O.
  573. Args:
  574. labels: Label ID sequences. (B, L)
  575. encoder_out_lens: Encoder output lengths. (B,)
  576. ignore_id: Padding symbol ID.
  577. blank_id: Blank symbol ID.
  578. Returns:
  579. decoder_in: Decoder inputs. (B, U)
  580. target: Target label ID sequences. (B, U)
  581. t_len: Time lengths. (B,)
  582. u_len: Label lengths. (B,)
  583. """
  584. def pad_list(labels: List[torch.Tensor], padding_value: int = 0):
  585. """Create padded batch of labels from a list of labels sequences.
  586. Args:
  587. labels: Labels sequences. [B x (?)]
  588. padding_value: Padding value.
  589. Returns:
  590. labels: Batch of padded labels sequences. (B,)
  591. """
  592. batch_size = len(labels)
  593. padded = (
  594. labels[0]
  595. .new(batch_size, max(x.size(0) for x in labels), *labels[0].size()[1:])
  596. .fill_(padding_value)
  597. )
  598. for i in range(batch_size):
  599. padded[i, : labels[i].size(0)] = labels[i]
  600. return padded
  601. device = labels.device
  602. labels_unpad = [y[y != ignore_id] for y in labels]
  603. blank = labels[0].new([blank_id])
  604. decoder_in = pad_list(
  605. [torch.cat([blank, label], dim=0) for label in labels_unpad], blank_id
  606. ).to(device)
  607. target = pad_list(labels_unpad, blank_id).type(torch.int32).to(device)
  608. encoder_out_lens = list(map(int, encoder_out_lens))
  609. t_len = torch.IntTensor(encoder_out_lens).to(device)
  610. u_len = torch.IntTensor([y.size(0) for y in labels_unpad]).to(device)
  611. return decoder_in, target, t_len, u_len
  612. def pad_to_len(t: torch.Tensor, pad_len: int, dim: int):
  613. """Pad the tensor `t` at `dim` to the length `pad_len` with right padding zeros."""
  614. if t.size(dim) == pad_len:
  615. return t
  616. else:
  617. pad_size = list(t.shape)
  618. pad_size[dim] = pad_len - t.size(dim)
  619. return torch.cat(
  620. [t, torch.zeros(*pad_size, dtype=t.dtype, device=t.device)], dim=dim
  621. )