mfcca_encoder.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436
  1. from typing import Optional
  2. from typing import Tuple
  3. import logging
  4. import torch
  5. from torch import nn
  6. from funasr.models.encoder.encoder_layer_mfcca import EncoderLayer
  7. from funasr.models.transformer.utils.nets_utils import get_activation
  8. from funasr.models.transformer.utils.nets_utils import make_pad_mask
  9. from funasr.models.transformer.attention import (
  10. MultiHeadedAttention, # noqa: H301
  11. RelPositionMultiHeadedAttention, # noqa: H301
  12. LegacyRelPositionMultiHeadedAttention, # noqa: H301
  13. )
  14. from funasr.models.transformer.embedding import (
  15. PositionalEncoding, # noqa: H301
  16. ScaledPositionalEncoding, # noqa: H301
  17. RelPositionalEncoding, # noqa: H301
  18. LegacyRelPositionalEncoding, # noqa: H301
  19. )
  20. from funasr.models.transformer.layer_norm import LayerNorm
  21. from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear
  22. from funasr.models.transformer.utils.multi_layer_conv import MultiLayeredConv1d
  23. from funasr.models.transformer.positionwise_feed_forward import (
  24. PositionwiseFeedForward, # noqa: H301
  25. )
  26. from funasr.models.transformer.utils.repeat import repeat
  27. from funasr.models.transformer.utils.subsampling import Conv2dSubsampling
  28. from funasr.models.transformer.utils.subsampling import Conv2dSubsampling2
  29. from funasr.models.transformer.utils.subsampling import Conv2dSubsampling6
  30. from funasr.models.transformer.utils.subsampling import Conv2dSubsampling8
  31. from funasr.models.transformer.utils.subsampling import TooShortUttError
  32. from funasr.models.transformer.utils.subsampling import check_short_utt
  33. from funasr.models.encoder.abs_encoder import AbsEncoder
  34. import pdb
  35. import math
  36. class ConvolutionModule(nn.Module):
  37. """ConvolutionModule in Conformer model.
  38. Args:
  39. channels (int): The number of channels of conv layers.
  40. kernel_size (int): Kernerl size of conv layers.
  41. """
  42. def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True):
  43. """Construct an ConvolutionModule object."""
  44. super(ConvolutionModule, self).__init__()
  45. # kernerl_size should be a odd number for 'SAME' padding
  46. assert (kernel_size - 1) % 2 == 0
  47. self.pointwise_conv1 = nn.Conv1d(
  48. channels,
  49. 2 * channels,
  50. kernel_size=1,
  51. stride=1,
  52. padding=0,
  53. bias=bias,
  54. )
  55. self.depthwise_conv = nn.Conv1d(
  56. channels,
  57. channels,
  58. kernel_size,
  59. stride=1,
  60. padding=(kernel_size - 1) // 2,
  61. groups=channels,
  62. bias=bias,
  63. )
  64. self.norm = nn.BatchNorm1d(channels)
  65. self.pointwise_conv2 = nn.Conv1d(
  66. channels,
  67. channels,
  68. kernel_size=1,
  69. stride=1,
  70. padding=0,
  71. bias=bias,
  72. )
  73. self.activation = activation
  74. def forward(self, x):
  75. """Compute convolution module.
  76. Args:
  77. x (torch.Tensor): Input tensor (#batch, time, channels).
  78. Returns:
  79. torch.Tensor: Output tensor (#batch, time, channels).
  80. """
  81. # exchange the temporal dimension and the feature dimension
  82. x = x.transpose(1, 2)
  83. # GLU mechanism
  84. x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
  85. x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
  86. # 1D Depthwise Conv
  87. x = self.depthwise_conv(x)
  88. x = self.activation(self.norm(x))
  89. x = self.pointwise_conv2(x)
  90. return x.transpose(1, 2)
  91. class MFCCAEncoder(AbsEncoder):
  92. """Conformer encoder module.
  93. Args:
  94. input_size (int): Input dimension.
  95. output_size (int): Dimention of attention.
  96. attention_heads (int): The number of heads of multi head attention.
  97. linear_units (int): The number of units of position-wise feed forward.
  98. num_blocks (int): The number of decoder blocks.
  99. dropout_rate (float): Dropout rate.
  100. attention_dropout_rate (float): Dropout rate in attention.
  101. positional_dropout_rate (float): Dropout rate after adding positional encoding.
  102. input_layer (Union[str, torch.nn.Module]): Input layer type.
  103. normalize_before (bool): Whether to use layer_norm before the first block.
  104. concat_after (bool): Whether to concat attention layer's input and output.
  105. If True, additional linear will be applied.
  106. i.e. x -> x + linear(concat(x, att(x)))
  107. If False, no additional linear will be applied. i.e. x -> x + att(x)
  108. positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear".
  109. positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer.
  110. rel_pos_type (str): Whether to use the latest relative positional encoding or
  111. the legacy one. The legacy relative positional encoding will be deprecated
  112. in the future. More Details can be found in
  113. https://github.com/espnet/espnet/pull/2816.
  114. encoder_pos_enc_layer_type (str): Encoder positional encoding layer type.
  115. encoder_attn_layer_type (str): Encoder attention layer type.
  116. activation_type (str): Encoder activation function type.
  117. macaron_style (bool): Whether to use macaron style for positionwise layer.
  118. use_cnn_module (bool): Whether to use convolution module.
  119. zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
  120. cnn_module_kernel (int): Kernerl size of convolution module.
  121. padding_idx (int): Padding idx for input_layer=embed.
  122. """
  123. def __init__(
  124. self,
  125. input_size: int,
  126. output_size: int = 256,
  127. attention_heads: int = 4,
  128. linear_units: int = 2048,
  129. num_blocks: int = 6,
  130. dropout_rate: float = 0.1,
  131. positional_dropout_rate: float = 0.1,
  132. attention_dropout_rate: float = 0.0,
  133. input_layer: str = "conv2d",
  134. normalize_before: bool = True,
  135. concat_after: bool = False,
  136. positionwise_layer_type: str = "linear",
  137. positionwise_conv_kernel_size: int = 3,
  138. macaron_style: bool = False,
  139. rel_pos_type: str = "legacy",
  140. pos_enc_layer_type: str = "rel_pos",
  141. selfattention_layer_type: str = "rel_selfattn",
  142. activation_type: str = "swish",
  143. use_cnn_module: bool = True,
  144. zero_triu: bool = False,
  145. cnn_module_kernel: int = 31,
  146. padding_idx: int = -1,
  147. ):
  148. super().__init__()
  149. self._output_size = output_size
  150. if rel_pos_type == "legacy":
  151. if pos_enc_layer_type == "rel_pos":
  152. pos_enc_layer_type = "legacy_rel_pos"
  153. if selfattention_layer_type == "rel_selfattn":
  154. selfattention_layer_type = "legacy_rel_selfattn"
  155. elif rel_pos_type == "latest":
  156. assert selfattention_layer_type != "legacy_rel_selfattn"
  157. assert pos_enc_layer_type != "legacy_rel_pos"
  158. else:
  159. raise ValueError("unknown rel_pos_type: " + rel_pos_type)
  160. activation = get_activation(activation_type)
  161. if pos_enc_layer_type == "abs_pos":
  162. pos_enc_class = PositionalEncoding
  163. elif pos_enc_layer_type == "scaled_abs_pos":
  164. pos_enc_class = ScaledPositionalEncoding
  165. elif pos_enc_layer_type == "rel_pos":
  166. assert selfattention_layer_type == "rel_selfattn"
  167. pos_enc_class = RelPositionalEncoding
  168. elif pos_enc_layer_type == "legacy_rel_pos":
  169. assert selfattention_layer_type == "legacy_rel_selfattn"
  170. pos_enc_class = LegacyRelPositionalEncoding
  171. logging.warning(
  172. "Using legacy_rel_pos and it will be deprecated in the future."
  173. )
  174. else:
  175. raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
  176. if input_layer == "linear":
  177. self.embed = torch.nn.Sequential(
  178. torch.nn.Linear(input_size, output_size),
  179. torch.nn.LayerNorm(output_size),
  180. torch.nn.Dropout(dropout_rate),
  181. pos_enc_class(output_size, positional_dropout_rate),
  182. )
  183. elif input_layer == "conv2d":
  184. self.embed = Conv2dSubsampling(
  185. input_size,
  186. output_size,
  187. dropout_rate,
  188. pos_enc_class(output_size, positional_dropout_rate),
  189. )
  190. elif input_layer == "conv2d6":
  191. self.embed = Conv2dSubsampling6(
  192. input_size,
  193. output_size,
  194. dropout_rate,
  195. pos_enc_class(output_size, positional_dropout_rate),
  196. )
  197. elif input_layer == "conv2d8":
  198. self.embed = Conv2dSubsampling8(
  199. input_size,
  200. output_size,
  201. dropout_rate,
  202. pos_enc_class(output_size, positional_dropout_rate),
  203. )
  204. elif input_layer == "embed":
  205. self.embed = torch.nn.Sequential(
  206. torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
  207. pos_enc_class(output_size, positional_dropout_rate),
  208. )
  209. elif isinstance(input_layer, torch.nn.Module):
  210. self.embed = torch.nn.Sequential(
  211. input_layer,
  212. pos_enc_class(output_size, positional_dropout_rate),
  213. )
  214. elif input_layer is None:
  215. self.embed = torch.nn.Sequential(
  216. pos_enc_class(output_size, positional_dropout_rate)
  217. )
  218. else:
  219. raise ValueError("unknown input_layer: " + input_layer)
  220. self.normalize_before = normalize_before
  221. if positionwise_layer_type == "linear":
  222. positionwise_layer = PositionwiseFeedForward
  223. positionwise_layer_args = (
  224. output_size,
  225. linear_units,
  226. dropout_rate,
  227. activation,
  228. )
  229. elif positionwise_layer_type == "conv1d":
  230. positionwise_layer = MultiLayeredConv1d
  231. positionwise_layer_args = (
  232. output_size,
  233. linear_units,
  234. positionwise_conv_kernel_size,
  235. dropout_rate,
  236. )
  237. elif positionwise_layer_type == "conv1d-linear":
  238. positionwise_layer = Conv1dLinear
  239. positionwise_layer_args = (
  240. output_size,
  241. linear_units,
  242. positionwise_conv_kernel_size,
  243. dropout_rate,
  244. )
  245. else:
  246. raise NotImplementedError("Support only linear or conv1d.")
  247. if selfattention_layer_type == "selfattn":
  248. encoder_selfattn_layer = MultiHeadedAttention
  249. encoder_selfattn_layer_args = (
  250. attention_heads,
  251. output_size,
  252. attention_dropout_rate,
  253. )
  254. elif selfattention_layer_type == "legacy_rel_selfattn":
  255. assert pos_enc_layer_type == "legacy_rel_pos"
  256. encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention
  257. encoder_selfattn_layer_args = (
  258. attention_heads,
  259. output_size,
  260. attention_dropout_rate,
  261. )
  262. logging.warning(
  263. "Using legacy_rel_selfattn and it will be deprecated in the future."
  264. )
  265. elif selfattention_layer_type == "rel_selfattn":
  266. assert pos_enc_layer_type == "rel_pos"
  267. encoder_selfattn_layer = RelPositionMultiHeadedAttention
  268. encoder_selfattn_layer_args = (
  269. attention_heads,
  270. output_size,
  271. attention_dropout_rate,
  272. zero_triu,
  273. )
  274. else:
  275. raise ValueError("unknown encoder_attn_layer: " + selfattention_layer_type)
  276. convolution_layer = ConvolutionModule
  277. convolution_layer_args = (output_size, cnn_module_kernel, activation)
  278. encoder_selfattn_layer_raw = MultiHeadedAttention
  279. encoder_selfattn_layer_args_raw = (
  280. attention_heads,
  281. output_size,
  282. attention_dropout_rate,
  283. )
  284. self.encoders = repeat(
  285. num_blocks,
  286. lambda lnum: EncoderLayer(
  287. output_size,
  288. encoder_selfattn_layer_raw(*encoder_selfattn_layer_args_raw),
  289. encoder_selfattn_layer(*encoder_selfattn_layer_args),
  290. positionwise_layer(*positionwise_layer_args),
  291. positionwise_layer(*positionwise_layer_args) if macaron_style else None,
  292. convolution_layer(*convolution_layer_args) if use_cnn_module else None,
  293. dropout_rate,
  294. normalize_before,
  295. concat_after,
  296. ),
  297. )
  298. if self.normalize_before:
  299. self.after_norm = LayerNorm(output_size)
  300. self.conv1 = torch.nn.Conv2d(8, 16, [5, 7], stride=[1, 1], padding=(2, 3))
  301. self.conv2 = torch.nn.Conv2d(16, 32, [5, 7], stride=[1, 1], padding=(2, 3))
  302. self.conv3 = torch.nn.Conv2d(32, 16, [5, 7], stride=[1, 1], padding=(2, 3))
  303. self.conv4 = torch.nn.Conv2d(16, 1, [5, 7], stride=[1, 1], padding=(2, 3))
  304. def output_size(self) -> int:
  305. return self._output_size
  306. def forward(
  307. self,
  308. xs_pad: torch.Tensor,
  309. ilens: torch.Tensor,
  310. channel_size: torch.Tensor,
  311. prev_states: torch.Tensor = None,
  312. ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
  313. """Calculate forward propagation.
  314. Args:
  315. xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
  316. ilens (torch.Tensor): Input length (#batch).
  317. prev_states (torch.Tensor): Not to be used now.
  318. Returns:
  319. torch.Tensor: Output tensor (#batch, L, output_size).
  320. torch.Tensor: Output length (#batch).
  321. torch.Tensor: Not to be used now.
  322. """
  323. masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
  324. if (
  325. isinstance(self.embed, Conv2dSubsampling)
  326. or isinstance(self.embed, Conv2dSubsampling6)
  327. or isinstance(self.embed, Conv2dSubsampling8)
  328. ):
  329. short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
  330. if short_status:
  331. raise TooShortUttError(
  332. f"has {xs_pad.size(1)} frames and is too short for subsampling "
  333. + f"(it needs more than {limit_size} frames), return empty results",
  334. xs_pad.size(1),
  335. limit_size,
  336. )
  337. xs_pad, masks = self.embed(xs_pad, masks)
  338. else:
  339. xs_pad = self.embed(xs_pad)
  340. xs_pad, masks, channel_size = self.encoders(xs_pad, masks, channel_size)
  341. if isinstance(xs_pad, tuple):
  342. xs_pad = xs_pad[0]
  343. t_leng = xs_pad.size(1)
  344. d_dim = xs_pad.size(2)
  345. xs_pad = xs_pad.reshape(-1, channel_size, t_leng, d_dim)
  346. # pdb.set_trace()
  347. if (channel_size < 8):
  348. repeat_num = math.ceil(8 / channel_size)
  349. xs_pad = xs_pad.repeat(1, repeat_num, 1, 1)[:, 0:8, :, :]
  350. xs_pad = self.conv1(xs_pad)
  351. xs_pad = self.conv2(xs_pad)
  352. xs_pad = self.conv3(xs_pad)
  353. xs_pad = self.conv4(xs_pad)
  354. xs_pad = xs_pad.squeeze().reshape(-1, t_leng, d_dim)
  355. mask_tmp = masks.size(1)
  356. masks = masks.reshape(-1, channel_size, mask_tmp, t_leng)[:, 0, :, :]
  357. if self.normalize_before:
  358. xs_pad = self.after_norm(xs_pad)
  359. olens = masks.squeeze(1).sum(1)
  360. return xs_pad, olens, None
  361. def forward_hidden(
  362. self,
  363. xs_pad: torch.Tensor,
  364. ilens: torch.Tensor,
  365. prev_states: torch.Tensor = None,
  366. ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
  367. """Calculate forward propagation.
  368. Args:
  369. xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
  370. ilens (torch.Tensor): Input length (#batch).
  371. prev_states (torch.Tensor): Not to be used now.
  372. Returns:
  373. torch.Tensor: Output tensor (#batch, L, output_size).
  374. torch.Tensor: Output length (#batch).
  375. torch.Tensor: Not to be used now.
  376. """
  377. masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
  378. if (
  379. isinstance(self.embed, Conv2dSubsampling)
  380. or isinstance(self.embed, Conv2dSubsampling6)
  381. or isinstance(self.embed, Conv2dSubsampling8)
  382. ):
  383. short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
  384. if short_status:
  385. raise TooShortUttError(
  386. f"has {xs_pad.size(1)} frames and is too short for subsampling "
  387. + f"(it needs more than {limit_size} frames), return empty results",
  388. xs_pad.size(1),
  389. limit_size,
  390. )
  391. xs_pad, masks = self.embed(xs_pad, masks)
  392. else:
  393. xs_pad = self.embed(xs_pad)
  394. num_layer = len(self.encoders)
  395. for idx, encoder in enumerate(self.encoders):
  396. xs_pad, masks = encoder(xs_pad, masks)
  397. if idx == num_layer // 2 - 1:
  398. hidden_feature = xs_pad
  399. if isinstance(xs_pad, tuple):
  400. xs_pad = xs_pad[0]
  401. hidden_feature = hidden_feature[0]
  402. if self.normalize_before:
  403. xs_pad = self.after_norm(xs_pad)
  404. self.hidden_feature = self.after_norm(hidden_feature)
  405. olens = masks.squeeze(1).sum(1)
  406. return xs_pad, olens, None