| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788 |
- """Time warp module."""
- import torch
- from funasr.models.transformer.utils.nets_utils import pad_list
- DEFAULT_TIME_WARP_MODE = "bicubic"
- def time_warp(x: torch.Tensor, window: int = 80, mode: str = DEFAULT_TIME_WARP_MODE):
- """Time warping using torch.interpolate.
- Args:
- x: (Batch, Time, Freq)
- window: time warp parameter
- mode: Interpolate mode
- """
- # bicubic supports 4D or more dimension tensor
- org_size = x.size()
- if x.dim() == 3:
- # x: (Batch, Time, Freq) -> (Batch, 1, Time, Freq)
- x = x[:, None]
- t = x.shape[2]
- if t - window <= window:
- return x.view(*org_size)
- center = torch.randint(window, t - window, (1,))[0]
- warped = torch.randint(center - window, center + window, (1,))[0] + 1
- # left: (Batch, Channel, warped, Freq)
- # right: (Batch, Channel, time - warped, Freq)
- left = torch.nn.functional.interpolate(
- x[:, :, :center], (warped, x.shape[3]), mode=mode, align_corners=False
- )
- right = torch.nn.functional.interpolate(
- x[:, :, center:], (t - warped, x.shape[3]), mode=mode, align_corners=False
- )
- if x.requires_grad:
- x = torch.cat([left, right], dim=-2)
- else:
- x[:, :, :warped] = left
- x[:, :, warped:] = right
- return x.view(*org_size)
- class TimeWarp(torch.nn.Module):
- """Time warping using torch.interpolate.
- Args:
- window: time warp parameter
- mode: Interpolate mode
- """
- def __init__(self, window: int = 80, mode: str = DEFAULT_TIME_WARP_MODE):
- super().__init__()
- self.window = window
- self.mode = mode
- def extra_repr(self):
- return f"window={self.window}, mode={self.mode}"
- def forward(self, x: torch.Tensor, x_lengths: torch.Tensor = None):
- """Forward function.
- Args:
- x: (Batch, Time, Freq)
- x_lengths: (Batch,)
- """
- if x_lengths is None or all(le == x_lengths[0] for le in x_lengths):
- # Note that applying same warping for each sample
- y = time_warp(x, window=self.window, mode=self.mode)
- else:
- # FIXME(kamo): I have no idea to batchify Timewarp
- ys = []
- for i in range(x.size(0)):
- _y = time_warp(
- x[i][None, : x_lengths[i]],
- window=self.window,
- mode=self.mode,
- )[0]
- ys.append(_y)
- y = pad_list(ys, 0.0)
- return y, x_lengths
|