mfcca_encoder.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450
  1. from typing import Optional
  2. from typing import Tuple
  3. import logging
  4. import torch
  5. from torch import nn
  6. from typeguard import check_argument_types
  7. from funasr.models.encoder.encoder_layer_mfcca import EncoderLayer
  8. from funasr.modules.nets_utils import get_activation
  9. from funasr.modules.nets_utils import make_pad_mask
  10. from funasr.modules.attention import (
  11. MultiHeadedAttention, # noqa: H301
  12. RelPositionMultiHeadedAttention, # noqa: H301
  13. LegacyRelPositionMultiHeadedAttention, # noqa: H301
  14. )
  15. from funasr.modules.embedding import (
  16. PositionalEncoding, # noqa: H301
  17. ScaledPositionalEncoding, # noqa: H301
  18. RelPositionalEncoding, # noqa: H301
  19. LegacyRelPositionalEncoding, # noqa: H301
  20. )
  21. from funasr.modules.layer_norm import LayerNorm
  22. from funasr.modules.multi_layer_conv import Conv1dLinear
  23. from funasr.modules.multi_layer_conv import MultiLayeredConv1d
  24. from funasr.modules.positionwise_feed_forward import (
  25. PositionwiseFeedForward, # noqa: H301
  26. )
  27. from funasr.modules.repeat import repeat
  28. from funasr.modules.subsampling import Conv2dSubsampling
  29. from funasr.modules.subsampling import Conv2dSubsampling2
  30. from funasr.modules.subsampling import Conv2dSubsampling6
  31. from funasr.modules.subsampling import Conv2dSubsampling8
  32. from funasr.modules.subsampling import TooShortUttError
  33. from funasr.modules.subsampling import check_short_utt
  34. from funasr.models.encoder.abs_encoder import AbsEncoder
  35. import pdb
  36. import math
  37. class ConvolutionModule(nn.Module):
  38. """ConvolutionModule in Conformer model.
  39. Args:
  40. channels (int): The number of channels of conv layers.
  41. kernel_size (int): Kernerl size of conv layers.
  42. """
  43. def __init__(self, channels, kernel_size, activation=nn.ReLU(), bias=True):
  44. """Construct an ConvolutionModule object."""
  45. super(ConvolutionModule, self).__init__()
  46. # kernerl_size should be a odd number for 'SAME' padding
  47. assert (kernel_size - 1) % 2 == 0
  48. self.pointwise_conv1 = nn.Conv1d(
  49. channels,
  50. 2 * channels,
  51. kernel_size=1,
  52. stride=1,
  53. padding=0,
  54. bias=bias,
  55. )
  56. self.depthwise_conv = nn.Conv1d(
  57. channels,
  58. channels,
  59. kernel_size,
  60. stride=1,
  61. padding=(kernel_size - 1) // 2,
  62. groups=channels,
  63. bias=bias,
  64. )
  65. self.norm = nn.BatchNorm1d(channels)
  66. self.pointwise_conv2 = nn.Conv1d(
  67. channels,
  68. channels,
  69. kernel_size=1,
  70. stride=1,
  71. padding=0,
  72. bias=bias,
  73. )
  74. self.activation = activation
  75. def forward(self, x):
  76. """Compute convolution module.
  77. Args:
  78. x (torch.Tensor): Input tensor (#batch, time, channels).
  79. Returns:
  80. torch.Tensor: Output tensor (#batch, time, channels).
  81. """
  82. # exchange the temporal dimension and the feature dimension
  83. x = x.transpose(1, 2)
  84. # GLU mechanism
  85. x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
  86. x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
  87. # 1D Depthwise Conv
  88. x = self.depthwise_conv(x)
  89. x = self.activation(self.norm(x))
  90. x = self.pointwise_conv2(x)
  91. return x.transpose(1, 2)
  92. class MFCCAEncoder(AbsEncoder):
  93. """Conformer encoder module.
  94. Args:
  95. input_size (int): Input dimension.
  96. output_size (int): Dimention of attention.
  97. attention_heads (int): The number of heads of multi head attention.
  98. linear_units (int): The number of units of position-wise feed forward.
  99. num_blocks (int): The number of decoder blocks.
  100. dropout_rate (float): Dropout rate.
  101. attention_dropout_rate (float): Dropout rate in attention.
  102. positional_dropout_rate (float): Dropout rate after adding positional encoding.
  103. input_layer (Union[str, torch.nn.Module]): Input layer type.
  104. normalize_before (bool): Whether to use layer_norm before the first block.
  105. concat_after (bool): Whether to concat attention layer's input and output.
  106. If True, additional linear will be applied.
  107. i.e. x -> x + linear(concat(x, att(x)))
  108. If False, no additional linear will be applied. i.e. x -> x + att(x)
  109. positionwise_layer_type (str): "linear", "conv1d", or "conv1d-linear".
  110. positionwise_conv_kernel_size (int): Kernel size of positionwise conv1d layer.
  111. rel_pos_type (str): Whether to use the latest relative positional encoding or
  112. the legacy one. The legacy relative positional encoding will be deprecated
  113. in the future. More Details can be found in
  114. https://github.com/espnet/espnet/pull/2816.
  115. encoder_pos_enc_layer_type (str): Encoder positional encoding layer type.
  116. encoder_attn_layer_type (str): Encoder attention layer type.
  117. activation_type (str): Encoder activation function type.
  118. macaron_style (bool): Whether to use macaron style for positionwise layer.
  119. use_cnn_module (bool): Whether to use convolution module.
  120. zero_triu (bool): Whether to zero the upper triangular part of attention matrix.
  121. cnn_module_kernel (int): Kernerl size of convolution module.
  122. padding_idx (int): Padding idx for input_layer=embed.
  123. """
  124. def __init__(
  125. self,
  126. input_size: int,
  127. output_size: int = 256,
  128. attention_heads: int = 4,
  129. linear_units: int = 2048,
  130. num_blocks: int = 6,
  131. dropout_rate: float = 0.1,
  132. positional_dropout_rate: float = 0.1,
  133. attention_dropout_rate: float = 0.0,
  134. input_layer: str = "conv2d",
  135. normalize_before: bool = True,
  136. concat_after: bool = False,
  137. positionwise_layer_type: str = "linear",
  138. positionwise_conv_kernel_size: int = 3,
  139. macaron_style: bool = False,
  140. rel_pos_type: str = "legacy",
  141. pos_enc_layer_type: str = "rel_pos",
  142. selfattention_layer_type: str = "rel_selfattn",
  143. activation_type: str = "swish",
  144. use_cnn_module: bool = True,
  145. zero_triu: bool = False,
  146. cnn_module_kernel: int = 31,
  147. padding_idx: int = -1,
  148. ):
  149. assert check_argument_types()
  150. super().__init__()
  151. self._output_size = output_size
  152. if rel_pos_type == "legacy":
  153. if pos_enc_layer_type == "rel_pos":
  154. pos_enc_layer_type = "legacy_rel_pos"
  155. if selfattention_layer_type == "rel_selfattn":
  156. selfattention_layer_type = "legacy_rel_selfattn"
  157. elif rel_pos_type == "latest":
  158. assert selfattention_layer_type != "legacy_rel_selfattn"
  159. assert pos_enc_layer_type != "legacy_rel_pos"
  160. else:
  161. raise ValueError("unknown rel_pos_type: " + rel_pos_type)
  162. activation = get_activation(activation_type)
  163. if pos_enc_layer_type == "abs_pos":
  164. pos_enc_class = PositionalEncoding
  165. elif pos_enc_layer_type == "scaled_abs_pos":
  166. pos_enc_class = ScaledPositionalEncoding
  167. elif pos_enc_layer_type == "rel_pos":
  168. assert selfattention_layer_type == "rel_selfattn"
  169. pos_enc_class = RelPositionalEncoding
  170. elif pos_enc_layer_type == "legacy_rel_pos":
  171. assert selfattention_layer_type == "legacy_rel_selfattn"
  172. pos_enc_class = LegacyRelPositionalEncoding
  173. logging.warning(
  174. "Using legacy_rel_pos and it will be deprecated in the future."
  175. )
  176. else:
  177. raise ValueError("unknown pos_enc_layer: " + pos_enc_layer_type)
  178. if input_layer == "linear":
  179. self.embed = torch.nn.Sequential(
  180. torch.nn.Linear(input_size, output_size),
  181. torch.nn.LayerNorm(output_size),
  182. torch.nn.Dropout(dropout_rate),
  183. pos_enc_class(output_size, positional_dropout_rate),
  184. )
  185. elif input_layer == "conv2d":
  186. self.embed = Conv2dSubsampling(
  187. input_size,
  188. output_size,
  189. dropout_rate,
  190. pos_enc_class(output_size, positional_dropout_rate),
  191. )
  192. elif input_layer == "conv2d6":
  193. self.embed = Conv2dSubsampling6(
  194. input_size,
  195. output_size,
  196. dropout_rate,
  197. pos_enc_class(output_size, positional_dropout_rate),
  198. )
  199. elif input_layer == "conv2d8":
  200. self.embed = Conv2dSubsampling8(
  201. input_size,
  202. output_size,
  203. dropout_rate,
  204. pos_enc_class(output_size, positional_dropout_rate),
  205. )
  206. elif input_layer == "embed":
  207. self.embed = torch.nn.Sequential(
  208. torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
  209. pos_enc_class(output_size, positional_dropout_rate),
  210. )
  211. elif isinstance(input_layer, torch.nn.Module):
  212. self.embed = torch.nn.Sequential(
  213. input_layer,
  214. pos_enc_class(output_size, positional_dropout_rate),
  215. )
  216. elif input_layer is None:
  217. self.embed = torch.nn.Sequential(
  218. pos_enc_class(output_size, positional_dropout_rate)
  219. )
  220. else:
  221. raise ValueError("unknown input_layer: " + input_layer)
  222. self.normalize_before = normalize_before
  223. if positionwise_layer_type == "linear":
  224. positionwise_layer = PositionwiseFeedForward
  225. positionwise_layer_args = (
  226. output_size,
  227. linear_units,
  228. dropout_rate,
  229. activation,
  230. )
  231. elif positionwise_layer_type == "conv1d":
  232. positionwise_layer = MultiLayeredConv1d
  233. positionwise_layer_args = (
  234. output_size,
  235. linear_units,
  236. positionwise_conv_kernel_size,
  237. dropout_rate,
  238. )
  239. elif positionwise_layer_type == "conv1d-linear":
  240. positionwise_layer = Conv1dLinear
  241. positionwise_layer_args = (
  242. output_size,
  243. linear_units,
  244. positionwise_conv_kernel_size,
  245. dropout_rate,
  246. )
  247. else:
  248. raise NotImplementedError("Support only linear or conv1d.")
  249. if selfattention_layer_type == "selfattn":
  250. encoder_selfattn_layer = MultiHeadedAttention
  251. encoder_selfattn_layer_args = (
  252. attention_heads,
  253. output_size,
  254. attention_dropout_rate,
  255. )
  256. elif selfattention_layer_type == "legacy_rel_selfattn":
  257. assert pos_enc_layer_type == "legacy_rel_pos"
  258. encoder_selfattn_layer = LegacyRelPositionMultiHeadedAttention
  259. encoder_selfattn_layer_args = (
  260. attention_heads,
  261. output_size,
  262. attention_dropout_rate,
  263. )
  264. logging.warning(
  265. "Using legacy_rel_selfattn and it will be deprecated in the future."
  266. )
  267. elif selfattention_layer_type == "rel_selfattn":
  268. assert pos_enc_layer_type == "rel_pos"
  269. encoder_selfattn_layer = RelPositionMultiHeadedAttention
  270. encoder_selfattn_layer_args = (
  271. attention_heads,
  272. output_size,
  273. attention_dropout_rate,
  274. zero_triu,
  275. )
  276. else:
  277. raise ValueError("unknown encoder_attn_layer: " + selfattention_layer_type)
  278. convolution_layer = ConvolutionModule
  279. convolution_layer_args = (output_size, cnn_module_kernel, activation)
  280. encoder_selfattn_layer_raw = MultiHeadedAttention
  281. encoder_selfattn_layer_args_raw = (
  282. attention_heads,
  283. output_size,
  284. attention_dropout_rate,
  285. )
  286. self.encoders = repeat(
  287. num_blocks,
  288. lambda lnum: EncoderLayer(
  289. output_size,
  290. encoder_selfattn_layer_raw(*encoder_selfattn_layer_args_raw),
  291. encoder_selfattn_layer(*encoder_selfattn_layer_args),
  292. positionwise_layer(*positionwise_layer_args),
  293. positionwise_layer(*positionwise_layer_args) if macaron_style else None,
  294. convolution_layer(*convolution_layer_args) if use_cnn_module else None,
  295. dropout_rate,
  296. normalize_before,
  297. concat_after,
  298. ),
  299. )
  300. if self.normalize_before:
  301. self.after_norm = LayerNorm(output_size)
  302. self.conv1 = torch.nn.Conv2d(8, 16, [5,7], stride=[1,1], padding=(2,3))
  303. self.conv2 = torch.nn.Conv2d(16, 32, [5,7], stride=[1,1], padding=(2,3))
  304. self.conv3 = torch.nn.Conv2d(32, 16, [5,7], stride=[1,1], padding=(2,3))
  305. self.conv4 = torch.nn.Conv2d(16, 1, [5,7], stride=[1,1], padding=(2,3))
  306. def output_size(self) -> int:
  307. return self._output_size
  308. def forward(
  309. self,
  310. xs_pad: torch.Tensor,
  311. ilens: torch.Tensor,
  312. channel_size: torch.Tensor,
  313. prev_states: torch.Tensor = None,
  314. ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
  315. """Calculate forward propagation.
  316. Args:
  317. xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
  318. ilens (torch.Tensor): Input length (#batch).
  319. prev_states (torch.Tensor): Not to be used now.
  320. Returns:
  321. torch.Tensor: Output tensor (#batch, L, output_size).
  322. torch.Tensor: Output length (#batch).
  323. torch.Tensor: Not to be used now.
  324. """
  325. masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
  326. if (
  327. isinstance(self.embed, Conv2dSubsampling)
  328. or isinstance(self.embed, Conv2dSubsampling6)
  329. or isinstance(self.embed, Conv2dSubsampling8)
  330. ):
  331. short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
  332. if short_status:
  333. raise TooShortUttError(
  334. f"has {xs_pad.size(1)} frames and is too short for subsampling "
  335. + f"(it needs more than {limit_size} frames), return empty results",
  336. xs_pad.size(1),
  337. limit_size,
  338. )
  339. xs_pad, masks = self.embed(xs_pad, masks)
  340. else:
  341. xs_pad = self.embed(xs_pad)
  342. xs_pad, masks, channel_size = self.encoders(xs_pad, masks, channel_size)
  343. if isinstance(xs_pad, tuple):
  344. xs_pad = xs_pad[0]
  345. t_leng = xs_pad.size(1)
  346. d_dim = xs_pad.size(2)
  347. xs_pad = xs_pad.reshape(-1,channel_size,t_leng,d_dim)
  348. #pdb.set_trace()
  349. if(channel_size<8):
  350. repeat_num = math.ceil(8/channel_size)
  351. xs_pad = xs_pad.repeat(1,repeat_num,1,1)[:,0:8,:,:]
  352. xs_pad = self.conv1(xs_pad)
  353. xs_pad = self.conv2(xs_pad)
  354. xs_pad = self.conv3(xs_pad)
  355. xs_pad = self.conv4(xs_pad)
  356. xs_pad = xs_pad.squeeze().reshape(-1,t_leng,d_dim)
  357. mask_tmp = masks.size(1)
  358. masks = masks.reshape(-1,channel_size,mask_tmp,t_leng)[:,0,:,:]
  359. if self.normalize_before:
  360. xs_pad = self.after_norm(xs_pad)
  361. olens = masks.squeeze(1).sum(1)
  362. return xs_pad, olens, None
  363. def forward_hidden(
  364. self,
  365. xs_pad: torch.Tensor,
  366. ilens: torch.Tensor,
  367. prev_states: torch.Tensor = None,
  368. ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
  369. """Calculate forward propagation.
  370. Args:
  371. xs_pad (torch.Tensor): Input tensor (#batch, L, input_size).
  372. ilens (torch.Tensor): Input length (#batch).
  373. prev_states (torch.Tensor): Not to be used now.
  374. Returns:
  375. torch.Tensor: Output tensor (#batch, L, output_size).
  376. torch.Tensor: Output length (#batch).
  377. torch.Tensor: Not to be used now.
  378. """
  379. masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
  380. if (
  381. isinstance(self.embed, Conv2dSubsampling)
  382. or isinstance(self.embed, Conv2dSubsampling6)
  383. or isinstance(self.embed, Conv2dSubsampling8)
  384. ):
  385. short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
  386. if short_status:
  387. raise TooShortUttError(
  388. f"has {xs_pad.size(1)} frames and is too short for subsampling "
  389. + f"(it needs more than {limit_size} frames), return empty results",
  390. xs_pad.size(1),
  391. limit_size,
  392. )
  393. xs_pad, masks = self.embed(xs_pad, masks)
  394. else:
  395. xs_pad = self.embed(xs_pad)
  396. num_layer = len(self.encoders)
  397. for idx, encoder in enumerate(self.encoders):
  398. xs_pad, masks = encoder(xs_pad, masks)
  399. if idx == num_layer // 2 - 1:
  400. hidden_feature = xs_pad
  401. if isinstance(xs_pad, tuple):
  402. xs_pad = xs_pad[0]
  403. hidden_feature = hidden_feature[0]
  404. if self.normalize_before:
  405. xs_pad = self.after_norm(xs_pad)
  406. self.hidden_feature = self.after_norm(hidden_feature)
  407. olens = masks.squeeze(1).sum(1)
  408. return xs_pad, olens, None