time_warp.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. """Time warp module."""
  2. import torch
  3. from funasr.modules.nets_utils import pad_list
  4. DEFAULT_TIME_WARP_MODE = "bicubic"
  5. def time_warp(x: torch.Tensor, window: int = 80, mode: str = DEFAULT_TIME_WARP_MODE):
  6. """Time warping using torch.interpolate.
  7. Args:
  8. x: (Batch, Time, Freq)
  9. window: time warp parameter
  10. mode: Interpolate mode
  11. """
  12. # bicubic supports 4D or more dimension tensor
  13. org_size = x.size()
  14. if x.dim() == 3:
  15. # x: (Batch, Time, Freq) -> (Batch, 1, Time, Freq)
  16. x = x[:, None]
  17. t = x.shape[2]
  18. if t - window <= window:
  19. return x.view(*org_size)
  20. center = torch.randint(window, t - window, (1,))[0]
  21. warped = torch.randint(center - window, center + window, (1,))[0] + 1
  22. # left: (Batch, Channel, warped, Freq)
  23. # right: (Batch, Channel, time - warped, Freq)
  24. left = torch.nn.functional.interpolate(
  25. x[:, :, :center], (warped, x.shape[3]), mode=mode, align_corners=False
  26. )
  27. right = torch.nn.functional.interpolate(
  28. x[:, :, center:], (t - warped, x.shape[3]), mode=mode, align_corners=False
  29. )
  30. if x.requires_grad:
  31. x = torch.cat([left, right], dim=-2)
  32. else:
  33. x[:, :, :warped] = left
  34. x[:, :, warped:] = right
  35. return x.view(*org_size)
  36. class TimeWarp(torch.nn.Module):
  37. """Time warping using torch.interpolate.
  38. Args:
  39. window: time warp parameter
  40. mode: Interpolate mode
  41. """
  42. def __init__(self, window: int = 80, mode: str = DEFAULT_TIME_WARP_MODE):
  43. super().__init__()
  44. self.window = window
  45. self.mode = mode
  46. def extra_repr(self):
  47. return f"window={self.window}, mode={self.mode}"
  48. def forward(self, x: torch.Tensor, x_lengths: torch.Tensor = None):
  49. """Forward function.
  50. Args:
  51. x: (Batch, Time, Freq)
  52. x_lengths: (Batch,)
  53. """
  54. if x_lengths is None or all(le == x_lengths[0] for le in x_lengths):
  55. # Note that applying same warping for each sample
  56. y = time_warp(x, window=self.window, mode=self.mode)
  57. else:
  58. # FIXME(kamo): I have no idea to batchify Timewarp
  59. ys = []
  60. for i in range(x.size(0)):
  61. _y = time_warp(
  62. x[i][None, : x_lengths[i]],
  63. window=self.window,
  64. mode=self.mode,
  65. )[0]
  66. ys.append(_y)
  67. y = pad_list(ys, 0.0)
  68. return y, x_lengths