sinc_conv.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. #!/usr/bin/env python3
  2. # 2020, Technische Universität München; Ludwig Kürzinger
  3. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  4. """Sinc convolutions."""
  5. import math
  6. import torch
  7. from typeguard import check_argument_types
  8. from typing import Union
  9. class LogCompression(torch.nn.Module):
  10. """Log Compression Activation.
  11. Activation function `log(abs(x) + 1)`.
  12. """
  13. def __init__(self):
  14. """Initialize."""
  15. super().__init__()
  16. def forward(self, x: torch.Tensor) -> torch.Tensor:
  17. """Forward.
  18. Applies the Log Compression function elementwise on tensor x.
  19. """
  20. return torch.log(torch.abs(x) + 1)
  21. class SincConv(torch.nn.Module):
  22. """Sinc Convolution.
  23. This module performs a convolution using Sinc filters in time domain as kernel.
  24. Sinc filters function as band passes in spectral domain.
  25. The filtering is done as a convolution in time domain, and no transformation
  26. to spectral domain is necessary.
  27. This implementation of the Sinc convolution is heavily inspired
  28. by Ravanelli et al. https://github.com/mravanelli/SincNet,
  29. and adapted for the ESpnet toolkit.
  30. Combine Sinc convolutions with a log compression activation function, as in:
  31. https://arxiv.org/abs/2010.07597
  32. Notes:
  33. Currently, the same filters are applied to all input channels.
  34. The windowing function is applied on the kernel to obtained a smoother filter,
  35. and not on the input values, which is different to traditional ASR.
  36. """
  37. def __init__(
  38. self,
  39. in_channels: int,
  40. out_channels: int,
  41. kernel_size: int,
  42. stride: int = 1,
  43. padding: int = 0,
  44. dilation: int = 1,
  45. window_func: str = "hamming",
  46. scale_type: str = "mel",
  47. fs: Union[int, float] = 16000,
  48. ):
  49. """Initialize Sinc convolutions.
  50. Args:
  51. in_channels: Number of input channels.
  52. out_channels: Number of output channels.
  53. kernel_size: Sinc filter kernel size (needs to be an odd number).
  54. stride: See torch.nn.functional.conv1d.
  55. padding: See torch.nn.functional.conv1d.
  56. dilation: See torch.nn.functional.conv1d.
  57. window_func: Window function on the filter, one of ["hamming", "none"].
  58. fs (str, int, float): Sample rate of the input data
  59. """
  60. assert check_argument_types()
  61. super().__init__()
  62. window_funcs = {
  63. "none": self.none_window,
  64. "hamming": self.hamming_window,
  65. }
  66. if window_func not in window_funcs:
  67. raise NotImplementedError(
  68. f"Window function has to be one of {list(window_funcs.keys())}",
  69. )
  70. self.window_func = window_funcs[window_func]
  71. scale_choices = {
  72. "mel": MelScale,
  73. "bark": BarkScale,
  74. }
  75. if scale_type not in scale_choices:
  76. raise NotImplementedError(
  77. f"Scale has to be one of {list(scale_choices.keys())}",
  78. )
  79. self.scale = scale_choices[scale_type]
  80. self.in_channels = in_channels
  81. self.out_channels = out_channels
  82. self.kernel_size = kernel_size
  83. self.padding = padding
  84. self.dilation = dilation
  85. self.stride = stride
  86. self.fs = float(fs)
  87. if self.kernel_size % 2 == 0:
  88. raise ValueError("SincConv: Kernel size must be odd.")
  89. self.f = None
  90. N = self.kernel_size // 2
  91. self._x = 2 * math.pi * torch.linspace(1, N, N)
  92. self._window = self.window_func(torch.linspace(1, N, N))
  93. # init may get overwritten by E2E network,
  94. # but is still required to calculate output dim
  95. self.init_filters()
  96. @staticmethod
  97. def sinc(x: torch.Tensor) -> torch.Tensor:
  98. """Sinc function."""
  99. x2 = x + 1e-6
  100. return torch.sin(x2) / x2
  101. @staticmethod
  102. def none_window(x: torch.Tensor) -> torch.Tensor:
  103. """Identity-like windowing function."""
  104. return torch.ones_like(x)
  105. @staticmethod
  106. def hamming_window(x: torch.Tensor) -> torch.Tensor:
  107. """Hamming Windowing function."""
  108. L = 2 * x.size(0) + 1
  109. x = x.flip(0)
  110. return 0.54 - 0.46 * torch.cos(2.0 * math.pi * x / L)
  111. def init_filters(self):
  112. """Initialize filters with filterbank values."""
  113. f = self.scale.bank(self.out_channels, self.fs)
  114. f = torch.div(f, self.fs)
  115. self.f = torch.nn.Parameter(f, requires_grad=True)
  116. def _create_filters(self, device: str):
  117. """Calculate coefficients.
  118. This function (re-)calculates the filter convolutions coefficients.
  119. """
  120. f_mins = torch.abs(self.f[:, 0])
  121. f_maxs = torch.abs(self.f[:, 0]) + torch.abs(self.f[:, 1] - self.f[:, 0])
  122. self._x = self._x.to(device)
  123. self._window = self._window.to(device)
  124. f_mins_x = torch.matmul(f_mins.view(-1, 1), self._x.view(1, -1))
  125. f_maxs_x = torch.matmul(f_maxs.view(-1, 1), self._x.view(1, -1))
  126. kernel = (torch.sin(f_maxs_x) - torch.sin(f_mins_x)) / (0.5 * self._x)
  127. kernel = kernel * self._window
  128. kernel_left = kernel.flip(1)
  129. kernel_center = (2 * f_maxs - 2 * f_mins).unsqueeze(1)
  130. filters = torch.cat([kernel_left, kernel_center, kernel], dim=1)
  131. filters = filters.view(filters.size(0), 1, filters.size(1))
  132. self.sinc_filters = filters
  133. def forward(self, xs: torch.Tensor) -> torch.Tensor:
  134. """Sinc convolution forward function.
  135. Args:
  136. xs: Batch in form of torch.Tensor (B, C_in, D_in).
  137. Returns:
  138. xs: Batch in form of torch.Tensor (B, C_out, D_out).
  139. """
  140. self._create_filters(xs.device)
  141. xs = torch.nn.functional.conv1d(
  142. xs,
  143. self.sinc_filters,
  144. padding=self.padding,
  145. stride=self.stride,
  146. dilation=self.dilation,
  147. groups=self.in_channels,
  148. )
  149. return xs
  150. def get_odim(self, idim: int) -> int:
  151. """Obtain the output dimension of the filter."""
  152. D_out = idim + 2 * self.padding - self.dilation * (self.kernel_size - 1) - 1
  153. D_out = (D_out // self.stride) + 1
  154. return D_out
  155. class MelScale:
  156. """Mel frequency scale."""
  157. @staticmethod
  158. def convert(f):
  159. """Convert Hz to mel."""
  160. return 1125.0 * torch.log(torch.div(f, 700.0) + 1.0)
  161. @staticmethod
  162. def invert(x):
  163. """Convert mel to Hz."""
  164. return 700.0 * (torch.exp(torch.div(x, 1125.0)) - 1.0)
  165. @classmethod
  166. def bank(cls, channels: int, fs: float) -> torch.Tensor:
  167. """Obtain initialization values for the mel scale.
  168. Args:
  169. channels: Number of channels.
  170. fs: Sample rate.
  171. Returns:
  172. torch.Tensor: Filter start frequencíes.
  173. torch.Tensor: Filter stop frequencies.
  174. """
  175. assert check_argument_types()
  176. # min and max bandpass edge frequencies
  177. min_frequency = torch.tensor(30.0)
  178. max_frequency = torch.tensor(fs * 0.5)
  179. frequencies = torch.linspace(
  180. cls.convert(min_frequency), cls.convert(max_frequency), channels + 2
  181. )
  182. frequencies = cls.invert(frequencies)
  183. f1, f2 = frequencies[:-2], frequencies[2:]
  184. return torch.stack([f1, f2], dim=1)
  185. class BarkScale:
  186. """Bark frequency scale.
  187. Has wider bandwidths at lower frequencies, see:
  188. Critical bandwidth: BARK
  189. Zwicker and Terhardt, 1980
  190. """
  191. @staticmethod
  192. def convert(f):
  193. """Convert Hz to Bark."""
  194. b = torch.div(f, 1000.0)
  195. b = torch.pow(b, 2.0) * 1.4
  196. b = torch.pow(b + 1.0, 0.69)
  197. return b * 75.0 + 25.0
  198. @staticmethod
  199. def invert(x):
  200. """Convert Bark to Hz."""
  201. f = torch.div(x - 25.0, 75.0)
  202. f = torch.pow(f, (1.0 / 0.69))
  203. f = torch.div(f - 1.0, 1.4)
  204. f = torch.pow(f, 0.5)
  205. return f * 1000.0
  206. @classmethod
  207. def bank(cls, channels: int, fs: float) -> torch.Tensor:
  208. """Obtain initialization values for the Bark scale.
  209. Args:
  210. channels: Number of channels.
  211. fs: Sample rate.
  212. Returns:
  213. torch.Tensor: Filter start frequencíes.
  214. torch.Tensor: Filter stop frequencíes.
  215. """
  216. assert check_argument_types()
  217. # min and max BARK center frequencies by approximation
  218. min_center_frequency = torch.tensor(70.0)
  219. max_center_frequency = torch.tensor(fs * 0.45)
  220. center_frequencies = torch.linspace(
  221. cls.convert(min_center_frequency),
  222. cls.convert(max_center_frequency),
  223. channels,
  224. )
  225. center_frequencies = cls.invert(center_frequencies)
  226. f1 = center_frequencies - torch.div(cls.convert(center_frequencies), 2)
  227. f2 = center_frequencies + torch.div(cls.convert(center_frequencies), 2)
  228. return torch.stack([f1, f2], dim=1)