sanm_encoder.py 42 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918
  1. from typing import List
  2. from typing import Optional
  3. from typing import Sequence
  4. from typing import Tuple
  5. from typing import Union
  6. import logging
  7. import torch
  8. import torch.nn as nn
  9. from funasr.modules.streaming_utils.chunk_utilis import overlap_chunk
  10. from typeguard import check_argument_types
  11. import numpy as np
  12. from funasr.modules.nets_utils import make_pad_mask
  13. from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM
  14. from funasr.modules.embedding import SinusoidalPositionEncoder
  15. from funasr.modules.layer_norm import LayerNorm
  16. from funasr.modules.multi_layer_conv import Conv1dLinear
  17. from funasr.modules.multi_layer_conv import MultiLayeredConv1d
  18. from funasr.modules.positionwise_feed_forward import (
  19. PositionwiseFeedForward, # noqa: H301
  20. )
  21. from funasr.modules.repeat import repeat
  22. from funasr.modules.subsampling import Conv2dSubsampling
  23. from funasr.modules.subsampling import Conv2dSubsampling2
  24. from funasr.modules.subsampling import Conv2dSubsampling6
  25. from funasr.modules.subsampling import Conv2dSubsampling8
  26. from funasr.modules.subsampling import TooShortUttError
  27. from funasr.modules.subsampling import check_short_utt
  28. from funasr.models.ctc import CTC
  29. from funasr.models.encoder.abs_encoder import AbsEncoder
  30. class EncoderLayerSANM(nn.Module):
  31. def __init__(
  32. self,
  33. in_size,
  34. size,
  35. self_attn,
  36. feed_forward,
  37. dropout_rate,
  38. normalize_before=True,
  39. concat_after=False,
  40. stochastic_depth_rate=0.0,
  41. ):
  42. """Construct an EncoderLayer object."""
  43. super(EncoderLayerSANM, self).__init__()
  44. self.self_attn = self_attn
  45. self.feed_forward = feed_forward
  46. self.norm1 = LayerNorm(in_size)
  47. self.norm2 = LayerNorm(size)
  48. self.dropout = nn.Dropout(dropout_rate)
  49. self.in_size = in_size
  50. self.size = size
  51. self.normalize_before = normalize_before
  52. self.concat_after = concat_after
  53. if self.concat_after:
  54. self.concat_linear = nn.Linear(size + size, size)
  55. self.stochastic_depth_rate = stochastic_depth_rate
  56. self.dropout_rate = dropout_rate
  57. def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
  58. """Compute encoded features.
  59. Args:
  60. x_input (torch.Tensor): Input tensor (#batch, time, size).
  61. mask (torch.Tensor): Mask tensor for the input (#batch, time).
  62. cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
  63. Returns:
  64. torch.Tensor: Output tensor (#batch, time, size).
  65. torch.Tensor: Mask tensor (#batch, time).
  66. """
  67. skip_layer = False
  68. # with stochastic depth, residual connection `x + f(x)` becomes
  69. # `x <- x + 1 / (1 - p) * f(x)` at training time.
  70. stoch_layer_coeff = 1.0
  71. if self.training and self.stochastic_depth_rate > 0:
  72. skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
  73. stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
  74. if skip_layer:
  75. if cache is not None:
  76. x = torch.cat([cache, x], dim=1)
  77. return x, mask
  78. residual = x
  79. if self.normalize_before:
  80. x = self.norm1(x)
  81. if self.concat_after:
  82. x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
  83. if self.in_size == self.size:
  84. x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
  85. else:
  86. x = stoch_layer_coeff * self.concat_linear(x_concat)
  87. else:
  88. if self.in_size == self.size:
  89. x = residual + stoch_layer_coeff * self.dropout(
  90. self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
  91. )
  92. else:
  93. x = stoch_layer_coeff * self.dropout(
  94. self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
  95. )
  96. if not self.normalize_before:
  97. x = self.norm1(x)
  98. residual = x
  99. if self.normalize_before:
  100. x = self.norm2(x)
  101. x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
  102. if not self.normalize_before:
  103. x = self.norm2(x)
  104. return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
  105. class SANMEncoder(AbsEncoder):
  106. """
  107. author: Speech Lab, Alibaba Group, China
  108. San-m: Memory equipped self-attention for end-to-end speech recognition
  109. https://arxiv.org/abs/2006.01713
  110. """
  111. def __init__(
  112. self,
  113. input_size: int,
  114. output_size: int = 256,
  115. attention_heads: int = 4,
  116. linear_units: int = 2048,
  117. num_blocks: int = 6,
  118. dropout_rate: float = 0.1,
  119. positional_dropout_rate: float = 0.1,
  120. attention_dropout_rate: float = 0.0,
  121. input_layer: Optional[str] = "conv2d",
  122. pos_enc_class=SinusoidalPositionEncoder,
  123. normalize_before: bool = True,
  124. concat_after: bool = False,
  125. positionwise_layer_type: str = "linear",
  126. positionwise_conv_kernel_size: int = 1,
  127. padding_idx: int = -1,
  128. interctc_layer_idx: List[int] = [],
  129. interctc_use_conditioning: bool = False,
  130. kernel_size : int = 11,
  131. sanm_shfit : int = 0,
  132. selfattention_layer_type: str = "sanm",
  133. tf2torch_tensor_name_prefix_torch: str = "encoder",
  134. tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
  135. ):
  136. assert check_argument_types()
  137. super().__init__()
  138. self._output_size = output_size
  139. if input_layer == "linear":
  140. self.embed = torch.nn.Sequential(
  141. torch.nn.Linear(input_size, output_size),
  142. torch.nn.LayerNorm(output_size),
  143. torch.nn.Dropout(dropout_rate),
  144. torch.nn.ReLU(),
  145. pos_enc_class(output_size, positional_dropout_rate),
  146. )
  147. elif input_layer == "conv2d":
  148. self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
  149. elif input_layer == "conv2d2":
  150. self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
  151. elif input_layer == "conv2d6":
  152. self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
  153. elif input_layer == "conv2d8":
  154. self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
  155. elif input_layer == "embed":
  156. self.embed = torch.nn.Sequential(
  157. torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
  158. SinusoidalPositionEncoder(),
  159. )
  160. elif input_layer is None:
  161. if input_size == output_size:
  162. self.embed = None
  163. else:
  164. self.embed = torch.nn.Linear(input_size, output_size)
  165. elif input_layer == "pe":
  166. self.embed = SinusoidalPositionEncoder()
  167. else:
  168. raise ValueError("unknown input_layer: " + input_layer)
  169. self.normalize_before = normalize_before
  170. if positionwise_layer_type == "linear":
  171. positionwise_layer = PositionwiseFeedForward
  172. positionwise_layer_args = (
  173. output_size,
  174. linear_units,
  175. dropout_rate,
  176. )
  177. elif positionwise_layer_type == "conv1d":
  178. positionwise_layer = MultiLayeredConv1d
  179. positionwise_layer_args = (
  180. output_size,
  181. linear_units,
  182. positionwise_conv_kernel_size,
  183. dropout_rate,
  184. )
  185. elif positionwise_layer_type == "conv1d-linear":
  186. positionwise_layer = Conv1dLinear
  187. positionwise_layer_args = (
  188. output_size,
  189. linear_units,
  190. positionwise_conv_kernel_size,
  191. dropout_rate,
  192. )
  193. else:
  194. raise NotImplementedError("Support only linear or conv1d.")
  195. if selfattention_layer_type == "selfattn":
  196. encoder_selfattn_layer = MultiHeadedAttention
  197. encoder_selfattn_layer_args = (
  198. attention_heads,
  199. output_size,
  200. attention_dropout_rate,
  201. )
  202. elif selfattention_layer_type == "sanm":
  203. encoder_selfattn_layer = MultiHeadedAttentionSANM
  204. encoder_selfattn_layer_args0 = (
  205. attention_heads,
  206. input_size,
  207. output_size,
  208. attention_dropout_rate,
  209. kernel_size,
  210. sanm_shfit,
  211. )
  212. encoder_selfattn_layer_args = (
  213. attention_heads,
  214. output_size,
  215. output_size,
  216. attention_dropout_rate,
  217. kernel_size,
  218. sanm_shfit,
  219. )
  220. self.encoders0 = repeat(
  221. 1,
  222. lambda lnum: EncoderLayerSANM(
  223. input_size,
  224. output_size,
  225. encoder_selfattn_layer(*encoder_selfattn_layer_args0),
  226. positionwise_layer(*positionwise_layer_args),
  227. dropout_rate,
  228. normalize_before,
  229. concat_after,
  230. ),
  231. )
  232. self.encoders = repeat(
  233. num_blocks-1,
  234. lambda lnum: EncoderLayerSANM(
  235. output_size,
  236. output_size,
  237. encoder_selfattn_layer(*encoder_selfattn_layer_args),
  238. positionwise_layer(*positionwise_layer_args),
  239. dropout_rate,
  240. normalize_before,
  241. concat_after,
  242. ),
  243. )
  244. if self.normalize_before:
  245. self.after_norm = LayerNorm(output_size)
  246. self.interctc_layer_idx = interctc_layer_idx
  247. if len(interctc_layer_idx) > 0:
  248. assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
  249. self.interctc_use_conditioning = interctc_use_conditioning
  250. self.conditioning_layer = None
  251. self.dropout = nn.Dropout(dropout_rate)
  252. self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
  253. self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
  254. def output_size(self) -> int:
  255. return self._output_size
  256. def forward(
  257. self,
  258. xs_pad: torch.Tensor,
  259. ilens: torch.Tensor,
  260. prev_states: torch.Tensor = None,
  261. ctc: CTC = None,
  262. ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
  263. """Embed positions in tensor.
  264. Args:
  265. xs_pad: input tensor (B, L, D)
  266. ilens: input length (B)
  267. prev_states: Not to be used now.
  268. Returns:
  269. position embedded tensor and mask
  270. """
  271. masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
  272. xs_pad *= self.output_size()**0.5
  273. if self.embed is None:
  274. xs_pad = xs_pad
  275. elif (
  276. isinstance(self.embed, Conv2dSubsampling)
  277. or isinstance(self.embed, Conv2dSubsampling2)
  278. or isinstance(self.embed, Conv2dSubsampling6)
  279. or isinstance(self.embed, Conv2dSubsampling8)
  280. ):
  281. short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
  282. if short_status:
  283. raise TooShortUttError(
  284. f"has {xs_pad.size(1)} frames and is too short for subsampling "
  285. + f"(it needs more than {limit_size} frames), return empty results",
  286. xs_pad.size(1),
  287. limit_size,
  288. )
  289. xs_pad, masks = self.embed(xs_pad, masks)
  290. else:
  291. xs_pad = self.embed(xs_pad)
  292. # xs_pad = self.dropout(xs_pad)
  293. encoder_outs = self.encoders0(xs_pad, masks)
  294. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  295. intermediate_outs = []
  296. if len(self.interctc_layer_idx) == 0:
  297. encoder_outs = self.encoders(xs_pad, masks)
  298. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  299. else:
  300. for layer_idx, encoder_layer in enumerate(self.encoders):
  301. encoder_outs = encoder_layer(xs_pad, masks)
  302. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  303. if layer_idx + 1 in self.interctc_layer_idx:
  304. encoder_out = xs_pad
  305. # intermediate outputs are also normalized
  306. if self.normalize_before:
  307. encoder_out = self.after_norm(encoder_out)
  308. intermediate_outs.append((layer_idx + 1, encoder_out))
  309. if self.interctc_use_conditioning:
  310. ctc_out = ctc.softmax(encoder_out)
  311. xs_pad = xs_pad + self.conditioning_layer(ctc_out)
  312. if self.normalize_before:
  313. xs_pad = self.after_norm(xs_pad)
  314. olens = masks.squeeze(1).sum(1)
  315. if len(intermediate_outs) > 0:
  316. return (xs_pad, intermediate_outs), olens, None
  317. return xs_pad, olens, None
  318. def gen_tf2torch_map_dict(self):
  319. tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
  320. tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
  321. map_dict_local = {
  322. ## encoder
  323. # cicd
  324. "{}.encoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
  325. {"name": "{}/layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
  326. "squeeze": None,
  327. "transpose": None,
  328. }, # (256,),(256,)
  329. "{}.encoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
  330. {"name": "{}/layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
  331. "squeeze": None,
  332. "transpose": None,
  333. }, # (256,),(256,)
  334. "{}.encoders.layeridx.self_attn.linear_q_k_v.weight".format(tensor_name_prefix_torch):
  335. {"name": "{}/layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
  336. "squeeze": 0,
  337. "transpose": (1, 0),
  338. }, # (768,256),(1,256,768)
  339. "{}.encoders.layeridx.self_attn.linear_q_k_v.bias".format(tensor_name_prefix_torch):
  340. {"name": "{}/layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
  341. "squeeze": None,
  342. "transpose": None,
  343. }, # (768,),(768,)
  344. "{}.encoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
  345. {"name": "{}/layer_layeridx/multi_head/depth_conv_w".format(tensor_name_prefix_tf),
  346. "squeeze": 0,
  347. "transpose": (1, 2, 0),
  348. }, # (256,1,31),(1,31,256,1)
  349. "{}.encoders.layeridx.self_attn.linear_out.weight".format(tensor_name_prefix_torch):
  350. {"name": "{}/layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
  351. "squeeze": 0,
  352. "transpose": (1, 0),
  353. }, # (256,256),(1,256,256)
  354. "{}.encoders.layeridx.self_attn.linear_out.bias".format(tensor_name_prefix_torch):
  355. {"name": "{}/layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
  356. "squeeze": None,
  357. "transpose": None,
  358. }, # (256,),(256,)
  359. # ffn
  360. "{}.encoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
  361. {"name": "{}/layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
  362. "squeeze": None,
  363. "transpose": None,
  364. }, # (256,),(256,)
  365. "{}.encoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
  366. {"name": "{}/layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
  367. "squeeze": None,
  368. "transpose": None,
  369. }, # (256,),(256,)
  370. "{}.encoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
  371. {"name": "{}/layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf),
  372. "squeeze": 0,
  373. "transpose": (1, 0),
  374. }, # (1024,256),(1,256,1024)
  375. "{}.encoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
  376. {"name": "{}/layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf),
  377. "squeeze": None,
  378. "transpose": None,
  379. }, # (1024,),(1024,)
  380. "{}.encoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
  381. {"name": "{}/layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
  382. "squeeze": 0,
  383. "transpose": (1, 0),
  384. }, # (256,1024),(1,1024,256)
  385. "{}.encoders.layeridx.feed_forward.w_2.bias".format(tensor_name_prefix_torch):
  386. {"name": "{}/layer_layeridx/ffn/conv1d_1/bias".format(tensor_name_prefix_tf),
  387. "squeeze": None,
  388. "transpose": None,
  389. }, # (256,),(256,)
  390. # out norm
  391. "{}.after_norm.weight".format(tensor_name_prefix_torch):
  392. {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
  393. "squeeze": None,
  394. "transpose": None,
  395. }, # (256,),(256,)
  396. "{}.after_norm.bias".format(tensor_name_prefix_torch):
  397. {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
  398. "squeeze": None,
  399. "transpose": None,
  400. }, # (256,),(256,)
  401. }
  402. return map_dict_local
  403. def convert_tf2torch(self,
  404. var_dict_tf,
  405. var_dict_torch,
  406. ):
  407. map_dict = self.gen_tf2torch_map_dict()
  408. var_dict_torch_update = dict()
  409. for name in sorted(var_dict_torch.keys(), reverse=False):
  410. names = name.split('.')
  411. if names[0] == self.tf2torch_tensor_name_prefix_torch:
  412. if names[1] == "encoders0":
  413. layeridx = int(names[2])
  414. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  415. name_q = name_q.replace("encoders0", "encoders")
  416. layeridx_bias = 0
  417. layeridx += layeridx_bias
  418. if name_q in map_dict.keys():
  419. name_v = map_dict[name_q]["name"]
  420. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  421. data_tf = var_dict_tf[name_tf]
  422. if map_dict[name_q]["squeeze"] is not None:
  423. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  424. if map_dict[name_q]["transpose"] is not None:
  425. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  426. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  427. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  428. var_dict_torch[
  429. name].size(),
  430. data_tf.size())
  431. var_dict_torch_update[name] = data_tf
  432. logging.info(
  433. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  434. var_dict_tf[name_tf].shape))
  435. elif names[1] == "encoders":
  436. layeridx = int(names[2])
  437. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  438. layeridx_bias = 1
  439. layeridx += layeridx_bias
  440. if name_q in map_dict.keys():
  441. name_v = map_dict[name_q]["name"]
  442. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  443. data_tf = var_dict_tf[name_tf]
  444. if map_dict[name_q]["squeeze"] is not None:
  445. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  446. if map_dict[name_q]["transpose"] is not None:
  447. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  448. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  449. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  450. var_dict_torch[
  451. name].size(),
  452. data_tf.size())
  453. var_dict_torch_update[name] = data_tf
  454. logging.info(
  455. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  456. var_dict_tf[name_tf].shape))
  457. elif names[1] == "after_norm":
  458. name_tf = map_dict[name]["name"]
  459. data_tf = var_dict_tf[name_tf]
  460. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  461. var_dict_torch_update[name] = data_tf
  462. logging.info(
  463. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
  464. var_dict_tf[name_tf].shape))
  465. return var_dict_torch_update
  466. class SANMEncoderChunkOpt(AbsEncoder):
  467. """
  468. author: Speech Lab, Alibaba Group, China
  469. SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
  470. https://arxiv.org/abs/2006.01713
  471. """
  472. def __init__(
  473. self,
  474. input_size: int,
  475. output_size: int = 256,
  476. attention_heads: int = 4,
  477. linear_units: int = 2048,
  478. num_blocks: int = 6,
  479. dropout_rate: float = 0.1,
  480. positional_dropout_rate: float = 0.1,
  481. attention_dropout_rate: float = 0.0,
  482. input_layer: Optional[str] = "conv2d",
  483. pos_enc_class=SinusoidalPositionEncoder,
  484. normalize_before: bool = True,
  485. concat_after: bool = False,
  486. positionwise_layer_type: str = "linear",
  487. positionwise_conv_kernel_size: int = 1,
  488. padding_idx: int = -1,
  489. interctc_layer_idx: List[int] = [],
  490. interctc_use_conditioning: bool = False,
  491. kernel_size: int = 11,
  492. sanm_shfit: int = 0,
  493. selfattention_layer_type: str = "sanm",
  494. chunk_size: Union[int, Sequence[int]] = (16,),
  495. stride: Union[int, Sequence[int]] = (10,),
  496. pad_left: Union[int, Sequence[int]] = (0,),
  497. encoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
  498. decoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
  499. tf2torch_tensor_name_prefix_torch: str = "encoder",
  500. tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
  501. ):
  502. assert check_argument_types()
  503. super().__init__()
  504. self._output_size = output_size
  505. if input_layer == "linear":
  506. self.embed = torch.nn.Sequential(
  507. torch.nn.Linear(input_size, output_size),
  508. torch.nn.LayerNorm(output_size),
  509. torch.nn.Dropout(dropout_rate),
  510. torch.nn.ReLU(),
  511. pos_enc_class(output_size, positional_dropout_rate),
  512. )
  513. elif input_layer == "conv2d":
  514. self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
  515. elif input_layer == "conv2d2":
  516. self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
  517. elif input_layer == "conv2d6":
  518. self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
  519. elif input_layer == "conv2d8":
  520. self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
  521. elif input_layer == "embed":
  522. self.embed = torch.nn.Sequential(
  523. torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
  524. pos_enc_class(output_size, positional_dropout_rate),
  525. )
  526. elif input_layer is None:
  527. if input_size == output_size:
  528. self.embed = None
  529. else:
  530. self.embed = torch.nn.Linear(input_size, output_size)
  531. elif input_layer == "pe":
  532. self.embed = SinusoidalPositionEncoder()
  533. else:
  534. raise ValueError("unknown input_layer: " + input_layer)
  535. self.normalize_before = normalize_before
  536. if positionwise_layer_type == "linear":
  537. positionwise_layer = PositionwiseFeedForward
  538. positionwise_layer_args = (
  539. output_size,
  540. linear_units,
  541. dropout_rate,
  542. )
  543. elif positionwise_layer_type == "conv1d":
  544. positionwise_layer = MultiLayeredConv1d
  545. positionwise_layer_args = (
  546. output_size,
  547. linear_units,
  548. positionwise_conv_kernel_size,
  549. dropout_rate,
  550. )
  551. elif positionwise_layer_type == "conv1d-linear":
  552. positionwise_layer = Conv1dLinear
  553. positionwise_layer_args = (
  554. output_size,
  555. linear_units,
  556. positionwise_conv_kernel_size,
  557. dropout_rate,
  558. )
  559. else:
  560. raise NotImplementedError("Support only linear or conv1d.")
  561. if selfattention_layer_type == "selfattn":
  562. encoder_selfattn_layer = MultiHeadedAttention
  563. encoder_selfattn_layer_args = (
  564. attention_heads,
  565. output_size,
  566. attention_dropout_rate,
  567. )
  568. elif selfattention_layer_type == "sanm":
  569. encoder_selfattn_layer = MultiHeadedAttentionSANM
  570. encoder_selfattn_layer_args0 = (
  571. attention_heads,
  572. input_size,
  573. output_size,
  574. attention_dropout_rate,
  575. kernel_size,
  576. sanm_shfit,
  577. )
  578. encoder_selfattn_layer_args = (
  579. attention_heads,
  580. output_size,
  581. output_size,
  582. attention_dropout_rate,
  583. kernel_size,
  584. sanm_shfit,
  585. )
  586. self.encoders0 = repeat(
  587. 1,
  588. lambda lnum: EncoderLayerSANM(
  589. input_size,
  590. output_size,
  591. encoder_selfattn_layer(*encoder_selfattn_layer_args0),
  592. positionwise_layer(*positionwise_layer_args),
  593. dropout_rate,
  594. normalize_before,
  595. concat_after,
  596. ),
  597. )
  598. self.encoders = repeat(
  599. num_blocks - 1,
  600. lambda lnum: EncoderLayerSANM(
  601. output_size,
  602. output_size,
  603. encoder_selfattn_layer(*encoder_selfattn_layer_args),
  604. positionwise_layer(*positionwise_layer_args),
  605. dropout_rate,
  606. normalize_before,
  607. concat_after,
  608. ),
  609. )
  610. if self.normalize_before:
  611. self.after_norm = LayerNorm(output_size)
  612. self.interctc_layer_idx = interctc_layer_idx
  613. if len(interctc_layer_idx) > 0:
  614. assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
  615. self.interctc_use_conditioning = interctc_use_conditioning
  616. self.conditioning_layer = None
  617. shfit_fsmn = (kernel_size - 1) // 2
  618. self.overlap_chunk_cls = overlap_chunk(
  619. chunk_size=chunk_size,
  620. stride=stride,
  621. pad_left=pad_left,
  622. shfit_fsmn=shfit_fsmn,
  623. encoder_att_look_back_factor=encoder_att_look_back_factor,
  624. decoder_att_look_back_factor=decoder_att_look_back_factor,
  625. )
  626. self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
  627. self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
  628. def output_size(self) -> int:
  629. return self._output_size
  630. def forward(
  631. self,
  632. xs_pad: torch.Tensor,
  633. ilens: torch.Tensor,
  634. prev_states: torch.Tensor = None,
  635. ctc: CTC = None,
  636. ind: int = 0,
  637. ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
  638. """Embed positions in tensor.
  639. Args:
  640. xs_pad: input tensor (B, L, D)
  641. ilens: input length (B)
  642. prev_states: Not to be used now.
  643. Returns:
  644. position embedded tensor and mask
  645. """
  646. masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
  647. xs_pad *= self.output_size() ** 0.5
  648. if self.embed is None:
  649. xs_pad = xs_pad
  650. elif (
  651. isinstance(self.embed, Conv2dSubsampling)
  652. or isinstance(self.embed, Conv2dSubsampling2)
  653. or isinstance(self.embed, Conv2dSubsampling6)
  654. or isinstance(self.embed, Conv2dSubsampling8)
  655. ):
  656. short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
  657. if short_status:
  658. raise TooShortUttError(
  659. f"has {xs_pad.size(1)} frames and is too short for subsampling "
  660. + f"(it needs more than {limit_size} frames), return empty results",
  661. xs_pad.size(1),
  662. limit_size,
  663. )
  664. xs_pad, masks = self.embed(xs_pad, masks)
  665. else:
  666. xs_pad = self.embed(xs_pad)
  667. mask_shfit_chunk, mask_att_chunk_encoder = None, None
  668. if self.overlap_chunk_cls is not None:
  669. ilens = masks.squeeze(1).sum(1)
  670. chunk_outs = self.overlap_chunk_cls.gen_chunk_mask(ilens, ind)
  671. xs_pad, ilens = self.overlap_chunk_cls.split_chunk(xs_pad, ilens, chunk_outs=chunk_outs)
  672. masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
  673. mask_shfit_chunk = self.overlap_chunk_cls.get_mask_shfit_chunk(chunk_outs, xs_pad.device, xs_pad.size(0),
  674. dtype=xs_pad.dtype)
  675. mask_att_chunk_encoder = self.overlap_chunk_cls.get_mask_att_chunk_encoder(chunk_outs, xs_pad.device,
  676. xs_pad.size(0),
  677. dtype=xs_pad.dtype)
  678. encoder_outs = self.encoders0(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
  679. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  680. intermediate_outs = []
  681. if len(self.interctc_layer_idx) == 0:
  682. encoder_outs = self.encoders(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
  683. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  684. else:
  685. for layer_idx, encoder_layer in enumerate(self.encoders):
  686. encoder_outs = encoder_layer(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
  687. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  688. if layer_idx + 1 in self.interctc_layer_idx:
  689. encoder_out = xs_pad
  690. # intermediate outputs are also normalized
  691. if self.normalize_before:
  692. encoder_out = self.after_norm(encoder_out)
  693. intermediate_outs.append((layer_idx + 1, encoder_out))
  694. if self.interctc_use_conditioning:
  695. ctc_out = ctc.softmax(encoder_out)
  696. xs_pad = xs_pad + self.conditioning_layer(ctc_out)
  697. if self.normalize_before:
  698. xs_pad = self.after_norm(xs_pad)
  699. olens = masks.squeeze(1).sum(1)
  700. if len(intermediate_outs) > 0:
  701. return (xs_pad, intermediate_outs), olens, None
  702. return xs_pad, olens, None
  703. def gen_tf2torch_map_dict(self):
  704. tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
  705. tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
  706. map_dict_local = {
  707. ## encoder
  708. # cicd
  709. "{}.encoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
  710. {"name": "{}/layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
  711. "squeeze": None,
  712. "transpose": None,
  713. }, # (256,),(256,)
  714. "{}.encoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
  715. {"name": "{}/layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
  716. "squeeze": None,
  717. "transpose": None,
  718. }, # (256,),(256,)
  719. "{}.encoders.layeridx.self_attn.linear_q_k_v.weight".format(tensor_name_prefix_torch):
  720. {"name": "{}/layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
  721. "squeeze": 0,
  722. "transpose": (1, 0),
  723. }, # (768,256),(1,256,768)
  724. "{}.encoders.layeridx.self_attn.linear_q_k_v.bias".format(tensor_name_prefix_torch):
  725. {"name": "{}/layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
  726. "squeeze": None,
  727. "transpose": None,
  728. }, # (768,),(768,)
  729. "{}.encoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
  730. {"name": "{}/layer_layeridx/multi_head/depth_conv_w".format(tensor_name_prefix_tf),
  731. "squeeze": 0,
  732. "transpose": (1, 2, 0),
  733. }, # (256,1,31),(1,31,256,1)
  734. "{}.encoders.layeridx.self_attn.linear_out.weight".format(tensor_name_prefix_torch):
  735. {"name": "{}/layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
  736. "squeeze": 0,
  737. "transpose": (1, 0),
  738. }, # (256,256),(1,256,256)
  739. "{}.encoders.layeridx.self_attn.linear_out.bias".format(tensor_name_prefix_torch):
  740. {"name": "{}/layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
  741. "squeeze": None,
  742. "transpose": None,
  743. }, # (256,),(256,)
  744. # ffn
  745. "{}.encoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
  746. {"name": "{}/layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
  747. "squeeze": None,
  748. "transpose": None,
  749. }, # (256,),(256,)
  750. "{}.encoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
  751. {"name": "{}/layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
  752. "squeeze": None,
  753. "transpose": None,
  754. }, # (256,),(256,)
  755. "{}.encoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
  756. {"name": "{}/layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf),
  757. "squeeze": 0,
  758. "transpose": (1, 0),
  759. }, # (1024,256),(1,256,1024)
  760. "{}.encoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
  761. {"name": "{}/layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf),
  762. "squeeze": None,
  763. "transpose": None,
  764. }, # (1024,),(1024,)
  765. "{}.encoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
  766. {"name": "{}/layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
  767. "squeeze": 0,
  768. "transpose": (1, 0),
  769. }, # (256,1024),(1,1024,256)
  770. "{}.encoders.layeridx.feed_forward.w_2.bias".format(tensor_name_prefix_torch):
  771. {"name": "{}/layer_layeridx/ffn/conv1d_1/bias".format(tensor_name_prefix_tf),
  772. "squeeze": None,
  773. "transpose": None,
  774. }, # (256,),(256,)
  775. # out norm
  776. "{}.after_norm.weight".format(tensor_name_prefix_torch):
  777. {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
  778. "squeeze": None,
  779. "transpose": None,
  780. }, # (256,),(256,)
  781. "{}.after_norm.bias".format(tensor_name_prefix_torch):
  782. {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
  783. "squeeze": None,
  784. "transpose": None,
  785. }, # (256,),(256,)
  786. }
  787. return map_dict_local
  788. def convert_tf2torch(self,
  789. var_dict_tf,
  790. var_dict_torch,
  791. ):
  792. map_dict = self.gen_tf2torch_map_dict()
  793. var_dict_torch_update = dict()
  794. for name in sorted(var_dict_torch.keys(), reverse=False):
  795. names = name.split('.')
  796. if names[0] == self.tf2torch_tensor_name_prefix_torch:
  797. if names[1] == "encoders0":
  798. layeridx = int(names[2])
  799. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  800. name_q = name_q.replace("encoders0", "encoders")
  801. layeridx_bias = 0
  802. layeridx += layeridx_bias
  803. if name_q in map_dict.keys():
  804. name_v = map_dict[name_q]["name"]
  805. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  806. data_tf = var_dict_tf[name_tf]
  807. if map_dict[name_q]["squeeze"] is not None:
  808. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  809. if map_dict[name_q]["transpose"] is not None:
  810. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  811. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  812. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  813. var_dict_torch[
  814. name].size(),
  815. data_tf.size())
  816. var_dict_torch_update[name] = data_tf
  817. logging.info(
  818. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  819. var_dict_tf[name_tf].shape))
  820. elif names[1] == "encoders":
  821. layeridx = int(names[2])
  822. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  823. layeridx_bias = 1
  824. layeridx += layeridx_bias
  825. if name_q in map_dict.keys():
  826. name_v = map_dict[name_q]["name"]
  827. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  828. data_tf = var_dict_tf[name_tf]
  829. if map_dict[name_q]["squeeze"] is not None:
  830. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  831. if map_dict[name_q]["transpose"] is not None:
  832. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  833. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  834. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  835. var_dict_torch[
  836. name].size(),
  837. data_tf.size())
  838. var_dict_torch_update[name] = data_tf
  839. logging.info(
  840. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  841. var_dict_tf[name_tf].shape))
  842. elif names[1] == "after_norm":
  843. name_tf = map_dict[name]["name"]
  844. data_tf = var_dict_tf[name_tf]
  845. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  846. var_dict_torch_update[name] = data_tf
  847. logging.info(
  848. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
  849. var_dict_tf[name_tf].shape))
  850. return var_dict_torch_update