sinc.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. #!/usr/bin/env python3
  2. # 2020, Technische Universität München; Ludwig Kürzinger
  3. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  4. """Sinc convolutions for raw audio input."""
  5. from collections import OrderedDict
  6. from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
  7. from funasr.layers.sinc_conv import LogCompression
  8. from funasr.layers.sinc_conv import SincConv
  9. import humanfriendly
  10. import torch
  11. from typing import Optional
  12. from typing import Tuple
  13. from typing import Union
  14. class LightweightSincConvs(AbsPreEncoder):
  15. """Lightweight Sinc Convolutions.
  16. Instead of using precomputed features, end-to-end speech recognition
  17. can also be done directly from raw audio using sinc convolutions, as
  18. described in "Lightweight End-to-End Speech Recognition from Raw Audio
  19. Data Using Sinc-Convolutions" by Kürzinger et al.
  20. https://arxiv.org/abs/2010.07597
  21. To use Sinc convolutions in your model instead of the default f-bank
  22. frontend, set this module as your pre-encoder with `preencoder: sinc`
  23. and use the input of the sliding window frontend with
  24. `frontend: sliding_window` in your yaml configuration file.
  25. So that the process flow is:
  26. Frontend (SlidingWindow) -> SpecAug -> Normalization ->
  27. Pre-encoder (LightweightSincConvs) -> Encoder -> Decoder
  28. Note that this method also performs data augmentation in time domain
  29. (vs. in spectral domain in the default frontend).
  30. Use `plot_sinc_filters.py` to visualize the learned Sinc filters.
  31. """
  32. def __init__(
  33. self,
  34. fs: Union[int, str, float] = 16000,
  35. in_channels: int = 1,
  36. out_channels: int = 256,
  37. activation_type: str = "leakyrelu",
  38. dropout_type: str = "dropout",
  39. windowing_type: str = "hamming",
  40. scale_type: str = "mel",
  41. ):
  42. """Initialize the module.
  43. Args:
  44. fs: Sample rate.
  45. in_channels: Number of input channels.
  46. out_channels: Number of output channels (for each input channel).
  47. activation_type: Choice of activation function.
  48. dropout_type: Choice of dropout function.
  49. windowing_type: Choice of windowing function.
  50. scale_type: Choice of filter-bank initialization scale.
  51. """
  52. super().__init__()
  53. if isinstance(fs, str):
  54. fs = humanfriendly.parse_size(fs)
  55. self.fs = fs
  56. self.in_channels = in_channels
  57. self.out_channels = out_channels
  58. self.activation_type = activation_type
  59. self.dropout_type = dropout_type
  60. self.windowing_type = windowing_type
  61. self.scale_type = scale_type
  62. self.choices_dropout = {
  63. "dropout": torch.nn.Dropout,
  64. "spatial": SpatialDropout,
  65. "dropout2d": torch.nn.Dropout2d,
  66. }
  67. if dropout_type not in self.choices_dropout:
  68. raise NotImplementedError(
  69. f"Dropout type has to be one of "
  70. f"{list(self.choices_dropout.keys())}",
  71. )
  72. self.choices_activation = {
  73. "leakyrelu": torch.nn.LeakyReLU,
  74. "relu": torch.nn.ReLU,
  75. }
  76. if activation_type not in self.choices_activation:
  77. raise NotImplementedError(
  78. f"Activation type has to be one of "
  79. f"{list(self.choices_activation.keys())}",
  80. )
  81. # initialization
  82. self._create_sinc_convs()
  83. # Sinc filters require custom initialization
  84. self.espnet_initialization_fn()
  85. def _create_sinc_convs(self):
  86. blocks = OrderedDict()
  87. # SincConvBlock
  88. out_channels = 128
  89. self.filters = SincConv(
  90. self.in_channels,
  91. out_channels,
  92. kernel_size=101,
  93. stride=1,
  94. fs=self.fs,
  95. window_func=self.windowing_type,
  96. scale_type=self.scale_type,
  97. )
  98. block = OrderedDict(
  99. [
  100. ("Filters", self.filters),
  101. ("LogCompression", LogCompression()),
  102. ("BatchNorm", torch.nn.BatchNorm1d(out_channels, affine=True)),
  103. ("AvgPool", torch.nn.AvgPool1d(2)),
  104. ]
  105. )
  106. blocks["SincConvBlock"] = torch.nn.Sequential(block)
  107. in_channels = out_channels
  108. # First convolutional block, connects the sinc output to the front-end "body"
  109. out_channels = 128
  110. blocks["DConvBlock1"] = self.gen_lsc_block(
  111. in_channels,
  112. out_channels,
  113. depthwise_kernel_size=25,
  114. depthwise_stride=2,
  115. pointwise_groups=0,
  116. avgpool=True,
  117. dropout_probability=0.1,
  118. )
  119. in_channels = out_channels
  120. # Second convolutional block, multiple convolutional layers
  121. out_channels = self.out_channels
  122. for layer in [2, 3, 4]:
  123. blocks[f"DConvBlock{layer}"] = self.gen_lsc_block(
  124. in_channels, out_channels, depthwise_kernel_size=9, depthwise_stride=1
  125. )
  126. in_channels = out_channels
  127. # Third Convolutional block, acts as coupling to encoder
  128. out_channels = self.out_channels
  129. blocks["DConvBlock5"] = self.gen_lsc_block(
  130. in_channels,
  131. out_channels,
  132. depthwise_kernel_size=7,
  133. depthwise_stride=1,
  134. pointwise_groups=0,
  135. )
  136. self.blocks = torch.nn.Sequential(blocks)
  137. def gen_lsc_block(
  138. self,
  139. in_channels: int,
  140. out_channels: int,
  141. depthwise_kernel_size: int = 9,
  142. depthwise_stride: int = 1,
  143. depthwise_groups=None,
  144. pointwise_groups=0,
  145. dropout_probability: float = 0.15,
  146. avgpool=False,
  147. ):
  148. """Generate a convolutional block for Lightweight Sinc convolutions.
  149. Each block consists of either a depthwise or a depthwise-separable
  150. convolutions together with dropout, (batch-)normalization layer, and
  151. an optional average-pooling layer.
  152. Args:
  153. in_channels: Number of input channels.
  154. out_channels: Number of output channels.
  155. depthwise_kernel_size: Kernel size of the depthwise convolution.
  156. depthwise_stride: Stride of the depthwise convolution.
  157. depthwise_groups: Number of groups of the depthwise convolution.
  158. pointwise_groups: Number of groups of the pointwise convolution.
  159. dropout_probability: Dropout probability in the block.
  160. avgpool: If True, an AvgPool layer is inserted.
  161. Returns:
  162. torch.nn.Sequential: Neural network building block.
  163. """
  164. block = OrderedDict()
  165. if not depthwise_groups:
  166. # GCD(in_channels, out_channels) to prevent size mismatches
  167. depthwise_groups, r = in_channels, out_channels
  168. while r != 0:
  169. depthwise_groups, r = depthwise_groups, depthwise_groups % r
  170. block["depthwise"] = torch.nn.Conv1d(
  171. in_channels,
  172. out_channels,
  173. depthwise_kernel_size,
  174. depthwise_stride,
  175. groups=depthwise_groups,
  176. )
  177. if pointwise_groups:
  178. block["pointwise"] = torch.nn.Conv1d(
  179. out_channels, out_channels, 1, 1, groups=pointwise_groups
  180. )
  181. block["activation"] = self.choices_activation[self.activation_type]()
  182. block["batchnorm"] = torch.nn.BatchNorm1d(out_channels, affine=True)
  183. if avgpool:
  184. block["avgpool"] = torch.nn.AvgPool1d(2)
  185. block["dropout"] = self.choices_dropout[self.dropout_type](dropout_probability)
  186. return torch.nn.Sequential(block)
  187. def espnet_initialization_fn(self):
  188. """Initialize sinc filters with filterbank values."""
  189. self.filters.init_filters()
  190. for block in self.blocks:
  191. for layer in block:
  192. if type(layer) == torch.nn.BatchNorm1d and layer.affine:
  193. layer.weight.data[:] = 1.0
  194. layer.bias.data[:] = 0.0
  195. def forward(
  196. self, input: torch.Tensor, input_lengths: torch.Tensor
  197. ) -> Tuple[torch.Tensor, torch.Tensor]:
  198. """Apply Lightweight Sinc Convolutions.
  199. The input shall be formatted as (B, T, C_in, D_in)
  200. with B as batch size, T as time dimension, C_in as channels,
  201. and D_in as feature dimension.
  202. The output will then be (B, T, C_out*D_out)
  203. with C_out and D_out as output dimensions.
  204. The current module structure only handles D_in=400, so that D_out=1.
  205. Remark for the multichannel case: C_out is the number of out_channels
  206. given at initialization multiplied with C_in.
  207. """
  208. # Transform input data:
  209. # (B, T, C_in, D_in) -> (B*T, C_in, D_in)
  210. B, T, C_in, D_in = input.size()
  211. input_frames = input.view(B * T, C_in, D_in)
  212. output_frames = self.blocks.forward(input_frames)
  213. # ---TRANSFORM: (B*T, C_out, D_out) -> (B, T, C_out*D_out)
  214. _, C_out, D_out = output_frames.size()
  215. output_frames = output_frames.view(B, T, C_out * D_out)
  216. return output_frames, input_lengths # no state in this layer
  217. def output_size(self) -> int:
  218. """Get the output size."""
  219. return self.out_channels * self.in_channels
  220. class SpatialDropout(torch.nn.Module):
  221. """Spatial dropout module.
  222. Apply dropout to full channels on tensors of input (B, C, D)
  223. """
  224. def __init__(
  225. self,
  226. dropout_probability: float = 0.15,
  227. shape: Optional[Union[tuple, list]] = None,
  228. ):
  229. """Initialize.
  230. Args:
  231. dropout_probability: Dropout probability.
  232. shape (tuple, list): Shape of input tensors.
  233. """
  234. super().__init__()
  235. if shape is None:
  236. shape = (0, 2, 1)
  237. self.dropout = torch.nn.Dropout2d(dropout_probability)
  238. self.shape = (shape,)
  239. def forward(self, x: torch.Tensor) -> torch.Tensor:
  240. """Forward of spatial dropout module."""
  241. y = x.permute(*self.shape)
  242. y = self.dropout(y)
  243. return y.permute(*self.shape)