windowing.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. #!/usr/bin/env python3
  2. # 2020, Technische Universität München; Ludwig Kürzinger
  3. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  4. """Sliding Window for raw audio input data."""
  5. from funasr.models.frontend.abs_frontend import AbsFrontend
  6. import torch
  7. from typeguard import check_argument_types
  8. from typing import Tuple
  9. class SlidingWindow(AbsFrontend):
  10. """Sliding Window.
  11. Provides a sliding window over a batched continuous raw audio tensor.
  12. Optionally, provides padding (Currently not implemented).
  13. Combine this module with a pre-encoder compatible with raw audio data,
  14. for example Sinc convolutions.
  15. Known issues:
  16. Output length is calculated incorrectly if audio shorter than win_length.
  17. WARNING: trailing values are discarded - padding not implemented yet.
  18. There is currently no additional window function applied to input values.
  19. """
  20. def __init__(
  21. self,
  22. win_length: int = 400,
  23. hop_length: int = 160,
  24. channels: int = 1,
  25. padding: int = None,
  26. fs=None,
  27. ):
  28. """Initialize.
  29. Args:
  30. win_length: Length of frame.
  31. hop_length: Relative starting point of next frame.
  32. channels: Number of input channels.
  33. padding: Padding (placeholder, currently not implemented).
  34. fs: Sampling rate (placeholder for compatibility, not used).
  35. """
  36. assert check_argument_types()
  37. super().__init__()
  38. self.fs = fs
  39. self.win_length = win_length
  40. self.hop_length = hop_length
  41. self.channels = channels
  42. self.padding = padding
  43. def forward(
  44. self, input: torch.Tensor, input_lengths: torch.Tensor
  45. ) -> Tuple[torch.Tensor, torch.Tensor]:
  46. """Apply a sliding window on the input.
  47. Args:
  48. input: Input (B, T, C*D) or (B, T*C*D), with D=C=1.
  49. input_lengths: Input lengths within batch.
  50. Returns:
  51. Tensor: Output with dimensions (B, T, C, D), with D=win_length.
  52. Tensor: Output lengths within batch.
  53. """
  54. input_size = input.size()
  55. B = input_size[0]
  56. T = input_size[1]
  57. C = self.channels
  58. D = self.win_length
  59. # (B, T, C) --> (T, B, C)
  60. continuous = input.view(B, T, C).permute(1, 0, 2)
  61. windowed = continuous.unfold(0, D, self.hop_length)
  62. # (T, B, C, D) --> (B, T, C, D)
  63. output = windowed.permute(1, 0, 2, 3).contiguous()
  64. # After unfold(), windowed lengths change:
  65. output_lengths = (input_lengths - self.win_length) // self.hop_length + 1
  66. return output, output_lengths
  67. def output_size(self) -> int:
  68. """Return output length of feature dimension D, i.e. the window length."""
  69. return self.win_length