sinc.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. #!/usr/bin/env python3
  2. # 2020, Technische Universität München; Ludwig Kürzinger
  3. # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
  4. """Sinc convolutions for raw audio input."""
  5. from collections import OrderedDict
  6. from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
  7. from funasr.layers.sinc_conv import LogCompression
  8. from funasr.layers.sinc_conv import SincConv
  9. import humanfriendly
  10. import torch
  11. from typeguard import check_argument_types
  12. from typing import Optional
  13. from typing import Tuple
  14. from typing import Union
  15. class LightweightSincConvs(AbsPreEncoder):
  16. """Lightweight Sinc Convolutions.
  17. Instead of using precomputed features, end-to-end speech recognition
  18. can also be done directly from raw audio using sinc convolutions, as
  19. described in "Lightweight End-to-End Speech Recognition from Raw Audio
  20. Data Using Sinc-Convolutions" by Kürzinger et al.
  21. https://arxiv.org/abs/2010.07597
  22. To use Sinc convolutions in your model instead of the default f-bank
  23. frontend, set this module as your pre-encoder with `preencoder: sinc`
  24. and use the input of the sliding window frontend with
  25. `frontend: sliding_window` in your yaml configuration file.
  26. So that the process flow is:
  27. Frontend (SlidingWindow) -> SpecAug -> Normalization ->
  28. Pre-encoder (LightweightSincConvs) -> Encoder -> Decoder
  29. Note that this method also performs data augmentation in time domain
  30. (vs. in spectral domain in the default frontend).
  31. Use `plot_sinc_filters.py` to visualize the learned Sinc filters.
  32. """
  33. def __init__(
  34. self,
  35. fs: Union[int, str, float] = 16000,
  36. in_channels: int = 1,
  37. out_channels: int = 256,
  38. activation_type: str = "leakyrelu",
  39. dropout_type: str = "dropout",
  40. windowing_type: str = "hamming",
  41. scale_type: str = "mel",
  42. ):
  43. """Initialize the module.
  44. Args:
  45. fs: Sample rate.
  46. in_channels: Number of input channels.
  47. out_channels: Number of output channels (for each input channel).
  48. activation_type: Choice of activation function.
  49. dropout_type: Choice of dropout function.
  50. windowing_type: Choice of windowing function.
  51. scale_type: Choice of filter-bank initialization scale.
  52. """
  53. assert check_argument_types()
  54. super().__init__()
  55. if isinstance(fs, str):
  56. fs = humanfriendly.parse_size(fs)
  57. self.fs = fs
  58. self.in_channels = in_channels
  59. self.out_channels = out_channels
  60. self.activation_type = activation_type
  61. self.dropout_type = dropout_type
  62. self.windowing_type = windowing_type
  63. self.scale_type = scale_type
  64. self.choices_dropout = {
  65. "dropout": torch.nn.Dropout,
  66. "spatial": SpatialDropout,
  67. "dropout2d": torch.nn.Dropout2d,
  68. }
  69. if dropout_type not in self.choices_dropout:
  70. raise NotImplementedError(
  71. f"Dropout type has to be one of "
  72. f"{list(self.choices_dropout.keys())}",
  73. )
  74. self.choices_activation = {
  75. "leakyrelu": torch.nn.LeakyReLU,
  76. "relu": torch.nn.ReLU,
  77. }
  78. if activation_type not in self.choices_activation:
  79. raise NotImplementedError(
  80. f"Activation type has to be one of "
  81. f"{list(self.choices_activation.keys())}",
  82. )
  83. # initialization
  84. self._create_sinc_convs()
  85. # Sinc filters require custom initialization
  86. self.espnet_initialization_fn()
  87. def _create_sinc_convs(self):
  88. blocks = OrderedDict()
  89. # SincConvBlock
  90. out_channels = 128
  91. self.filters = SincConv(
  92. self.in_channels,
  93. out_channels,
  94. kernel_size=101,
  95. stride=1,
  96. fs=self.fs,
  97. window_func=self.windowing_type,
  98. scale_type=self.scale_type,
  99. )
  100. block = OrderedDict(
  101. [
  102. ("Filters", self.filters),
  103. ("LogCompression", LogCompression()),
  104. ("BatchNorm", torch.nn.BatchNorm1d(out_channels, affine=True)),
  105. ("AvgPool", torch.nn.AvgPool1d(2)),
  106. ]
  107. )
  108. blocks["SincConvBlock"] = torch.nn.Sequential(block)
  109. in_channels = out_channels
  110. # First convolutional block, connects the sinc output to the front-end "body"
  111. out_channels = 128
  112. blocks["DConvBlock1"] = self.gen_lsc_block(
  113. in_channels,
  114. out_channels,
  115. depthwise_kernel_size=25,
  116. depthwise_stride=2,
  117. pointwise_groups=0,
  118. avgpool=True,
  119. dropout_probability=0.1,
  120. )
  121. in_channels = out_channels
  122. # Second convolutional block, multiple convolutional layers
  123. out_channels = self.out_channels
  124. for layer in [2, 3, 4]:
  125. blocks[f"DConvBlock{layer}"] = self.gen_lsc_block(
  126. in_channels, out_channels, depthwise_kernel_size=9, depthwise_stride=1
  127. )
  128. in_channels = out_channels
  129. # Third Convolutional block, acts as coupling to encoder
  130. out_channels = self.out_channels
  131. blocks["DConvBlock5"] = self.gen_lsc_block(
  132. in_channels,
  133. out_channels,
  134. depthwise_kernel_size=7,
  135. depthwise_stride=1,
  136. pointwise_groups=0,
  137. )
  138. self.blocks = torch.nn.Sequential(blocks)
  139. def gen_lsc_block(
  140. self,
  141. in_channels: int,
  142. out_channels: int,
  143. depthwise_kernel_size: int = 9,
  144. depthwise_stride: int = 1,
  145. depthwise_groups=None,
  146. pointwise_groups=0,
  147. dropout_probability: float = 0.15,
  148. avgpool=False,
  149. ):
  150. """Generate a convolutional block for Lightweight Sinc convolutions.
  151. Each block consists of either a depthwise or a depthwise-separable
  152. convolutions together with dropout, (batch-)normalization layer, and
  153. an optional average-pooling layer.
  154. Args:
  155. in_channels: Number of input channels.
  156. out_channels: Number of output channels.
  157. depthwise_kernel_size: Kernel size of the depthwise convolution.
  158. depthwise_stride: Stride of the depthwise convolution.
  159. depthwise_groups: Number of groups of the depthwise convolution.
  160. pointwise_groups: Number of groups of the pointwise convolution.
  161. dropout_probability: Dropout probability in the block.
  162. avgpool: If True, an AvgPool layer is inserted.
  163. Returns:
  164. torch.nn.Sequential: Neural network building block.
  165. """
  166. block = OrderedDict()
  167. if not depthwise_groups:
  168. # GCD(in_channels, out_channels) to prevent size mismatches
  169. depthwise_groups, r = in_channels, out_channels
  170. while r != 0:
  171. depthwise_groups, r = depthwise_groups, depthwise_groups % r
  172. block["depthwise"] = torch.nn.Conv1d(
  173. in_channels,
  174. out_channels,
  175. depthwise_kernel_size,
  176. depthwise_stride,
  177. groups=depthwise_groups,
  178. )
  179. if pointwise_groups:
  180. block["pointwise"] = torch.nn.Conv1d(
  181. out_channels, out_channels, 1, 1, groups=pointwise_groups
  182. )
  183. block["activation"] = self.choices_activation[self.activation_type]()
  184. block["batchnorm"] = torch.nn.BatchNorm1d(out_channels, affine=True)
  185. if avgpool:
  186. block["avgpool"] = torch.nn.AvgPool1d(2)
  187. block["dropout"] = self.choices_dropout[self.dropout_type](dropout_probability)
  188. return torch.nn.Sequential(block)
  189. def espnet_initialization_fn(self):
  190. """Initialize sinc filters with filterbank values."""
  191. self.filters.init_filters()
  192. for block in self.blocks:
  193. for layer in block:
  194. if type(layer) == torch.nn.BatchNorm1d and layer.affine:
  195. layer.weight.data[:] = 1.0
  196. layer.bias.data[:] = 0.0
  197. def forward(
  198. self, input: torch.Tensor, input_lengths: torch.Tensor
  199. ) -> Tuple[torch.Tensor, torch.Tensor]:
  200. """Apply Lightweight Sinc Convolutions.
  201. The input shall be formatted as (B, T, C_in, D_in)
  202. with B as batch size, T as time dimension, C_in as channels,
  203. and D_in as feature dimension.
  204. The output will then be (B, T, C_out*D_out)
  205. with C_out and D_out as output dimensions.
  206. The current module structure only handles D_in=400, so that D_out=1.
  207. Remark for the multichannel case: C_out is the number of out_channels
  208. given at initialization multiplied with C_in.
  209. """
  210. # Transform input data:
  211. # (B, T, C_in, D_in) -> (B*T, C_in, D_in)
  212. B, T, C_in, D_in = input.size()
  213. input_frames = input.view(B * T, C_in, D_in)
  214. output_frames = self.blocks.forward(input_frames)
  215. # ---TRANSFORM: (B*T, C_out, D_out) -> (B, T, C_out*D_out)
  216. _, C_out, D_out = output_frames.size()
  217. output_frames = output_frames.view(B, T, C_out * D_out)
  218. return output_frames, input_lengths # no state in this layer
  219. def output_size(self) -> int:
  220. """Get the output size."""
  221. return self.out_channels * self.in_channels
  222. class SpatialDropout(torch.nn.Module):
  223. """Spatial dropout module.
  224. Apply dropout to full channels on tensors of input (B, C, D)
  225. """
  226. def __init__(
  227. self,
  228. dropout_probability: float = 0.15,
  229. shape: Optional[Union[tuple, list]] = None,
  230. ):
  231. """Initialize.
  232. Args:
  233. dropout_probability: Dropout probability.
  234. shape (tuple, list): Shape of input tensors.
  235. """
  236. assert check_argument_types()
  237. super().__init__()
  238. if shape is None:
  239. shape = (0, 2, 1)
  240. self.dropout = torch.nn.Dropout2d(dropout_probability)
  241. self.shape = (shape,)
  242. def forward(self, x: torch.Tensor) -> torch.Tensor:
  243. """Forward of spatial dropout module."""
  244. y = x.permute(*self.shape)
  245. y = self.dropout(y)
  246. return y.permute(*self.shape)