mask_along_axis.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. import math
  2. import torch
  3. from typing import Sequence
  4. from typing import Union
  5. def mask_along_axis(
  6. spec: torch.Tensor,
  7. spec_lengths: torch.Tensor,
  8. mask_width_range: Sequence[int] = (0, 30),
  9. dim: int = 1,
  10. num_mask: int = 2,
  11. replace_with_zero: bool = True,
  12. ):
  13. """Apply mask along the specified direction.
  14. Args:
  15. spec: (Batch, Length, Freq)
  16. spec_lengths: (Length): Not using lengths in this implementation
  17. mask_width_range: Select the width randomly between this range
  18. """
  19. org_size = spec.size()
  20. if spec.dim() == 4:
  21. # spec: (Batch, Channel, Length, Freq) -> (Batch * Channel, Length, Freq)
  22. spec = spec.view(-1, spec.size(2), spec.size(3))
  23. B = spec.shape[0]
  24. # D = Length or Freq
  25. D = spec.shape[dim]
  26. # mask_length: (B, num_mask, 1)
  27. mask_length = torch.randint(
  28. mask_width_range[0],
  29. mask_width_range[1],
  30. (B, num_mask),
  31. device=spec.device,
  32. ).unsqueeze(2)
  33. # mask_pos: (B, num_mask, 1)
  34. mask_pos = torch.randint(
  35. 0, max(1, D - mask_length.max()), (B, num_mask), device=spec.device
  36. ).unsqueeze(2)
  37. # aran: (1, 1, D)
  38. aran = torch.arange(D, device=spec.device)[None, None, :]
  39. # mask: (Batch, num_mask, D)
  40. mask = (mask_pos <= aran) * (aran < (mask_pos + mask_length))
  41. # Multiply masks: (Batch, num_mask, D) -> (Batch, D)
  42. mask = mask.any(dim=1)
  43. if dim == 1:
  44. # mask: (Batch, Length, 1)
  45. mask = mask.unsqueeze(2)
  46. elif dim == 2:
  47. # mask: (Batch, 1, Freq)
  48. mask = mask.unsqueeze(1)
  49. if replace_with_zero:
  50. value = 0.0
  51. else:
  52. value = spec.mean()
  53. if spec.requires_grad:
  54. spec = spec.masked_fill(mask, value)
  55. else:
  56. spec = spec.masked_fill_(mask, value)
  57. spec = spec.view(*org_size)
  58. return spec, spec_lengths
  59. def mask_along_axis_lfr(
  60. spec: torch.Tensor,
  61. spec_lengths: torch.Tensor,
  62. mask_width_range: Sequence[int] = (0, 30),
  63. dim: int = 1,
  64. num_mask: int = 2,
  65. replace_with_zero: bool = True,
  66. lfr_rate: int = 1,
  67. ):
  68. """Apply mask along the specified direction.
  69. Args:
  70. spec: (Batch, Length, Freq)
  71. spec_lengths: (Length): Not using lengths in this implementation
  72. mask_width_range: Select the width randomly between this range
  73. lfr_rate:low frame rate
  74. """
  75. org_size = spec.size()
  76. if spec.dim() == 4:
  77. # spec: (Batch, Channel, Length, Freq) -> (Batch * Channel, Length, Freq)
  78. spec = spec.view(-1, spec.size(2), spec.size(3))
  79. B = spec.shape[0]
  80. # D = Length or Freq
  81. D = spec.shape[dim] // lfr_rate
  82. # mask_length: (B, num_mask, 1)
  83. mask_length = torch.randint(
  84. mask_width_range[0],
  85. mask_width_range[1],
  86. (B, num_mask),
  87. device=spec.device,
  88. ).unsqueeze(2)
  89. if lfr_rate > 1:
  90. mask_length = mask_length.repeat(1, lfr_rate, 1)
  91. # mask_pos: (B, num_mask, 1)
  92. mask_pos = torch.randint(
  93. 0, max(1, D - mask_length.max()), (B, num_mask), device=spec.device
  94. ).unsqueeze(2)
  95. if lfr_rate > 1:
  96. mask_pos_raw = mask_pos.clone()
  97. mask_pos = torch.zeros((B, 0, 1), device=spec.device, dtype=torch.int32)
  98. for i in range(lfr_rate):
  99. mask_pos_i = mask_pos_raw + D * i
  100. mask_pos = torch.cat((mask_pos, mask_pos_i), dim=1)
  101. # aran: (1, 1, D)
  102. D = spec.shape[dim]
  103. aran = torch.arange(D, device=spec.device)[None, None, :]
  104. # mask: (Batch, num_mask, D)
  105. mask = (mask_pos <= aran) * (aran < (mask_pos + mask_length))
  106. # Multiply masks: (Batch, num_mask, D) -> (Batch, D)
  107. mask = mask.any(dim=1)
  108. if dim == 1:
  109. # mask: (Batch, Length, 1)
  110. mask = mask.unsqueeze(2)
  111. elif dim == 2:
  112. # mask: (Batch, 1, Freq)
  113. mask = mask.unsqueeze(1)
  114. if replace_with_zero:
  115. value = 0.0
  116. else:
  117. value = spec.mean()
  118. if spec.requires_grad:
  119. spec = spec.masked_fill(mask, value)
  120. else:
  121. spec = spec.masked_fill_(mask, value)
  122. spec = spec.view(*org_size)
  123. return spec, spec_lengths
  124. class MaskAlongAxis(torch.nn.Module):
  125. def __init__(
  126. self,
  127. mask_width_range: Union[int, Sequence[int]] = (0, 30),
  128. num_mask: int = 2,
  129. dim: Union[int, str] = "time",
  130. replace_with_zero: bool = True,
  131. ):
  132. if isinstance(mask_width_range, int):
  133. mask_width_range = (0, mask_width_range)
  134. if len(mask_width_range) != 2:
  135. raise TypeError(
  136. f"mask_width_range must be a tuple of int and int values: "
  137. f"{mask_width_range}",
  138. )
  139. assert mask_width_range[1] > mask_width_range[0]
  140. if isinstance(dim, str):
  141. if dim == "time":
  142. dim = 1
  143. elif dim == "freq":
  144. dim = 2
  145. else:
  146. raise ValueError("dim must be int, 'time' or 'freq'")
  147. if dim == 1:
  148. self.mask_axis = "time"
  149. elif dim == 2:
  150. self.mask_axis = "freq"
  151. else:
  152. self.mask_axis = "unknown"
  153. super().__init__()
  154. self.mask_width_range = mask_width_range
  155. self.num_mask = num_mask
  156. self.dim = dim
  157. self.replace_with_zero = replace_with_zero
  158. def extra_repr(self):
  159. return (
  160. f"mask_width_range={self.mask_width_range}, "
  161. f"num_mask={self.num_mask}, axis={self.mask_axis}"
  162. )
  163. def forward(self, spec: torch.Tensor, spec_lengths: torch.Tensor = None):
  164. """Forward function.
  165. Args:
  166. spec: (Batch, Length, Freq)
  167. """
  168. return mask_along_axis(
  169. spec,
  170. spec_lengths,
  171. mask_width_range=self.mask_width_range,
  172. dim=self.dim,
  173. num_mask=self.num_mask,
  174. replace_with_zero=self.replace_with_zero,
  175. )
  176. class MaskAlongAxisVariableMaxWidth(torch.nn.Module):
  177. """Mask input spec along a specified axis with variable maximum width.
  178. Formula:
  179. max_width = max_width_ratio * seq_len
  180. """
  181. def __init__(
  182. self,
  183. mask_width_ratio_range: Union[float, Sequence[float]] = (0.0, 0.05),
  184. num_mask: int = 2,
  185. dim: Union[int, str] = "time",
  186. replace_with_zero: bool = True,
  187. ):
  188. if isinstance(mask_width_ratio_range, float):
  189. mask_width_ratio_range = (0.0, mask_width_ratio_range)
  190. if len(mask_width_ratio_range) != 2:
  191. raise TypeError(
  192. f"mask_width_ratio_range must be a tuple of float and float values: "
  193. f"{mask_width_ratio_range}",
  194. )
  195. assert mask_width_ratio_range[1] > mask_width_ratio_range[0]
  196. if isinstance(dim, str):
  197. if dim == "time":
  198. dim = 1
  199. elif dim == "freq":
  200. dim = 2
  201. else:
  202. raise ValueError("dim must be int, 'time' or 'freq'")
  203. if dim == 1:
  204. self.mask_axis = "time"
  205. elif dim == 2:
  206. self.mask_axis = "freq"
  207. else:
  208. self.mask_axis = "unknown"
  209. super().__init__()
  210. self.mask_width_ratio_range = mask_width_ratio_range
  211. self.num_mask = num_mask
  212. self.dim = dim
  213. self.replace_with_zero = replace_with_zero
  214. def extra_repr(self):
  215. return (
  216. f"mask_width_ratio_range={self.mask_width_ratio_range}, "
  217. f"num_mask={self.num_mask}, axis={self.mask_axis}"
  218. )
  219. def forward(self, spec: torch.Tensor, spec_lengths: torch.Tensor = None):
  220. """Forward function.
  221. Args:
  222. spec: (Batch, Length, Freq)
  223. """
  224. max_seq_len = spec.shape[self.dim]
  225. min_mask_width = math.floor(max_seq_len * self.mask_width_ratio_range[0])
  226. min_mask_width = max([0, min_mask_width])
  227. max_mask_width = math.floor(max_seq_len * self.mask_width_ratio_range[1])
  228. max_mask_width = min([max_seq_len, max_mask_width])
  229. if max_mask_width > min_mask_width:
  230. return mask_along_axis(
  231. spec,
  232. spec_lengths,
  233. mask_width_range=(min_mask_width, max_mask_width),
  234. dim=self.dim,
  235. num_mask=self.num_mask,
  236. replace_with_zero=self.replace_with_zero,
  237. )
  238. return spec, spec_lengths
  239. class MaskAlongAxisLFR(torch.nn.Module):
  240. def __init__(
  241. self,
  242. mask_width_range: Union[int, Sequence[int]] = (0, 30),
  243. num_mask: int = 2,
  244. dim: Union[int, str] = "time",
  245. replace_with_zero: bool = True,
  246. lfr_rate: int = 1,
  247. ):
  248. if isinstance(mask_width_range, int):
  249. mask_width_range = (0, mask_width_range)
  250. if len(mask_width_range) != 2:
  251. raise TypeError(
  252. f"mask_width_range must be a tuple of int and int values: "
  253. f"{mask_width_range}",
  254. )
  255. assert mask_width_range[1] > mask_width_range[0]
  256. if isinstance(dim, str):
  257. if dim == "time":
  258. dim = 1
  259. lfr_rate = 1
  260. elif dim == "freq":
  261. dim = 2
  262. else:
  263. raise ValueError("dim must be int, 'time' or 'freq'")
  264. if dim == 1:
  265. self.mask_axis = "time"
  266. lfr_rate = 1
  267. elif dim == 2:
  268. self.mask_axis = "freq"
  269. else:
  270. self.mask_axis = "unknown"
  271. super().__init__()
  272. self.mask_width_range = mask_width_range
  273. self.num_mask = num_mask
  274. self.dim = dim
  275. self.replace_with_zero = replace_with_zero
  276. self.lfr_rate = lfr_rate
  277. def extra_repr(self):
  278. return (
  279. f"mask_width_range={self.mask_width_range}, "
  280. f"num_mask={self.num_mask}, axis={self.mask_axis}"
  281. )
  282. def forward(self, spec: torch.Tensor, spec_lengths: torch.Tensor = None):
  283. """Forward function.
  284. Args:
  285. spec: (Batch, Length, Freq)
  286. """
  287. return mask_along_axis_lfr(
  288. spec,
  289. spec_lengths,
  290. mask_width_range=self.mask_width_range,
  291. dim=self.dim,
  292. num_mask=self.num_mask,
  293. replace_with_zero=self.replace_with_zero,
  294. lfr_rate=self.lfr_rate,
  295. )