self_attention_encoder.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478
  1. from typing import List
  2. from typing import Optional
  3. from typing import Sequence
  4. from typing import Tuple
  5. from typing import Union
  6. import logging
  7. import torch
  8. import torch.nn as nn
  9. from funasr.models.scama.chunk_utilis import overlap_chunk
  10. import numpy as np
  11. from funasr.models.transformer.utils.nets_utils import make_pad_mask
  12. from funasr.models.sond.attention import MultiHeadSelfAttention
  13. from funasr.models.transformer.embedding import SinusoidalPositionEncoder
  14. from funasr.models.transformer.layer_norm import LayerNorm
  15. from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear
  16. from funasr.models.transformer.utils.multi_layer_conv import MultiLayeredConv1d
  17. from funasr.models.transformer.positionwise_feed_forward import (
  18. PositionwiseFeedForward, # noqa: H301
  19. )
  20. from funasr.models.transformer.utils.repeat import repeat
  21. from funasr.models.transformer.utils.subsampling import Conv2dSubsampling
  22. from funasr.models.transformer.utils.subsampling import Conv2dSubsampling2
  23. from funasr.models.transformer.utils.subsampling import Conv2dSubsampling6
  24. from funasr.models.transformer.utils.subsampling import Conv2dSubsampling8
  25. from funasr.models.transformer.utils.subsampling import TooShortUttError
  26. from funasr.models.transformer.utils.subsampling import check_short_utt
  27. from funasr.models.ctc import CTC
  28. from funasr.models.encoder.abs_encoder import AbsEncoder
  29. class EncoderLayer(nn.Module):
  30. def __init__(
  31. self,
  32. in_size,
  33. size,
  34. self_attn,
  35. feed_forward,
  36. dropout_rate,
  37. normalize_before=True,
  38. concat_after=False,
  39. stochastic_depth_rate=0.0,
  40. ):
  41. """Construct an EncoderLayer object."""
  42. super(EncoderLayer, self).__init__()
  43. self.self_attn = self_attn
  44. self.feed_forward = feed_forward
  45. self.norm1 = LayerNorm(in_size)
  46. self.norm2 = LayerNorm(size)
  47. self.dropout = nn.Dropout(dropout_rate)
  48. self.in_size = in_size
  49. self.size = size
  50. self.normalize_before = normalize_before
  51. self.concat_after = concat_after
  52. if self.concat_after:
  53. self.concat_linear = nn.Linear(size + size, size)
  54. self.stochastic_depth_rate = stochastic_depth_rate
  55. self.dropout_rate = dropout_rate
  56. def forward(self, x, mask, cache=None, mask_att_chunk_encoder=None):
  57. """Compute encoded features.
  58. Args:
  59. x_input (torch.Tensor): Input tensor (#batch, time, size).
  60. mask (torch.Tensor): Mask tensor for the input (#batch, time).
  61. cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
  62. Returns:
  63. torch.Tensor: Output tensor (#batch, time, size).
  64. torch.Tensor: Mask tensor (#batch, time).
  65. """
  66. skip_layer = False
  67. # with stochastic depth, residual connection `x + f(x)` becomes
  68. # `x <- x + 1 / (1 - p) * f(x)` at training time.
  69. stoch_layer_coeff = 1.0
  70. if self.training and self.stochastic_depth_rate > 0:
  71. skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
  72. stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
  73. if skip_layer:
  74. if cache is not None:
  75. x = torch.cat([cache, x], dim=1)
  76. return x, mask
  77. residual = x
  78. if self.normalize_before:
  79. x = self.norm1(x)
  80. if self.concat_after:
  81. x_concat = torch.cat((x, self.self_attn(x, mask, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
  82. if self.in_size == self.size:
  83. x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
  84. else:
  85. x = stoch_layer_coeff * self.concat_linear(x_concat)
  86. else:
  87. if self.in_size == self.size:
  88. x = residual + stoch_layer_coeff * self.dropout(
  89. self.self_attn(x, mask, mask_att_chunk_encoder=mask_att_chunk_encoder)
  90. )
  91. else:
  92. x = stoch_layer_coeff * self.dropout(
  93. self.self_attn(x, mask, mask_att_chunk_encoder=mask_att_chunk_encoder)
  94. )
  95. if not self.normalize_before:
  96. x = self.norm1(x)
  97. residual = x
  98. if self.normalize_before:
  99. x = self.norm2(x)
  100. x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
  101. if not self.normalize_before:
  102. x = self.norm2(x)
  103. return x, mask, cache, mask_att_chunk_encoder
  104. class SelfAttentionEncoder(AbsEncoder):
  105. """
  106. Author: Speech Lab of DAMO Academy, Alibaba Group
  107. Self attention encoder in OpenNMT framework
  108. """
  109. def __init__(
  110. self,
  111. input_size: int,
  112. output_size: int = 256,
  113. attention_heads: int = 4,
  114. linear_units: int = 2048,
  115. num_blocks: int = 6,
  116. dropout_rate: float = 0.1,
  117. positional_dropout_rate: float = 0.1,
  118. attention_dropout_rate: float = 0.0,
  119. input_layer: Optional[str] = "conv2d",
  120. pos_enc_class=SinusoidalPositionEncoder,
  121. normalize_before: bool = True,
  122. concat_after: bool = False,
  123. positionwise_layer_type: str = "linear",
  124. positionwise_conv_kernel_size: int = 1,
  125. padding_idx: int = -1,
  126. interctc_layer_idx: List[int] = [],
  127. interctc_use_conditioning: bool = False,
  128. tf2torch_tensor_name_prefix_torch: str = "encoder",
  129. tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
  130. out_units=None,
  131. ):
  132. super().__init__()
  133. self._output_size = output_size
  134. if input_layer == "linear":
  135. self.embed = torch.nn.Sequential(
  136. torch.nn.Linear(input_size, output_size),
  137. torch.nn.LayerNorm(output_size),
  138. torch.nn.Dropout(dropout_rate),
  139. torch.nn.ReLU(),
  140. pos_enc_class(output_size, positional_dropout_rate),
  141. )
  142. elif input_layer == "conv2d":
  143. self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
  144. elif input_layer == "conv2d2":
  145. self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
  146. elif input_layer == "conv2d6":
  147. self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
  148. elif input_layer == "conv2d8":
  149. self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
  150. elif input_layer == "embed":
  151. self.embed = torch.nn.Sequential(
  152. torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
  153. SinusoidalPositionEncoder(),
  154. )
  155. elif input_layer is None:
  156. if input_size == output_size:
  157. self.embed = None
  158. else:
  159. self.embed = torch.nn.Linear(input_size, output_size)
  160. elif input_layer == "pe":
  161. self.embed = SinusoidalPositionEncoder()
  162. elif input_layer == "null":
  163. self.embed = None
  164. else:
  165. raise ValueError("unknown input_layer: " + input_layer)
  166. self.normalize_before = normalize_before
  167. if positionwise_layer_type == "linear":
  168. positionwise_layer = PositionwiseFeedForward
  169. positionwise_layer_args = (
  170. output_size,
  171. linear_units,
  172. dropout_rate,
  173. )
  174. elif positionwise_layer_type == "conv1d":
  175. positionwise_layer = MultiLayeredConv1d
  176. positionwise_layer_args = (
  177. output_size,
  178. linear_units,
  179. positionwise_conv_kernel_size,
  180. dropout_rate,
  181. )
  182. elif positionwise_layer_type == "conv1d-linear":
  183. positionwise_layer = Conv1dLinear
  184. positionwise_layer_args = (
  185. output_size,
  186. linear_units,
  187. positionwise_conv_kernel_size,
  188. dropout_rate,
  189. )
  190. else:
  191. raise NotImplementedError("Support only linear or conv1d.")
  192. self.encoders = repeat(
  193. num_blocks,
  194. lambda lnum: EncoderLayer(
  195. output_size,
  196. output_size,
  197. MultiHeadSelfAttention(
  198. attention_heads,
  199. output_size,
  200. output_size,
  201. attention_dropout_rate,
  202. ),
  203. positionwise_layer(*positionwise_layer_args),
  204. dropout_rate,
  205. normalize_before,
  206. concat_after,
  207. ) if lnum > 0 else EncoderLayer(
  208. input_size,
  209. output_size,
  210. MultiHeadSelfAttention(
  211. attention_heads,
  212. input_size if input_layer == "pe" or input_layer == "null" else output_size,
  213. output_size,
  214. attention_dropout_rate,
  215. ),
  216. positionwise_layer(*positionwise_layer_args),
  217. dropout_rate,
  218. normalize_before,
  219. concat_after,
  220. ),
  221. )
  222. if self.normalize_before:
  223. self.after_norm = LayerNorm(output_size)
  224. self.interctc_layer_idx = interctc_layer_idx
  225. if len(interctc_layer_idx) > 0:
  226. assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
  227. self.interctc_use_conditioning = interctc_use_conditioning
  228. self.conditioning_layer = None
  229. self.dropout = nn.Dropout(dropout_rate)
  230. self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
  231. self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
  232. self.out_units = out_units
  233. if out_units is not None:
  234. self.output_linear = nn.Linear(output_size, out_units)
  235. def output_size(self) -> int:
  236. return self._output_size
  237. def forward(
  238. self,
  239. xs_pad: torch.Tensor,
  240. ilens: torch.Tensor,
  241. prev_states: torch.Tensor = None,
  242. ctc: CTC = None,
  243. ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
  244. """Embed positions in tensor.
  245. Args:
  246. xs_pad: input tensor (B, L, D)
  247. ilens: input length (B)
  248. prev_states: Not to be used now.
  249. Returns:
  250. position embedded tensor and mask
  251. """
  252. masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
  253. xs_pad = xs_pad * self.output_size()**0.5
  254. if self.embed is None:
  255. xs_pad = xs_pad
  256. elif (
  257. isinstance(self.embed, Conv2dSubsampling)
  258. or isinstance(self.embed, Conv2dSubsampling2)
  259. or isinstance(self.embed, Conv2dSubsampling6)
  260. or isinstance(self.embed, Conv2dSubsampling8)
  261. ):
  262. short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
  263. if short_status:
  264. raise TooShortUttError(
  265. f"has {xs_pad.size(1)} frames and is too short for subsampling "
  266. + f"(it needs more than {limit_size} frames), return empty results",
  267. xs_pad.size(1),
  268. limit_size,
  269. )
  270. xs_pad, masks = self.embed(xs_pad, masks)
  271. else:
  272. xs_pad = self.embed(xs_pad)
  273. xs_pad = self.dropout(xs_pad)
  274. # encoder_outs = self.encoders0(xs_pad, masks)
  275. # xs_pad, masks = encoder_outs[0], encoder_outs[1]
  276. intermediate_outs = []
  277. if len(self.interctc_layer_idx) == 0:
  278. encoder_outs = self.encoders(xs_pad, masks)
  279. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  280. else:
  281. for layer_idx, encoder_layer in enumerate(self.encoders):
  282. encoder_outs = encoder_layer(xs_pad, masks)
  283. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  284. if layer_idx + 1 in self.interctc_layer_idx:
  285. encoder_out = xs_pad
  286. # intermediate outputs are also normalized
  287. if self.normalize_before:
  288. encoder_out = self.after_norm(encoder_out)
  289. intermediate_outs.append((layer_idx + 1, encoder_out))
  290. if self.interctc_use_conditioning:
  291. ctc_out = ctc.softmax(encoder_out)
  292. xs_pad = xs_pad + self.conditioning_layer(ctc_out)
  293. if self.normalize_before:
  294. xs_pad = self.after_norm(xs_pad)
  295. if self.out_units is not None:
  296. xs_pad = self.output_linear(xs_pad)
  297. olens = masks.squeeze(1).sum(1)
  298. if len(intermediate_outs) > 0:
  299. return (xs_pad, intermediate_outs), olens, None
  300. return xs_pad, olens, None
  301. def gen_tf2torch_map_dict(self):
  302. tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
  303. tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
  304. map_dict_local = {
  305. # cicd
  306. # torch: conv1d.weight in "out_channel in_channel kernel_size"
  307. # tf : conv1d.weight in "kernel_size in_channel out_channel"
  308. # torch: linear.weight in "out_channel in_channel"
  309. # tf : dense.weight in "in_channel out_channel"
  310. "{}.encoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
  311. {"name": "{}/layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
  312. "squeeze": None,
  313. "transpose": None,
  314. }, # (256,),(256,)
  315. "{}.encoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
  316. {"name": "{}/layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
  317. "squeeze": None,
  318. "transpose": None,
  319. }, # (256,),(256,)
  320. "{}.encoders.layeridx.self_attn.linear_q_k_v.weight".format(tensor_name_prefix_torch):
  321. {"name": "{}/layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
  322. "squeeze": 0,
  323. "transpose": (1, 0),
  324. }, # (768,256),(1,256,768)
  325. "{}.encoders.layeridx.self_attn.linear_q_k_v.bias".format(tensor_name_prefix_torch):
  326. {"name": "{}/layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
  327. "squeeze": None,
  328. "transpose": None,
  329. }, # (768,),(768,)
  330. "{}.encoders.layeridx.self_attn.linear_out.weight".format(tensor_name_prefix_torch):
  331. {"name": "{}/layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
  332. "squeeze": 0,
  333. "transpose": (1, 0),
  334. }, # (256,256),(1,256,256)
  335. "{}.encoders.layeridx.self_attn.linear_out.bias".format(tensor_name_prefix_torch):
  336. {"name": "{}/layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
  337. "squeeze": None,
  338. "transpose": None,
  339. }, # (256,),(256,)
  340. # ffn
  341. "{}.encoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
  342. {"name": "{}/layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
  343. "squeeze": None,
  344. "transpose": None,
  345. }, # (256,),(256,)
  346. "{}.encoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
  347. {"name": "{}/layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
  348. "squeeze": None,
  349. "transpose": None,
  350. }, # (256,),(256,)
  351. "{}.encoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
  352. {"name": "{}/layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf),
  353. "squeeze": 0,
  354. "transpose": (1, 0),
  355. }, # (1024,256),(1,256,1024)
  356. "{}.encoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
  357. {"name": "{}/layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf),
  358. "squeeze": None,
  359. "transpose": None,
  360. }, # (1024,),(1024,)
  361. "{}.encoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
  362. {"name": "{}/layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
  363. "squeeze": 0,
  364. "transpose": (1, 0),
  365. }, # (256,1024),(1,1024,256)
  366. "{}.encoders.layeridx.feed_forward.w_2.bias".format(tensor_name_prefix_torch):
  367. {"name": "{}/layer_layeridx/ffn/conv1d_1/bias".format(tensor_name_prefix_tf),
  368. "squeeze": None,
  369. "transpose": None,
  370. }, # (256,),(256,)
  371. # out norm
  372. "{}.after_norm.weight".format(tensor_name_prefix_torch):
  373. {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
  374. "squeeze": None,
  375. "transpose": None,
  376. }, # (256,),(256,)
  377. "{}.after_norm.bias".format(tensor_name_prefix_torch):
  378. {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
  379. "squeeze": None,
  380. "transpose": None,
  381. }, # (256,),(256,)
  382. }
  383. if self.out_units is not None:
  384. map_dict_local.update({
  385. "{}.output_linear.weight".format(tensor_name_prefix_torch):
  386. {"name": "{}/conv1d/kernel".format(tensor_name_prefix_tf),
  387. "squeeze": 0,
  388. "transpose": (1, 0),
  389. },
  390. "{}.output_linear.bias".format(tensor_name_prefix_torch):
  391. {"name": "{}/conv1d/bias".format(tensor_name_prefix_tf),
  392. "squeeze": None,
  393. "transpose": None,
  394. }, # (256,),(256,)
  395. })
  396. return map_dict_local
  397. def convert_tf2torch(self,
  398. var_dict_tf,
  399. var_dict_torch,
  400. ):
  401. map_dict = self.gen_tf2torch_map_dict()
  402. var_dict_torch_update = dict()
  403. for name in sorted(var_dict_torch.keys(), reverse=False):
  404. if name.startswith(self.tf2torch_tensor_name_prefix_torch):
  405. # process special (first and last) layers
  406. if name in map_dict:
  407. name_tf = map_dict[name]["name"]
  408. data_tf = var_dict_tf[name_tf]
  409. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  410. if map_dict[name]["squeeze"] is not None:
  411. data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
  412. if map_dict[name]["transpose"] is not None:
  413. data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
  414. assert var_dict_torch[name].size() == data_tf.size(), \
  415. "{}, {}, {} != {}".format(name, name_tf,
  416. var_dict_torch[name].size(), data_tf.size())
  417. var_dict_torch_update[name] = data_tf
  418. logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
  419. name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
  420. ))
  421. # process general layers
  422. else:
  423. # self.tf2torch_tensor_name_prefix_torch may include ".", solve this case
  424. names = name.replace(self.tf2torch_tensor_name_prefix_torch, "todo").split('.')
  425. layeridx = int(names[2])
  426. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  427. if name_q in map_dict.keys():
  428. name_v = map_dict[name_q]["name"]
  429. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  430. data_tf = var_dict_tf[name_tf]
  431. if map_dict[name_q]["squeeze"] is not None:
  432. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  433. if map_dict[name_q]["transpose"] is not None:
  434. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  435. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  436. assert var_dict_torch[name].size() == data_tf.size(), \
  437. "{}, {}, {} != {}".format(name, name_tf,
  438. var_dict_torch[name].size(), data_tf.size())
  439. var_dict_torch_update[name] = data_tf
  440. logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(
  441. name, data_tf.size(), name_tf, var_dict_tf[name_tf].shape
  442. ))
  443. else:
  444. logging.warning("{} is missed from tf checkpoint".format(name))
  445. return var_dict_torch_update