encoder.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458
  1. #!/usr/bin/env python3
  2. # -*- encoding: utf-8 -*-
  3. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  4. # MIT License (https://opensource.org/licenses/MIT)
  5. from typing import List
  6. from typing import Optional
  7. from typing import Sequence
  8. from typing import Tuple
  9. from typing import Union
  10. import logging
  11. import torch
  12. import torch.nn as nn
  13. import torch.nn.functional as F
  14. import numpy as np
  15. from funasr.train_utils.device_funcs import to_device
  16. from funasr.models.transformer.utils.nets_utils import make_pad_mask
  17. from funasr.models.sanm.attention import MultiHeadedAttention, MultiHeadedAttentionSANM
  18. from funasr.models.transformer.embedding import SinusoidalPositionEncoder, StreamSinusoidalPositionEncoder
  19. from funasr.models.transformer.layer_norm import LayerNorm
  20. from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear
  21. from funasr.models.transformer.utils.multi_layer_conv import MultiLayeredConv1d
  22. from funasr.models.transformer.positionwise_feed_forward import (
  23. PositionwiseFeedForward, # noqa: H301
  24. )
  25. from funasr.models.transformer.utils.repeat import repeat
  26. from funasr.models.transformer.utils.subsampling import Conv2dSubsampling
  27. from funasr.models.transformer.utils.subsampling import Conv2dSubsampling2
  28. from funasr.models.transformer.utils.subsampling import Conv2dSubsampling6
  29. from funasr.models.transformer.utils.subsampling import Conv2dSubsampling8
  30. from funasr.models.transformer.utils.subsampling import TooShortUttError
  31. from funasr.models.transformer.utils.subsampling import check_short_utt
  32. from funasr.models.ctc.ctc import CTC
  33. from funasr.register import tables
  34. class EncoderLayerSANM(nn.Module):
  35. def __init__(
  36. self,
  37. in_size,
  38. size,
  39. self_attn,
  40. feed_forward,
  41. dropout_rate,
  42. normalize_before=True,
  43. concat_after=False,
  44. stochastic_depth_rate=0.0,
  45. ):
  46. """Construct an EncoderLayer object."""
  47. super(EncoderLayerSANM, self).__init__()
  48. self.self_attn = self_attn
  49. self.feed_forward = feed_forward
  50. self.norm1 = LayerNorm(in_size)
  51. self.norm2 = LayerNorm(size)
  52. self.dropout = nn.Dropout(dropout_rate)
  53. self.in_size = in_size
  54. self.size = size
  55. self.normalize_before = normalize_before
  56. self.concat_after = concat_after
  57. if self.concat_after:
  58. self.concat_linear = nn.Linear(size + size, size)
  59. self.stochastic_depth_rate = stochastic_depth_rate
  60. self.dropout_rate = dropout_rate
  61. def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
  62. """Compute encoded features.
  63. Args:
  64. x_input (torch.Tensor): Input tensor (#batch, time, size).
  65. mask (torch.Tensor): Mask tensor for the input (#batch, time).
  66. cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
  67. Returns:
  68. torch.Tensor: Output tensor (#batch, time, size).
  69. torch.Tensor: Mask tensor (#batch, time).
  70. """
  71. skip_layer = False
  72. # with stochastic depth, residual connection `x + f(x)` becomes
  73. # `x <- x + 1 / (1 - p) * f(x)` at training time.
  74. stoch_layer_coeff = 1.0
  75. if self.training and self.stochastic_depth_rate > 0:
  76. skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
  77. stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
  78. if skip_layer:
  79. if cache is not None:
  80. x = torch.cat([cache, x], dim=1)
  81. return x, mask
  82. residual = x
  83. if self.normalize_before:
  84. x = self.norm1(x)
  85. if self.concat_after:
  86. x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
  87. if self.in_size == self.size:
  88. x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
  89. else:
  90. x = stoch_layer_coeff * self.concat_linear(x_concat)
  91. else:
  92. if self.in_size == self.size:
  93. x = residual + stoch_layer_coeff * self.dropout(
  94. self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
  95. )
  96. else:
  97. x = stoch_layer_coeff * self.dropout(
  98. self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
  99. )
  100. if not self.normalize_before:
  101. x = self.norm1(x)
  102. residual = x
  103. if self.normalize_before:
  104. x = self.norm2(x)
  105. x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
  106. if not self.normalize_before:
  107. x = self.norm2(x)
  108. return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
  109. def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0):
  110. """Compute encoded features.
  111. Args:
  112. x_input (torch.Tensor): Input tensor (#batch, time, size).
  113. mask (torch.Tensor): Mask tensor for the input (#batch, time).
  114. cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
  115. Returns:
  116. torch.Tensor: Output tensor (#batch, time, size).
  117. torch.Tensor: Mask tensor (#batch, time).
  118. """
  119. residual = x
  120. if self.normalize_before:
  121. x = self.norm1(x)
  122. if self.in_size == self.size:
  123. attn, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back)
  124. x = residual + attn
  125. else:
  126. x, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back)
  127. if not self.normalize_before:
  128. x = self.norm1(x)
  129. residual = x
  130. if self.normalize_before:
  131. x = self.norm2(x)
  132. x = residual + self.feed_forward(x)
  133. if not self.normalize_before:
  134. x = self.norm2(x)
  135. return x, cache
  136. @tables.register("encoder_classes", "SANMEncoder")
  137. class SANMEncoder(nn.Module):
  138. """
  139. Author: Zhifu Gao, Shiliang Zhang, Ming Lei, Ian McLoughlin
  140. San-m: Memory equipped self-attention for end-to-end speech recognition
  141. https://arxiv.org/abs/2006.01713
  142. """
  143. def __init__(
  144. self,
  145. input_size: int,
  146. output_size: int = 256,
  147. attention_heads: int = 4,
  148. linear_units: int = 2048,
  149. num_blocks: int = 6,
  150. dropout_rate: float = 0.1,
  151. positional_dropout_rate: float = 0.1,
  152. attention_dropout_rate: float = 0.0,
  153. input_layer: Optional[str] = "conv2d",
  154. pos_enc_class=SinusoidalPositionEncoder,
  155. normalize_before: bool = True,
  156. concat_after: bool = False,
  157. positionwise_layer_type: str = "linear",
  158. positionwise_conv_kernel_size: int = 1,
  159. padding_idx: int = -1,
  160. interctc_layer_idx: List[int] = [],
  161. interctc_use_conditioning: bool = False,
  162. kernel_size : int = 11,
  163. sanm_shfit : int = 0,
  164. lora_list: List[str] = None,
  165. lora_rank: int = 8,
  166. lora_alpha: int = 16,
  167. lora_dropout: float = 0.1,
  168. selfattention_layer_type: str = "sanm",
  169. tf2torch_tensor_name_prefix_torch: str = "encoder",
  170. tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
  171. ):
  172. super().__init__()
  173. self._output_size = output_size
  174. if input_layer == "linear":
  175. self.embed = torch.nn.Sequential(
  176. torch.nn.Linear(input_size, output_size),
  177. torch.nn.LayerNorm(output_size),
  178. torch.nn.Dropout(dropout_rate),
  179. torch.nn.ReLU(),
  180. pos_enc_class(output_size, positional_dropout_rate),
  181. )
  182. elif input_layer == "conv2d":
  183. self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
  184. elif input_layer == "conv2d2":
  185. self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
  186. elif input_layer == "conv2d6":
  187. self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
  188. elif input_layer == "conv2d8":
  189. self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
  190. elif input_layer == "embed":
  191. self.embed = torch.nn.Sequential(
  192. torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
  193. SinusoidalPositionEncoder(),
  194. )
  195. elif input_layer is None:
  196. if input_size == output_size:
  197. self.embed = None
  198. else:
  199. self.embed = torch.nn.Linear(input_size, output_size)
  200. elif input_layer == "pe":
  201. self.embed = SinusoidalPositionEncoder()
  202. elif input_layer == "pe_online":
  203. self.embed = StreamSinusoidalPositionEncoder()
  204. else:
  205. raise ValueError("unknown input_layer: " + input_layer)
  206. self.normalize_before = normalize_before
  207. if positionwise_layer_type == "linear":
  208. positionwise_layer = PositionwiseFeedForward
  209. positionwise_layer_args = (
  210. output_size,
  211. linear_units,
  212. dropout_rate,
  213. )
  214. elif positionwise_layer_type == "conv1d":
  215. positionwise_layer = MultiLayeredConv1d
  216. positionwise_layer_args = (
  217. output_size,
  218. linear_units,
  219. positionwise_conv_kernel_size,
  220. dropout_rate,
  221. )
  222. elif positionwise_layer_type == "conv1d-linear":
  223. positionwise_layer = Conv1dLinear
  224. positionwise_layer_args = (
  225. output_size,
  226. linear_units,
  227. positionwise_conv_kernel_size,
  228. dropout_rate,
  229. )
  230. else:
  231. raise NotImplementedError("Support only linear or conv1d.")
  232. if selfattention_layer_type == "selfattn":
  233. encoder_selfattn_layer = MultiHeadedAttention
  234. encoder_selfattn_layer_args = (
  235. attention_heads,
  236. output_size,
  237. attention_dropout_rate,
  238. )
  239. elif selfattention_layer_type == "sanm":
  240. encoder_selfattn_layer = MultiHeadedAttentionSANM
  241. encoder_selfattn_layer_args0 = (
  242. attention_heads,
  243. input_size,
  244. output_size,
  245. attention_dropout_rate,
  246. kernel_size,
  247. sanm_shfit,
  248. lora_list,
  249. lora_rank,
  250. lora_alpha,
  251. lora_dropout,
  252. )
  253. encoder_selfattn_layer_args = (
  254. attention_heads,
  255. output_size,
  256. output_size,
  257. attention_dropout_rate,
  258. kernel_size,
  259. sanm_shfit,
  260. lora_list,
  261. lora_rank,
  262. lora_alpha,
  263. lora_dropout,
  264. )
  265. self.encoders0 = repeat(
  266. 1,
  267. lambda lnum: EncoderLayerSANM(
  268. input_size,
  269. output_size,
  270. encoder_selfattn_layer(*encoder_selfattn_layer_args0),
  271. positionwise_layer(*positionwise_layer_args),
  272. dropout_rate,
  273. normalize_before,
  274. concat_after,
  275. ),
  276. )
  277. self.encoders = repeat(
  278. num_blocks-1,
  279. lambda lnum: EncoderLayerSANM(
  280. output_size,
  281. output_size,
  282. encoder_selfattn_layer(*encoder_selfattn_layer_args),
  283. positionwise_layer(*positionwise_layer_args),
  284. dropout_rate,
  285. normalize_before,
  286. concat_after,
  287. ),
  288. )
  289. if self.normalize_before:
  290. self.after_norm = LayerNorm(output_size)
  291. self.interctc_layer_idx = interctc_layer_idx
  292. if len(interctc_layer_idx) > 0:
  293. assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
  294. self.interctc_use_conditioning = interctc_use_conditioning
  295. self.conditioning_layer = None
  296. self.dropout = nn.Dropout(dropout_rate)
  297. self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
  298. self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
  299. def output_size(self) -> int:
  300. return self._output_size
  301. def forward(
  302. self,
  303. xs_pad: torch.Tensor,
  304. ilens: torch.Tensor,
  305. prev_states: torch.Tensor = None,
  306. ctc: CTC = None,
  307. ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
  308. """Embed positions in tensor.
  309. Args:
  310. xs_pad: input tensor (B, L, D)
  311. ilens: input length (B)
  312. prev_states: Not to be used now.
  313. Returns:
  314. position embedded tensor and mask
  315. """
  316. masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
  317. xs_pad = xs_pad * self.output_size()**0.5
  318. if self.embed is None:
  319. xs_pad = xs_pad
  320. elif (
  321. isinstance(self.embed, Conv2dSubsampling)
  322. or isinstance(self.embed, Conv2dSubsampling2)
  323. or isinstance(self.embed, Conv2dSubsampling6)
  324. or isinstance(self.embed, Conv2dSubsampling8)
  325. ):
  326. short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
  327. if short_status:
  328. raise TooShortUttError(
  329. f"has {xs_pad.size(1)} frames and is too short for subsampling "
  330. + f"(it needs more than {limit_size} frames), return empty results",
  331. xs_pad.size(1),
  332. limit_size,
  333. )
  334. xs_pad, masks = self.embed(xs_pad, masks)
  335. else:
  336. xs_pad = self.embed(xs_pad)
  337. # xs_pad = self.dropout(xs_pad)
  338. encoder_outs = self.encoders0(xs_pad, masks)
  339. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  340. intermediate_outs = []
  341. if len(self.interctc_layer_idx) == 0:
  342. encoder_outs = self.encoders(xs_pad, masks)
  343. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  344. else:
  345. for layer_idx, encoder_layer in enumerate(self.encoders):
  346. encoder_outs = encoder_layer(xs_pad, masks)
  347. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  348. if layer_idx + 1 in self.interctc_layer_idx:
  349. encoder_out = xs_pad
  350. # intermediate outputs are also normalized
  351. if self.normalize_before:
  352. encoder_out = self.after_norm(encoder_out)
  353. intermediate_outs.append((layer_idx + 1, encoder_out))
  354. if self.interctc_use_conditioning:
  355. ctc_out = ctc.softmax(encoder_out)
  356. xs_pad = xs_pad + self.conditioning_layer(ctc_out)
  357. if self.normalize_before:
  358. xs_pad = self.after_norm(xs_pad)
  359. olens = masks.squeeze(1).sum(1)
  360. if len(intermediate_outs) > 0:
  361. return (xs_pad, intermediate_outs), olens, None
  362. return xs_pad, olens, None
  363. def _add_overlap_chunk(self, feats: np.ndarray, cache: dict = {}):
  364. if len(cache) == 0:
  365. return feats
  366. cache["feats"] = to_device(cache["feats"], device=feats.device)
  367. overlap_feats = torch.cat((cache["feats"], feats), dim=1)
  368. cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :]
  369. return overlap_feats
  370. def forward_chunk(self,
  371. xs_pad: torch.Tensor,
  372. ilens: torch.Tensor,
  373. cache: dict = None,
  374. ctc: CTC = None,
  375. ):
  376. xs_pad *= self.output_size() ** 0.5
  377. if self.embed is None:
  378. xs_pad = xs_pad
  379. else:
  380. xs_pad = self.embed(xs_pad, cache)
  381. if cache["tail_chunk"]:
  382. xs_pad = to_device(cache["feats"], device=xs_pad.device)
  383. else:
  384. xs_pad = self._add_overlap_chunk(xs_pad, cache)
  385. encoder_outs = self.encoders0(xs_pad, None, None, None, None)
  386. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  387. intermediate_outs = []
  388. if len(self.interctc_layer_idx) == 0:
  389. encoder_outs = self.encoders(xs_pad, None, None, None, None)
  390. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  391. else:
  392. for layer_idx, encoder_layer in enumerate(self.encoders):
  393. encoder_outs = encoder_layer(xs_pad, None, None, None, None)
  394. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  395. if layer_idx + 1 in self.interctc_layer_idx:
  396. encoder_out = xs_pad
  397. # intermediate outputs are also normalized
  398. if self.normalize_before:
  399. encoder_out = self.after_norm(encoder_out)
  400. intermediate_outs.append((layer_idx + 1, encoder_out))
  401. if self.interctc_use_conditioning:
  402. ctc_out = ctc.softmax(encoder_out)
  403. xs_pad = xs_pad + self.conditioning_layer(ctc_out)
  404. if self.normalize_before:
  405. xs_pad = self.after_norm(xs_pad)
  406. if len(intermediate_outs) > 0:
  407. return (xs_pad, intermediate_outs), None, None
  408. return xs_pad, ilens, None