mask_along_axis.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. import math
  2. import torch
  3. from typeguard import check_argument_types
  4. from typing import Sequence
  5. from typing import Union
  6. def mask_along_axis(
  7. spec: torch.Tensor,
  8. spec_lengths: torch.Tensor,
  9. mask_width_range: Sequence[int] = (0, 30),
  10. dim: int = 1,
  11. num_mask: int = 2,
  12. replace_with_zero: bool = True,
  13. ):
  14. """Apply mask along the specified direction.
  15. Args:
  16. spec: (Batch, Length, Freq)
  17. spec_lengths: (Length): Not using lengths in this implementation
  18. mask_width_range: Select the width randomly between this range
  19. """
  20. org_size = spec.size()
  21. if spec.dim() == 4:
  22. # spec: (Batch, Channel, Length, Freq) -> (Batch * Channel, Length, Freq)
  23. spec = spec.view(-1, spec.size(2), spec.size(3))
  24. B = spec.shape[0]
  25. # D = Length or Freq
  26. D = spec.shape[dim]
  27. # mask_length: (B, num_mask, 1)
  28. mask_length = torch.randint(
  29. mask_width_range[0],
  30. mask_width_range[1],
  31. (B, num_mask),
  32. device=spec.device,
  33. ).unsqueeze(2)
  34. # mask_pos: (B, num_mask, 1)
  35. mask_pos = torch.randint(
  36. 0, max(1, D - mask_length.max()), (B, num_mask), device=spec.device
  37. ).unsqueeze(2)
  38. # aran: (1, 1, D)
  39. aran = torch.arange(D, device=spec.device)[None, None, :]
  40. # mask: (Batch, num_mask, D)
  41. mask = (mask_pos <= aran) * (aran < (mask_pos + mask_length))
  42. # Multiply masks: (Batch, num_mask, D) -> (Batch, D)
  43. mask = mask.any(dim=1)
  44. if dim == 1:
  45. # mask: (Batch, Length, 1)
  46. mask = mask.unsqueeze(2)
  47. elif dim == 2:
  48. # mask: (Batch, 1, Freq)
  49. mask = mask.unsqueeze(1)
  50. if replace_with_zero:
  51. value = 0.0
  52. else:
  53. value = spec.mean()
  54. if spec.requires_grad:
  55. spec = spec.masked_fill(mask, value)
  56. else:
  57. spec = spec.masked_fill_(mask, value)
  58. spec = spec.view(*org_size)
  59. return spec, spec_lengths
  60. def mask_along_axis_lfr(
  61. spec: torch.Tensor,
  62. spec_lengths: torch.Tensor,
  63. mask_width_range: Sequence[int] = (0, 30),
  64. dim: int = 1,
  65. num_mask: int = 2,
  66. replace_with_zero: bool = True,
  67. lfr_rate: int = 1,
  68. ):
  69. """Apply mask along the specified direction.
  70. Args:
  71. spec: (Batch, Length, Freq)
  72. spec_lengths: (Length): Not using lengths in this implementation
  73. mask_width_range: Select the width randomly between this range
  74. lfr_rate:low frame rate
  75. """
  76. org_size = spec.size()
  77. if spec.dim() == 4:
  78. # spec: (Batch, Channel, Length, Freq) -> (Batch * Channel, Length, Freq)
  79. spec = spec.view(-1, spec.size(2), spec.size(3))
  80. B = spec.shape[0]
  81. # D = Length or Freq
  82. D = spec.shape[dim] // lfr_rate
  83. # mask_length: (B, num_mask, 1)
  84. mask_length = torch.randint(
  85. mask_width_range[0],
  86. mask_width_range[1],
  87. (B, num_mask),
  88. device=spec.device,
  89. ).unsqueeze(2)
  90. if lfr_rate > 1:
  91. mask_length = mask_length.repeat(1, lfr_rate, 1)
  92. # mask_pos: (B, num_mask, 1)
  93. mask_pos = torch.randint(
  94. 0, max(1, D - mask_length.max()), (B, num_mask), device=spec.device
  95. ).unsqueeze(2)
  96. if lfr_rate > 1:
  97. mask_pos_raw = mask_pos.clone()
  98. mask_pos = torch.zeros((B, 0, 1), device=spec.device, dtype=torch.int32)
  99. for i in range(lfr_rate):
  100. mask_pos_i = mask_pos_raw + D * i
  101. mask_pos = torch.cat((mask_pos, mask_pos_i), dim=1)
  102. # aran: (1, 1, D)
  103. D = spec.shape[dim]
  104. aran = torch.arange(D, device=spec.device)[None, None, :]
  105. # mask: (Batch, num_mask, D)
  106. mask = (mask_pos <= aran) * (aran < (mask_pos + mask_length))
  107. # Multiply masks: (Batch, num_mask, D) -> (Batch, D)
  108. mask = mask.any(dim=1)
  109. if dim == 1:
  110. # mask: (Batch, Length, 1)
  111. mask = mask.unsqueeze(2)
  112. elif dim == 2:
  113. # mask: (Batch, 1, Freq)
  114. mask = mask.unsqueeze(1)
  115. if replace_with_zero:
  116. value = 0.0
  117. else:
  118. value = spec.mean()
  119. if spec.requires_grad:
  120. spec = spec.masked_fill(mask, value)
  121. else:
  122. spec = spec.masked_fill_(mask, value)
  123. spec = spec.view(*org_size)
  124. return spec, spec_lengths
  125. class MaskAlongAxis(torch.nn.Module):
  126. def __init__(
  127. self,
  128. mask_width_range: Union[int, Sequence[int]] = (0, 30),
  129. num_mask: int = 2,
  130. dim: Union[int, str] = "time",
  131. replace_with_zero: bool = True,
  132. ):
  133. assert check_argument_types()
  134. if isinstance(mask_width_range, int):
  135. mask_width_range = (0, mask_width_range)
  136. if len(mask_width_range) != 2:
  137. raise TypeError(
  138. f"mask_width_range must be a tuple of int and int values: "
  139. f"{mask_width_range}",
  140. )
  141. assert mask_width_range[1] > mask_width_range[0]
  142. if isinstance(dim, str):
  143. if dim == "time":
  144. dim = 1
  145. elif dim == "freq":
  146. dim = 2
  147. else:
  148. raise ValueError("dim must be int, 'time' or 'freq'")
  149. if dim == 1:
  150. self.mask_axis = "time"
  151. elif dim == 2:
  152. self.mask_axis = "freq"
  153. else:
  154. self.mask_axis = "unknown"
  155. super().__init__()
  156. self.mask_width_range = mask_width_range
  157. self.num_mask = num_mask
  158. self.dim = dim
  159. self.replace_with_zero = replace_with_zero
  160. def extra_repr(self):
  161. return (
  162. f"mask_width_range={self.mask_width_range}, "
  163. f"num_mask={self.num_mask}, axis={self.mask_axis}"
  164. )
  165. def forward(self, spec: torch.Tensor, spec_lengths: torch.Tensor = None):
  166. """Forward function.
  167. Args:
  168. spec: (Batch, Length, Freq)
  169. """
  170. return mask_along_axis(
  171. spec,
  172. spec_lengths,
  173. mask_width_range=self.mask_width_range,
  174. dim=self.dim,
  175. num_mask=self.num_mask,
  176. replace_with_zero=self.replace_with_zero,
  177. )
  178. class MaskAlongAxisVariableMaxWidth(torch.nn.Module):
  179. """Mask input spec along a specified axis with variable maximum width.
  180. Formula:
  181. max_width = max_width_ratio * seq_len
  182. """
  183. def __init__(
  184. self,
  185. mask_width_ratio_range: Union[float, Sequence[float]] = (0.0, 0.05),
  186. num_mask: int = 2,
  187. dim: Union[int, str] = "time",
  188. replace_with_zero: bool = True,
  189. ):
  190. assert check_argument_types()
  191. if isinstance(mask_width_ratio_range, float):
  192. mask_width_ratio_range = (0.0, mask_width_ratio_range)
  193. if len(mask_width_ratio_range) != 2:
  194. raise TypeError(
  195. f"mask_width_ratio_range must be a tuple of float and float values: "
  196. f"{mask_width_ratio_range}",
  197. )
  198. assert mask_width_ratio_range[1] > mask_width_ratio_range[0]
  199. if isinstance(dim, str):
  200. if dim == "time":
  201. dim = 1
  202. elif dim == "freq":
  203. dim = 2
  204. else:
  205. raise ValueError("dim must be int, 'time' or 'freq'")
  206. if dim == 1:
  207. self.mask_axis = "time"
  208. elif dim == 2:
  209. self.mask_axis = "freq"
  210. else:
  211. self.mask_axis = "unknown"
  212. super().__init__()
  213. self.mask_width_ratio_range = mask_width_ratio_range
  214. self.num_mask = num_mask
  215. self.dim = dim
  216. self.replace_with_zero = replace_with_zero
  217. def extra_repr(self):
  218. return (
  219. f"mask_width_ratio_range={self.mask_width_ratio_range}, "
  220. f"num_mask={self.num_mask}, axis={self.mask_axis}"
  221. )
  222. def forward(self, spec: torch.Tensor, spec_lengths: torch.Tensor = None):
  223. """Forward function.
  224. Args:
  225. spec: (Batch, Length, Freq)
  226. """
  227. max_seq_len = spec.shape[self.dim]
  228. min_mask_width = math.floor(max_seq_len * self.mask_width_ratio_range[0])
  229. min_mask_width = max([0, min_mask_width])
  230. max_mask_width = math.floor(max_seq_len * self.mask_width_ratio_range[1])
  231. max_mask_width = min([max_seq_len, max_mask_width])
  232. if max_mask_width > min_mask_width:
  233. return mask_along_axis(
  234. spec,
  235. spec_lengths,
  236. mask_width_range=(min_mask_width, max_mask_width),
  237. dim=self.dim,
  238. num_mask=self.num_mask,
  239. replace_with_zero=self.replace_with_zero,
  240. )
  241. return spec, spec_lengths
  242. class MaskAlongAxisLFR(torch.nn.Module):
  243. def __init__(
  244. self,
  245. mask_width_range: Union[int, Sequence[int]] = (0, 30),
  246. num_mask: int = 2,
  247. dim: Union[int, str] = "time",
  248. replace_with_zero: bool = True,
  249. lfr_rate: int = 1,
  250. ):
  251. assert check_argument_types()
  252. if isinstance(mask_width_range, int):
  253. mask_width_range = (0, mask_width_range)
  254. if len(mask_width_range) != 2:
  255. raise TypeError(
  256. f"mask_width_range must be a tuple of int and int values: "
  257. f"{mask_width_range}",
  258. )
  259. assert mask_width_range[1] > mask_width_range[0]
  260. if isinstance(dim, str):
  261. if dim == "time":
  262. dim = 1
  263. lfr_rate = 1
  264. elif dim == "freq":
  265. dim = 2
  266. else:
  267. raise ValueError("dim must be int, 'time' or 'freq'")
  268. if dim == 1:
  269. self.mask_axis = "time"
  270. lfr_rate = 1
  271. elif dim == 2:
  272. self.mask_axis = "freq"
  273. else:
  274. self.mask_axis = "unknown"
  275. super().__init__()
  276. self.mask_width_range = mask_width_range
  277. self.num_mask = num_mask
  278. self.dim = dim
  279. self.replace_with_zero = replace_with_zero
  280. self.lfr_rate = lfr_rate
  281. def extra_repr(self):
  282. return (
  283. f"mask_width_range={self.mask_width_range}, "
  284. f"num_mask={self.num_mask}, axis={self.mask_axis}"
  285. )
  286. def forward(self, spec: torch.Tensor, spec_lengths: torch.Tensor = None):
  287. """Forward function.
  288. Args:
  289. spec: (Batch, Length, Freq)
  290. """
  291. return mask_along_axis_lfr(
  292. spec,
  293. spec_lengths,
  294. mask_width_range=self.mask_width_range,
  295. dim=self.dim,
  296. num_mask=self.num_mask,
  297. replace_with_zero=self.replace_with_zero,
  298. lfr_rate=self.lfr_rate,
  299. )