specaug.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. """SpecAugment module."""
  2. from typing import Optional
  3. from typing import Sequence
  4. from typing import Union
  5. from funasr.models.specaug.abs_specaug import AbsSpecAug
  6. from funasr.layers.mask_along_axis import MaskAlongAxis
  7. from funasr.layers.mask_along_axis import MaskAlongAxisVariableMaxWidth
  8. from funasr.layers.mask_along_axis import MaskAlongAxisLFR
  9. from funasr.layers.time_warp import TimeWarp
  10. class SpecAug(AbsSpecAug):
  11. """Implementation of SpecAug.
  12. Reference:
  13. Daniel S. Park et al.
  14. "SpecAugment: A Simple Data
  15. Augmentation Method for Automatic Speech Recognition"
  16. .. warning::
  17. When using cuda mode, time_warp doesn't have reproducibility
  18. due to `torch.nn.functional.interpolate`.
  19. """
  20. def __init__(
  21. self,
  22. apply_time_warp: bool = True,
  23. time_warp_window: int = 5,
  24. time_warp_mode: str = "bicubic",
  25. apply_freq_mask: bool = True,
  26. freq_mask_width_range: Union[int, Sequence[int]] = (0, 20),
  27. num_freq_mask: int = 2,
  28. apply_time_mask: bool = True,
  29. time_mask_width_range: Optional[Union[int, Sequence[int]]] = None,
  30. time_mask_width_ratio_range: Optional[Union[float, Sequence[float]]] = None,
  31. num_time_mask: int = 2,
  32. ):
  33. if not apply_time_warp and not apply_time_mask and not apply_freq_mask:
  34. raise ValueError(
  35. "Either one of time_warp, time_mask, or freq_mask should be applied"
  36. )
  37. if (
  38. apply_time_mask
  39. and (time_mask_width_range is not None)
  40. and (time_mask_width_ratio_range is not None)
  41. ):
  42. raise ValueError(
  43. 'Either one of "time_mask_width_range" or '
  44. '"time_mask_width_ratio_range" can be used'
  45. )
  46. super().__init__()
  47. self.apply_time_warp = apply_time_warp
  48. self.apply_freq_mask = apply_freq_mask
  49. self.apply_time_mask = apply_time_mask
  50. if apply_time_warp:
  51. self.time_warp = TimeWarp(window=time_warp_window, mode=time_warp_mode)
  52. else:
  53. self.time_warp = None
  54. if apply_freq_mask:
  55. self.freq_mask = MaskAlongAxis(
  56. dim="freq",
  57. mask_width_range=freq_mask_width_range,
  58. num_mask=num_freq_mask,
  59. )
  60. else:
  61. self.freq_mask = None
  62. if apply_time_mask:
  63. if time_mask_width_range is not None:
  64. self.time_mask = MaskAlongAxis(
  65. dim="time",
  66. mask_width_range=time_mask_width_range,
  67. num_mask=num_time_mask,
  68. )
  69. elif time_mask_width_ratio_range is not None:
  70. self.time_mask = MaskAlongAxisVariableMaxWidth(
  71. dim="time",
  72. mask_width_ratio_range=time_mask_width_ratio_range,
  73. num_mask=num_time_mask,
  74. )
  75. else:
  76. raise ValueError(
  77. 'Either one of "time_mask_width_range" or '
  78. '"time_mask_width_ratio_range" should be used.'
  79. )
  80. else:
  81. self.time_mask = None
  82. def forward(self, x, x_lengths=None):
  83. if self.time_warp is not None:
  84. x, x_lengths = self.time_warp(x, x_lengths)
  85. if self.freq_mask is not None:
  86. x, x_lengths = self.freq_mask(x, x_lengths)
  87. if self.time_mask is not None:
  88. x, x_lengths = self.time_mask(x, x_lengths)
  89. return x, x_lengths
  90. class SpecAugLFR(AbsSpecAug):
  91. """Implementation of SpecAug.
  92. lfr_rate:low frame rate
  93. """
  94. def __init__(
  95. self,
  96. apply_time_warp: bool = True,
  97. time_warp_window: int = 5,
  98. time_warp_mode: str = "bicubic",
  99. apply_freq_mask: bool = True,
  100. freq_mask_width_range: Union[int, Sequence[int]] = (0, 20),
  101. num_freq_mask: int = 2,
  102. lfr_rate: int = 0,
  103. apply_time_mask: bool = True,
  104. time_mask_width_range: Optional[Union[int, Sequence[int]]] = None,
  105. time_mask_width_ratio_range: Optional[Union[float, Sequence[float]]] = None,
  106. num_time_mask: int = 2,
  107. ):
  108. if not apply_time_warp and not apply_time_mask and not apply_freq_mask:
  109. raise ValueError(
  110. "Either one of time_warp, time_mask, or freq_mask should be applied"
  111. )
  112. if (
  113. apply_time_mask
  114. and (time_mask_width_range is not None)
  115. and (time_mask_width_ratio_range is not None)
  116. ):
  117. raise ValueError(
  118. 'Either one of "time_mask_width_range" or '
  119. '"time_mask_width_ratio_range" can be used'
  120. )
  121. super().__init__()
  122. self.apply_time_warp = apply_time_warp
  123. self.apply_freq_mask = apply_freq_mask
  124. self.apply_time_mask = apply_time_mask
  125. if apply_time_warp:
  126. self.time_warp = TimeWarp(window=time_warp_window, mode=time_warp_mode)
  127. else:
  128. self.time_warp = None
  129. if apply_freq_mask:
  130. self.freq_mask = MaskAlongAxisLFR(
  131. dim="freq",
  132. mask_width_range=freq_mask_width_range,
  133. num_mask=num_freq_mask,
  134. lfr_rate=lfr_rate+1,
  135. )
  136. else:
  137. self.freq_mask = None
  138. if apply_time_mask:
  139. if time_mask_width_range is not None:
  140. self.time_mask = MaskAlongAxisLFR(
  141. dim="time",
  142. mask_width_range=time_mask_width_range,
  143. num_mask=num_time_mask,
  144. lfr_rate=lfr_rate + 1,
  145. )
  146. elif time_mask_width_ratio_range is not None:
  147. self.time_mask = MaskAlongAxisVariableMaxWidth(
  148. dim="time",
  149. mask_width_ratio_range=time_mask_width_ratio_range,
  150. num_mask=num_time_mask,
  151. )
  152. else:
  153. raise ValueError(
  154. 'Either one of "time_mask_width_range" or '
  155. '"time_mask_width_ratio_range" should be used.'
  156. )
  157. else:
  158. self.time_mask = None
  159. def forward(self, x, x_lengths=None):
  160. if self.time_warp is not None:
  161. x, x_lengths = self.time_warp(x, x_lengths)
  162. if self.freq_mask is not None:
  163. x, x_lengths = self.freq_mask(x, x_lengths)
  164. if self.time_mask is not None:
  165. x, x_lengths = self.time_mask(x, x_lengths)
  166. return x, x_lengths