windowing.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. #!/usr/bin/env python3
  2. # 2020, Technische Universität München; Ludwig Kürzinger
  3. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  4. """Sliding Window for raw audio input data."""
  5. import torch
  6. import torch.nn as nn
  7. from typing import Tuple
  8. class SlidingWindow(nn.Module):
  9. """Sliding Window.
  10. Provides a sliding window over a batched continuous raw audio tensor.
  11. Optionally, provides padding (Currently not implemented).
  12. Combine this module with a pre-encoder compatible with raw audio data,
  13. for example Sinc convolutions.
  14. Known issues:
  15. Output length is calculated incorrectly if audio shorter than win_length.
  16. WARNING: trailing values are discarded - padding not implemented yet.
  17. There is currently no additional window function applied to input values.
  18. """
  19. def __init__(
  20. self,
  21. win_length: int = 400,
  22. hop_length: int = 160,
  23. channels: int = 1,
  24. padding: int = None,
  25. fs=None,
  26. ):
  27. """Initialize.
  28. Args:
  29. win_length: Length of frame.
  30. hop_length: Relative starting point of next frame.
  31. channels: Number of input channels.
  32. padding: Padding (placeholder, currently not implemented).
  33. fs: Sampling rate (placeholder for compatibility, not used).
  34. """
  35. super().__init__()
  36. self.fs = fs
  37. self.win_length = win_length
  38. self.hop_length = hop_length
  39. self.channels = channels
  40. self.padding = padding
  41. def forward(
  42. self, input: torch.Tensor, input_lengths: torch.Tensor
  43. ) -> Tuple[torch.Tensor, torch.Tensor]:
  44. """Apply a sliding window on the input.
  45. Args:
  46. input: Input (B, T, C*D) or (B, T*C*D), with D=C=1.
  47. input_lengths: Input lengths within batch.
  48. Returns:
  49. Tensor: Output with dimensions (B, T, C, D), with D=win_length.
  50. Tensor: Output lengths within batch.
  51. """
  52. input_size = input.size()
  53. B = input_size[0]
  54. T = input_size[1]
  55. C = self.channels
  56. D = self.win_length
  57. # (B, T, C) --> (T, B, C)
  58. continuous = input.view(B, T, C).permute(1, 0, 2)
  59. windowed = continuous.unfold(0, D, self.hop_length)
  60. # (T, B, C, D) --> (B, T, C, D)
  61. output = windowed.permute(1, 0, 2, 3).contiguous()
  62. # After unfold(), windowed lengths change:
  63. output_lengths = (input_lengths - self.win_length) // self.hop_length + 1
  64. return output, output_lengths
  65. def output_size(self) -> int:
  66. """Return output length of feature dimension D, i.e. the window length."""
  67. return self.win_length