nets_utils.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508
  1. # -*- coding: utf-8 -*-
  2. """Network related utility tools."""
  3. import logging
  4. from typing import Dict
  5. import numpy as np
  6. import torch
  7. def to_device(m, x):
  8. """Send tensor into the device of the module.
  9. Args:
  10. m (torch.nn.Module): Torch module.
  11. x (Tensor): Torch tensor.
  12. Returns:
  13. Tensor: Torch tensor located in the same place as torch module.
  14. """
  15. if isinstance(m, torch.nn.Module):
  16. device = next(m.parameters()).device
  17. elif isinstance(m, torch.Tensor):
  18. device = m.device
  19. else:
  20. raise TypeError(
  21. "Expected torch.nn.Module or torch.tensor, " f"bot got: {type(m)}"
  22. )
  23. return x.to(device)
  24. def pad_list(xs, pad_value):
  25. """Perform padding for the list of tensors.
  26. Args:
  27. xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
  28. pad_value (float): Value for padding.
  29. Returns:
  30. Tensor: Padded tensor (B, Tmax, `*`).
  31. Examples:
  32. >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
  33. >>> x
  34. [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
  35. >>> pad_list(x, 0)
  36. tensor([[1., 1., 1., 1.],
  37. [1., 1., 0., 0.],
  38. [1., 0., 0., 0.]])
  39. """
  40. n_batch = len(xs)
  41. max_len = max(x.size(0) for x in xs)
  42. pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
  43. for i in range(n_batch):
  44. pad[i, : xs[i].size(0)] = xs[i]
  45. return pad
  46. def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None):
  47. """Make mask tensor containing indices of padded part.
  48. Args:
  49. lengths (LongTensor or List): Batch of lengths (B,).
  50. xs (Tensor, optional): The reference tensor.
  51. If set, masks will be the same shape as this tensor.
  52. length_dim (int, optional): Dimension indicator of the above tensor.
  53. See the example.
  54. Returns:
  55. Tensor: Mask tensor containing indices of padded part.
  56. dtype=torch.uint8 in PyTorch 1.2-
  57. dtype=torch.bool in PyTorch 1.2+ (including 1.2)
  58. Examples:
  59. With only lengths.
  60. >>> lengths = [5, 3, 2]
  61. >>> make_pad_mask(lengths)
  62. masks = [[0, 0, 0, 0 ,0],
  63. [0, 0, 0, 1, 1],
  64. [0, 0, 1, 1, 1]]
  65. With the reference tensor.
  66. >>> xs = torch.zeros((3, 2, 4))
  67. >>> make_pad_mask(lengths, xs)
  68. tensor([[[0, 0, 0, 0],
  69. [0, 0, 0, 0]],
  70. [[0, 0, 0, 1],
  71. [0, 0, 0, 1]],
  72. [[0, 0, 1, 1],
  73. [0, 0, 1, 1]]], dtype=torch.uint8)
  74. >>> xs = torch.zeros((3, 2, 6))
  75. >>> make_pad_mask(lengths, xs)
  76. tensor([[[0, 0, 0, 0, 0, 1],
  77. [0, 0, 0, 0, 0, 1]],
  78. [[0, 0, 0, 1, 1, 1],
  79. [0, 0, 0, 1, 1, 1]],
  80. [[0, 0, 1, 1, 1, 1],
  81. [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
  82. With the reference tensor and dimension indicator.
  83. >>> xs = torch.zeros((3, 6, 6))
  84. >>> make_pad_mask(lengths, xs, 1)
  85. tensor([[[0, 0, 0, 0, 0, 0],
  86. [0, 0, 0, 0, 0, 0],
  87. [0, 0, 0, 0, 0, 0],
  88. [0, 0, 0, 0, 0, 0],
  89. [0, 0, 0, 0, 0, 0],
  90. [1, 1, 1, 1, 1, 1]],
  91. [[0, 0, 0, 0, 0, 0],
  92. [0, 0, 0, 0, 0, 0],
  93. [0, 0, 0, 0, 0, 0],
  94. [1, 1, 1, 1, 1, 1],
  95. [1, 1, 1, 1, 1, 1],
  96. [1, 1, 1, 1, 1, 1]],
  97. [[0, 0, 0, 0, 0, 0],
  98. [0, 0, 0, 0, 0, 0],
  99. [1, 1, 1, 1, 1, 1],
  100. [1, 1, 1, 1, 1, 1],
  101. [1, 1, 1, 1, 1, 1],
  102. [1, 1, 1, 1, 1, 1]]], dtype=torch.uint8)
  103. >>> make_pad_mask(lengths, xs, 2)
  104. tensor([[[0, 0, 0, 0, 0, 1],
  105. [0, 0, 0, 0, 0, 1],
  106. [0, 0, 0, 0, 0, 1],
  107. [0, 0, 0, 0, 0, 1],
  108. [0, 0, 0, 0, 0, 1],
  109. [0, 0, 0, 0, 0, 1]],
  110. [[0, 0, 0, 1, 1, 1],
  111. [0, 0, 0, 1, 1, 1],
  112. [0, 0, 0, 1, 1, 1],
  113. [0, 0, 0, 1, 1, 1],
  114. [0, 0, 0, 1, 1, 1],
  115. [0, 0, 0, 1, 1, 1]],
  116. [[0, 0, 1, 1, 1, 1],
  117. [0, 0, 1, 1, 1, 1],
  118. [0, 0, 1, 1, 1, 1],
  119. [0, 0, 1, 1, 1, 1],
  120. [0, 0, 1, 1, 1, 1],
  121. [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
  122. """
  123. if length_dim == 0:
  124. raise ValueError("length_dim cannot be 0: {}".format(length_dim))
  125. if not isinstance(lengths, list):
  126. lengths = lengths.tolist()
  127. bs = int(len(lengths))
  128. if maxlen is None:
  129. if xs is None:
  130. maxlen = int(max(lengths))
  131. else:
  132. maxlen = xs.size(length_dim)
  133. else:
  134. assert xs is None
  135. assert maxlen >= int(max(lengths))
  136. seq_range = torch.arange(0, maxlen, dtype=torch.int64)
  137. seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen)
  138. seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1)
  139. mask = seq_range_expand >= seq_length_expand
  140. if xs is not None:
  141. assert xs.size(0) == bs, (xs.size(0), bs)
  142. if length_dim < 0:
  143. length_dim = xs.dim() + length_dim
  144. # ind = (:, None, ..., None, :, , None, ..., None)
  145. ind = tuple(
  146. slice(None) if i in (0, length_dim) else None for i in range(xs.dim())
  147. )
  148. mask = mask[ind].expand_as(xs).to(xs.device)
  149. return mask
  150. def make_non_pad_mask(lengths, xs=None, length_dim=-1):
  151. """Make mask tensor containing indices of non-padded part.
  152. Args:
  153. lengths (LongTensor or List): Batch of lengths (B,).
  154. xs (Tensor, optional): The reference tensor.
  155. If set, masks will be the same shape as this tensor.
  156. length_dim (int, optional): Dimension indicator of the above tensor.
  157. See the example.
  158. Returns:
  159. ByteTensor: mask tensor containing indices of padded part.
  160. dtype=torch.uint8 in PyTorch 1.2-
  161. dtype=torch.bool in PyTorch 1.2+ (including 1.2)
  162. Examples:
  163. With only lengths.
  164. >>> lengths = [5, 3, 2]
  165. >>> make_non_pad_mask(lengths)
  166. masks = [[1, 1, 1, 1 ,1],
  167. [1, 1, 1, 0, 0],
  168. [1, 1, 0, 0, 0]]
  169. With the reference tensor.
  170. >>> xs = torch.zeros((3, 2, 4))
  171. >>> make_non_pad_mask(lengths, xs)
  172. tensor([[[1, 1, 1, 1],
  173. [1, 1, 1, 1]],
  174. [[1, 1, 1, 0],
  175. [1, 1, 1, 0]],
  176. [[1, 1, 0, 0],
  177. [1, 1, 0, 0]]], dtype=torch.uint8)
  178. >>> xs = torch.zeros((3, 2, 6))
  179. >>> make_non_pad_mask(lengths, xs)
  180. tensor([[[1, 1, 1, 1, 1, 0],
  181. [1, 1, 1, 1, 1, 0]],
  182. [[1, 1, 1, 0, 0, 0],
  183. [1, 1, 1, 0, 0, 0]],
  184. [[1, 1, 0, 0, 0, 0],
  185. [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
  186. With the reference tensor and dimension indicator.
  187. >>> xs = torch.zeros((3, 6, 6))
  188. >>> make_non_pad_mask(lengths, xs, 1)
  189. tensor([[[1, 1, 1, 1, 1, 1],
  190. [1, 1, 1, 1, 1, 1],
  191. [1, 1, 1, 1, 1, 1],
  192. [1, 1, 1, 1, 1, 1],
  193. [1, 1, 1, 1, 1, 1],
  194. [0, 0, 0, 0, 0, 0]],
  195. [[1, 1, 1, 1, 1, 1],
  196. [1, 1, 1, 1, 1, 1],
  197. [1, 1, 1, 1, 1, 1],
  198. [0, 0, 0, 0, 0, 0],
  199. [0, 0, 0, 0, 0, 0],
  200. [0, 0, 0, 0, 0, 0]],
  201. [[1, 1, 1, 1, 1, 1],
  202. [1, 1, 1, 1, 1, 1],
  203. [0, 0, 0, 0, 0, 0],
  204. [0, 0, 0, 0, 0, 0],
  205. [0, 0, 0, 0, 0, 0],
  206. [0, 0, 0, 0, 0, 0]]], dtype=torch.uint8)
  207. >>> make_non_pad_mask(lengths, xs, 2)
  208. tensor([[[1, 1, 1, 1, 1, 0],
  209. [1, 1, 1, 1, 1, 0],
  210. [1, 1, 1, 1, 1, 0],
  211. [1, 1, 1, 1, 1, 0],
  212. [1, 1, 1, 1, 1, 0],
  213. [1, 1, 1, 1, 1, 0]],
  214. [[1, 1, 1, 0, 0, 0],
  215. [1, 1, 1, 0, 0, 0],
  216. [1, 1, 1, 0, 0, 0],
  217. [1, 1, 1, 0, 0, 0],
  218. [1, 1, 1, 0, 0, 0],
  219. [1, 1, 1, 0, 0, 0]],
  220. [[1, 1, 0, 0, 0, 0],
  221. [1, 1, 0, 0, 0, 0],
  222. [1, 1, 0, 0, 0, 0],
  223. [1, 1, 0, 0, 0, 0],
  224. [1, 1, 0, 0, 0, 0],
  225. [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
  226. """
  227. return ~make_pad_mask(lengths, xs, length_dim)
  228. def mask_by_length(xs, lengths, fill=0):
  229. """Mask tensor according to length.
  230. Args:
  231. xs (Tensor): Batch of input tensor (B, `*`).
  232. lengths (LongTensor or List): Batch of lengths (B,).
  233. fill (int or float): Value to fill masked part.
  234. Returns:
  235. Tensor: Batch of masked input tensor (B, `*`).
  236. Examples:
  237. >>> x = torch.arange(5).repeat(3, 1) + 1
  238. >>> x
  239. tensor([[1, 2, 3, 4, 5],
  240. [1, 2, 3, 4, 5],
  241. [1, 2, 3, 4, 5]])
  242. >>> lengths = [5, 3, 2]
  243. >>> mask_by_length(x, lengths)
  244. tensor([[1, 2, 3, 4, 5],
  245. [1, 2, 3, 0, 0],
  246. [1, 2, 0, 0, 0]])
  247. """
  248. assert xs.size(0) == len(lengths)
  249. ret = xs.data.new(*xs.size()).fill_(fill)
  250. for i, l in enumerate(lengths):
  251. ret[i, :l] = xs[i, :l]
  252. return ret
  253. def th_accuracy(pad_outputs, pad_targets, ignore_label):
  254. """Calculate accuracy.
  255. Args:
  256. pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
  257. pad_targets (LongTensor): Target label tensors (B, Lmax, D).
  258. ignore_label (int): Ignore label id.
  259. Returns:
  260. float: Accuracy value (0.0 - 1.0).
  261. """
  262. pad_pred = pad_outputs.view(
  263. pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1)
  264. ).argmax(2)
  265. mask = pad_targets != ignore_label
  266. numerator = torch.sum(
  267. pad_pred.masked_select(mask) == pad_targets.masked_select(mask)
  268. )
  269. denominator = torch.sum(mask)
  270. return float(numerator) / float(denominator)
  271. def to_torch_tensor(x):
  272. """Change to torch.Tensor or ComplexTensor from numpy.ndarray.
  273. Args:
  274. x: Inputs. It should be one of numpy.ndarray, Tensor, ComplexTensor, and dict.
  275. Returns:
  276. Tensor or ComplexTensor: Type converted inputs.
  277. Examples:
  278. >>> xs = np.ones(3, dtype=np.float32)
  279. >>> xs = to_torch_tensor(xs)
  280. tensor([1., 1., 1.])
  281. >>> xs = torch.ones(3, 4, 5)
  282. >>> assert to_torch_tensor(xs) is xs
  283. >>> xs = {'real': xs, 'imag': xs}
  284. >>> to_torch_tensor(xs)
  285. ComplexTensor(
  286. Real:
  287. tensor([1., 1., 1.])
  288. Imag;
  289. tensor([1., 1., 1.])
  290. )
  291. """
  292. # If numpy, change to torch tensor
  293. if isinstance(x, np.ndarray):
  294. if x.dtype.kind == "c":
  295. # Dynamically importing because torch_complex requires python3
  296. from torch_complex.tensor import ComplexTensor
  297. return ComplexTensor(x)
  298. else:
  299. return torch.from_numpy(x)
  300. # If {'real': ..., 'imag': ...}, convert to ComplexTensor
  301. elif isinstance(x, dict):
  302. # Dynamically importing because torch_complex requires python3
  303. from torch_complex.tensor import ComplexTensor
  304. if "real" not in x or "imag" not in x:
  305. raise ValueError("has 'real' and 'imag' keys: {}".format(list(x)))
  306. # Relative importing because of using python3 syntax
  307. return ComplexTensor(x["real"], x["imag"])
  308. # If torch.Tensor, as it is
  309. elif isinstance(x, torch.Tensor):
  310. return x
  311. else:
  312. error = (
  313. "x must be numpy.ndarray, torch.Tensor or a dict like "
  314. "{{'real': torch.Tensor, 'imag': torch.Tensor}}, "
  315. "but got {}".format(type(x))
  316. )
  317. try:
  318. from torch_complex.tensor import ComplexTensor
  319. except Exception:
  320. # If PY2
  321. raise ValueError(error)
  322. else:
  323. # If PY3
  324. if isinstance(x, ComplexTensor):
  325. return x
  326. else:
  327. raise ValueError(error)
  328. def get_subsample(train_args, mode, arch):
  329. """Parse the subsampling factors from the args for the specified `mode` and `arch`.
  330. Args:
  331. train_args: argument Namespace containing options.
  332. mode: one of ('asr', 'mt', 'st')
  333. arch: one of ('rnn', 'rnn-t', 'rnn_mix', 'rnn_mulenc', 'transformer')
  334. Returns:
  335. np.ndarray / List[np.ndarray]: subsampling factors.
  336. """
  337. if arch == "transformer":
  338. return np.array([1])
  339. elif mode == "mt" and arch == "rnn":
  340. # +1 means input (+1) and layers outputs (train_args.elayer)
  341. subsample = np.ones(train_args.elayers + 1, dtype=np.int)
  342. logging.warning("Subsampling is not performed for machine translation.")
  343. logging.info("subsample: " + " ".join([str(x) for x in subsample]))
  344. return subsample
  345. elif (
  346. (mode == "asr" and arch in ("rnn", "rnn-t"))
  347. or (mode == "mt" and arch == "rnn")
  348. or (mode == "st" and arch == "rnn")
  349. ):
  350. subsample = np.ones(train_args.elayers + 1, dtype=np.int)
  351. if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
  352. ss = train_args.subsample.split("_")
  353. for j in range(min(train_args.elayers + 1, len(ss))):
  354. subsample[j] = int(ss[j])
  355. else:
  356. logging.warning(
  357. "Subsampling is not performed for vgg*. "
  358. "It is performed in max pooling layers at CNN."
  359. )
  360. logging.info("subsample: " + " ".join([str(x) for x in subsample]))
  361. return subsample
  362. elif mode == "asr" and arch == "rnn_mix":
  363. subsample = np.ones(
  364. train_args.elayers_sd + train_args.elayers + 1, dtype=np.int
  365. )
  366. if train_args.etype.endswith("p") and not train_args.etype.startswith("vgg"):
  367. ss = train_args.subsample.split("_")
  368. for j in range(
  369. min(train_args.elayers_sd + train_args.elayers + 1, len(ss))
  370. ):
  371. subsample[j] = int(ss[j])
  372. else:
  373. logging.warning(
  374. "Subsampling is not performed for vgg*. "
  375. "It is performed in max pooling layers at CNN."
  376. )
  377. logging.info("subsample: " + " ".join([str(x) for x in subsample]))
  378. return subsample
  379. elif mode == "asr" and arch == "rnn_mulenc":
  380. subsample_list = []
  381. for idx in range(train_args.num_encs):
  382. subsample = np.ones(train_args.elayers[idx] + 1, dtype=np.int)
  383. if train_args.etype[idx].endswith("p") and not train_args.etype[
  384. idx
  385. ].startswith("vgg"):
  386. ss = train_args.subsample[idx].split("_")
  387. for j in range(min(train_args.elayers[idx] + 1, len(ss))):
  388. subsample[j] = int(ss[j])
  389. else:
  390. logging.warning(
  391. "Encoder %d: Subsampling is not performed for vgg*. "
  392. "It is performed in max pooling layers at CNN.",
  393. idx + 1,
  394. )
  395. logging.info("subsample: " + " ".join([str(x) for x in subsample]))
  396. subsample_list.append(subsample)
  397. return subsample_list
  398. else:
  399. raise ValueError("Invalid options: mode={}, arch={}".format(mode, arch))
  400. def rename_state_dict(
  401. old_prefix: str, new_prefix: str, state_dict: Dict[str, torch.Tensor]
  402. ):
  403. """Replace keys of old prefix with new prefix in state dict."""
  404. # need this list not to break the dict iterator
  405. old_keys = [k for k in state_dict if k.startswith(old_prefix)]
  406. if len(old_keys) > 0:
  407. logging.warning(f"Rename: {old_prefix} -> {new_prefix}")
  408. for k in old_keys:
  409. v = state_dict.pop(k)
  410. new_k = k.replace(old_prefix, new_prefix)
  411. state_dict[new_k] = v
  412. class Swish(torch.nn.Module):
  413. """Construct an Swish object."""
  414. def forward(self, x):
  415. """Return Swich activation function."""
  416. return x * torch.sigmoid(x)
  417. def get_activation(act):
  418. """Return activation function."""
  419. activation_funcs = {
  420. "hardtanh": torch.nn.Hardtanh,
  421. "tanh": torch.nn.Tanh,
  422. "relu": torch.nn.ReLU,
  423. "selu": torch.nn.SELU,
  424. "swish": Swish,
  425. }
  426. return activation_funcs[act]()