sanm_encoder.py 43 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960
  1. from typing import List
  2. from typing import Optional
  3. from typing import Sequence
  4. from typing import Tuple
  5. from typing import Union
  6. import logging
  7. import torch
  8. import torch.nn as nn
  9. from funasr.modules.streaming_utils.chunk_utilis import overlap_chunk
  10. from typeguard import check_argument_types
  11. import numpy as np
  12. from funasr.modules.nets_utils import make_pad_mask
  13. from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM
  14. from funasr.modules.embedding import SinusoidalPositionEncoder
  15. from funasr.modules.layer_norm import LayerNorm
  16. from funasr.modules.multi_layer_conv import Conv1dLinear
  17. from funasr.modules.multi_layer_conv import MultiLayeredConv1d
  18. from funasr.modules.positionwise_feed_forward import (
  19. PositionwiseFeedForward, # noqa: H301
  20. )
  21. from funasr.modules.repeat import repeat
  22. from funasr.modules.subsampling import Conv2dSubsampling
  23. from funasr.modules.subsampling import Conv2dSubsampling2
  24. from funasr.modules.subsampling import Conv2dSubsampling6
  25. from funasr.modules.subsampling import Conv2dSubsampling8
  26. from funasr.modules.subsampling import TooShortUttError
  27. from funasr.modules.subsampling import check_short_utt
  28. from funasr.models.ctc import CTC
  29. from funasr.models.encoder.abs_encoder import AbsEncoder
  30. class EncoderLayerSANM(nn.Module):
  31. def __init__(
  32. self,
  33. in_size,
  34. size,
  35. self_attn,
  36. feed_forward,
  37. dropout_rate,
  38. normalize_before=True,
  39. concat_after=False,
  40. stochastic_depth_rate=0.0,
  41. ):
  42. """Construct an EncoderLayer object."""
  43. super(EncoderLayerSANM, self).__init__()
  44. self.self_attn = self_attn
  45. self.feed_forward = feed_forward
  46. self.norm1 = LayerNorm(in_size)
  47. self.norm2 = LayerNorm(size)
  48. self.dropout = nn.Dropout(dropout_rate)
  49. self.in_size = in_size
  50. self.size = size
  51. self.normalize_before = normalize_before
  52. self.concat_after = concat_after
  53. if self.concat_after:
  54. self.concat_linear = nn.Linear(size + size, size)
  55. self.stochastic_depth_rate = stochastic_depth_rate
  56. self.dropout_rate = dropout_rate
  57. def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
  58. """Compute encoded features.
  59. Args:
  60. x_input (torch.Tensor): Input tensor (#batch, time, size).
  61. mask (torch.Tensor): Mask tensor for the input (#batch, time).
  62. cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
  63. Returns:
  64. torch.Tensor: Output tensor (#batch, time, size).
  65. torch.Tensor: Mask tensor (#batch, time).
  66. """
  67. skip_layer = False
  68. # with stochastic depth, residual connection `x + f(x)` becomes
  69. # `x <- x + 1 / (1 - p) * f(x)` at training time.
  70. stoch_layer_coeff = 1.0
  71. if self.training and self.stochastic_depth_rate > 0:
  72. skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
  73. stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
  74. if skip_layer:
  75. if cache is not None:
  76. x = torch.cat([cache, x], dim=1)
  77. return x, mask
  78. residual = x
  79. if self.normalize_before:
  80. x = self.norm1(x)
  81. if self.concat_after:
  82. x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
  83. if self.in_size == self.size:
  84. x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
  85. else:
  86. x = stoch_layer_coeff * self.concat_linear(x_concat)
  87. else:
  88. if self.in_size == self.size:
  89. x = residual + stoch_layer_coeff * self.dropout(
  90. self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
  91. )
  92. else:
  93. x = stoch_layer_coeff * self.dropout(
  94. self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
  95. )
  96. if not self.normalize_before:
  97. x = self.norm1(x)
  98. residual = x
  99. if self.normalize_before:
  100. x = self.norm2(x)
  101. x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
  102. if not self.normalize_before:
  103. x = self.norm2(x)
  104. return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
  105. class SANMEncoder(AbsEncoder):
  106. """
  107. author: Speech Lab, Alibaba Group, China
  108. San-m: Memory equipped self-attention for end-to-end speech recognition
  109. https://arxiv.org/abs/2006.01713
  110. """
  111. def __init__(
  112. self,
  113. input_size: int,
  114. output_size: int = 256,
  115. attention_heads: int = 4,
  116. linear_units: int = 2048,
  117. num_blocks: int = 6,
  118. dropout_rate: float = 0.1,
  119. positional_dropout_rate: float = 0.1,
  120. attention_dropout_rate: float = 0.0,
  121. input_layer: Optional[str] = "conv2d",
  122. pos_enc_class=SinusoidalPositionEncoder,
  123. normalize_before: bool = True,
  124. concat_after: bool = False,
  125. positionwise_layer_type: str = "linear",
  126. positionwise_conv_kernel_size: int = 1,
  127. padding_idx: int = -1,
  128. interctc_layer_idx: List[int] = [],
  129. interctc_use_conditioning: bool = False,
  130. kernel_size : int = 11,
  131. sanm_shfit : int = 0,
  132. selfattention_layer_type: str = "sanm",
  133. tf2torch_tensor_name_prefix_torch: str = "encoder",
  134. tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
  135. ):
  136. assert check_argument_types()
  137. super().__init__()
  138. self._output_size = output_size
  139. if input_layer == "linear":
  140. self.embed = torch.nn.Sequential(
  141. torch.nn.Linear(input_size, output_size),
  142. torch.nn.LayerNorm(output_size),
  143. torch.nn.Dropout(dropout_rate),
  144. torch.nn.ReLU(),
  145. pos_enc_class(output_size, positional_dropout_rate),
  146. )
  147. elif input_layer == "conv2d":
  148. self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
  149. elif input_layer == "conv2d2":
  150. self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
  151. elif input_layer == "conv2d6":
  152. self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
  153. elif input_layer == "conv2d8":
  154. self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
  155. elif input_layer == "embed":
  156. self.embed = torch.nn.Sequential(
  157. torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
  158. SinusoidalPositionEncoder(),
  159. )
  160. elif input_layer is None:
  161. if input_size == output_size:
  162. self.embed = None
  163. else:
  164. self.embed = torch.nn.Linear(input_size, output_size)
  165. elif input_layer == "pe":
  166. self.embed = SinusoidalPositionEncoder()
  167. else:
  168. raise ValueError("unknown input_layer: " + input_layer)
  169. self.normalize_before = normalize_before
  170. if positionwise_layer_type == "linear":
  171. positionwise_layer = PositionwiseFeedForward
  172. positionwise_layer_args = (
  173. output_size,
  174. linear_units,
  175. dropout_rate,
  176. )
  177. elif positionwise_layer_type == "conv1d":
  178. positionwise_layer = MultiLayeredConv1d
  179. positionwise_layer_args = (
  180. output_size,
  181. linear_units,
  182. positionwise_conv_kernel_size,
  183. dropout_rate,
  184. )
  185. elif positionwise_layer_type == "conv1d-linear":
  186. positionwise_layer = Conv1dLinear
  187. positionwise_layer_args = (
  188. output_size,
  189. linear_units,
  190. positionwise_conv_kernel_size,
  191. dropout_rate,
  192. )
  193. else:
  194. raise NotImplementedError("Support only linear or conv1d.")
  195. if selfattention_layer_type == "selfattn":
  196. encoder_selfattn_layer = MultiHeadedAttention
  197. encoder_selfattn_layer_args = (
  198. attention_heads,
  199. output_size,
  200. attention_dropout_rate,
  201. )
  202. elif selfattention_layer_type == "sanm":
  203. encoder_selfattn_layer = MultiHeadedAttentionSANM
  204. encoder_selfattn_layer_args0 = (
  205. attention_heads,
  206. input_size,
  207. output_size,
  208. attention_dropout_rate,
  209. kernel_size,
  210. sanm_shfit,
  211. )
  212. encoder_selfattn_layer_args = (
  213. attention_heads,
  214. output_size,
  215. output_size,
  216. attention_dropout_rate,
  217. kernel_size,
  218. sanm_shfit,
  219. )
  220. self.encoders0 = repeat(
  221. 1,
  222. lambda lnum: EncoderLayerSANM(
  223. input_size,
  224. output_size,
  225. encoder_selfattn_layer(*encoder_selfattn_layer_args0),
  226. positionwise_layer(*positionwise_layer_args),
  227. dropout_rate,
  228. normalize_before,
  229. concat_after,
  230. ),
  231. )
  232. self.encoders = repeat(
  233. num_blocks-1,
  234. lambda lnum: EncoderLayerSANM(
  235. output_size,
  236. output_size,
  237. encoder_selfattn_layer(*encoder_selfattn_layer_args),
  238. positionwise_layer(*positionwise_layer_args),
  239. dropout_rate,
  240. normalize_before,
  241. concat_after,
  242. ),
  243. )
  244. if self.normalize_before:
  245. self.after_norm = LayerNorm(output_size)
  246. self.interctc_layer_idx = interctc_layer_idx
  247. if len(interctc_layer_idx) > 0:
  248. assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
  249. self.interctc_use_conditioning = interctc_use_conditioning
  250. self.conditioning_layer = None
  251. self.dropout = nn.Dropout(dropout_rate)
  252. self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
  253. self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
  254. def output_size(self) -> int:
  255. return self._output_size
  256. def forward(
  257. self,
  258. xs_pad: torch.Tensor,
  259. ilens: torch.Tensor,
  260. prev_states: torch.Tensor = None,
  261. ctc: CTC = None,
  262. ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
  263. """Embed positions in tensor.
  264. Args:
  265. xs_pad: input tensor (B, L, D)
  266. ilens: input length (B)
  267. prev_states: Not to be used now.
  268. Returns:
  269. position embedded tensor and mask
  270. """
  271. masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
  272. xs_pad = xs_pad * self.output_size()**0.5
  273. if self.embed is None:
  274. xs_pad = xs_pad
  275. elif (
  276. isinstance(self.embed, Conv2dSubsampling)
  277. or isinstance(self.embed, Conv2dSubsampling2)
  278. or isinstance(self.embed, Conv2dSubsampling6)
  279. or isinstance(self.embed, Conv2dSubsampling8)
  280. ):
  281. short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
  282. if short_status:
  283. raise TooShortUttError(
  284. f"has {xs_pad.size(1)} frames and is too short for subsampling "
  285. + f"(it needs more than {limit_size} frames), return empty results",
  286. xs_pad.size(1),
  287. limit_size,
  288. )
  289. xs_pad, masks = self.embed(xs_pad, masks)
  290. else:
  291. xs_pad = self.embed(xs_pad)
  292. # xs_pad = self.dropout(xs_pad)
  293. encoder_outs = self.encoders0(xs_pad, masks)
  294. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  295. intermediate_outs = []
  296. if len(self.interctc_layer_idx) == 0:
  297. encoder_outs = self.encoders(xs_pad, masks)
  298. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  299. else:
  300. for layer_idx, encoder_layer in enumerate(self.encoders):
  301. encoder_outs = encoder_layer(xs_pad, masks)
  302. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  303. if layer_idx + 1 in self.interctc_layer_idx:
  304. encoder_out = xs_pad
  305. # intermediate outputs are also normalized
  306. if self.normalize_before:
  307. encoder_out = self.after_norm(encoder_out)
  308. intermediate_outs.append((layer_idx + 1, encoder_out))
  309. if self.interctc_use_conditioning:
  310. ctc_out = ctc.softmax(encoder_out)
  311. xs_pad = xs_pad + self.conditioning_layer(ctc_out)
  312. if self.normalize_before:
  313. xs_pad = self.after_norm(xs_pad)
  314. olens = masks.squeeze(1).sum(1)
  315. if len(intermediate_outs) > 0:
  316. return (xs_pad, intermediate_outs), olens, None
  317. return xs_pad, olens, None
  318. def forward_chunk(self,
  319. xs_pad: torch.Tensor,
  320. ilens: torch.Tensor,
  321. cache: dict = None,
  322. ctc: CTC = None,
  323. ):
  324. xs_pad *= self.output_size() ** 0.5
  325. if self.embed is None:
  326. xs_pad = xs_pad
  327. else:
  328. xs_pad = self.embed.forward_chunk(xs_pad, cache)
  329. encoder_outs = self.encoders0(xs_pad, None, None, None, None)
  330. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  331. intermediate_outs = []
  332. if len(self.interctc_layer_idx) == 0:
  333. encoder_outs = self.encoders(xs_pad, None, None, None, None)
  334. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  335. else:
  336. for layer_idx, encoder_layer in enumerate(self.encoders):
  337. encoder_outs = encoder_layer(xs_pad, None, None, None, None)
  338. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  339. if layer_idx + 1 in self.interctc_layer_idx:
  340. encoder_out = xs_pad
  341. # intermediate outputs are also normalized
  342. if self.normalize_before:
  343. encoder_out = self.after_norm(encoder_out)
  344. intermediate_outs.append((layer_idx + 1, encoder_out))
  345. if self.interctc_use_conditioning:
  346. ctc_out = ctc.softmax(encoder_out)
  347. xs_pad = xs_pad + self.conditioning_layer(ctc_out)
  348. if self.normalize_before:
  349. xs_pad = self.after_norm(xs_pad)
  350. if len(intermediate_outs) > 0:
  351. return (xs_pad, intermediate_outs), None, None
  352. return xs_pad, ilens, None
  353. def gen_tf2torch_map_dict(self):
  354. tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
  355. tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
  356. map_dict_local = {
  357. ## encoder
  358. # cicd
  359. "{}.encoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
  360. {"name": "{}/layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
  361. "squeeze": None,
  362. "transpose": None,
  363. }, # (256,),(256,)
  364. "{}.encoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
  365. {"name": "{}/layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
  366. "squeeze": None,
  367. "transpose": None,
  368. }, # (256,),(256,)
  369. "{}.encoders.layeridx.self_attn.linear_q_k_v.weight".format(tensor_name_prefix_torch):
  370. {"name": "{}/layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
  371. "squeeze": 0,
  372. "transpose": (1, 0),
  373. }, # (768,256),(1,256,768)
  374. "{}.encoders.layeridx.self_attn.linear_q_k_v.bias".format(tensor_name_prefix_torch):
  375. {"name": "{}/layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
  376. "squeeze": None,
  377. "transpose": None,
  378. }, # (768,),(768,)
  379. "{}.encoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
  380. {"name": "{}/layer_layeridx/multi_head/depth_conv_w".format(tensor_name_prefix_tf),
  381. "squeeze": 0,
  382. "transpose": (1, 2, 0),
  383. }, # (256,1,31),(1,31,256,1)
  384. "{}.encoders.layeridx.self_attn.linear_out.weight".format(tensor_name_prefix_torch):
  385. {"name": "{}/layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
  386. "squeeze": 0,
  387. "transpose": (1, 0),
  388. }, # (256,256),(1,256,256)
  389. "{}.encoders.layeridx.self_attn.linear_out.bias".format(tensor_name_prefix_torch):
  390. {"name": "{}/layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
  391. "squeeze": None,
  392. "transpose": None,
  393. }, # (256,),(256,)
  394. # ffn
  395. "{}.encoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
  396. {"name": "{}/layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
  397. "squeeze": None,
  398. "transpose": None,
  399. }, # (256,),(256,)
  400. "{}.encoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
  401. {"name": "{}/layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
  402. "squeeze": None,
  403. "transpose": None,
  404. }, # (256,),(256,)
  405. "{}.encoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
  406. {"name": "{}/layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf),
  407. "squeeze": 0,
  408. "transpose": (1, 0),
  409. }, # (1024,256),(1,256,1024)
  410. "{}.encoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
  411. {"name": "{}/layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf),
  412. "squeeze": None,
  413. "transpose": None,
  414. }, # (1024,),(1024,)
  415. "{}.encoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
  416. {"name": "{}/layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
  417. "squeeze": 0,
  418. "transpose": (1, 0),
  419. }, # (256,1024),(1,1024,256)
  420. "{}.encoders.layeridx.feed_forward.w_2.bias".format(tensor_name_prefix_torch):
  421. {"name": "{}/layer_layeridx/ffn/conv1d_1/bias".format(tensor_name_prefix_tf),
  422. "squeeze": None,
  423. "transpose": None,
  424. }, # (256,),(256,)
  425. # out norm
  426. "{}.after_norm.weight".format(tensor_name_prefix_torch):
  427. {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
  428. "squeeze": None,
  429. "transpose": None,
  430. }, # (256,),(256,)
  431. "{}.after_norm.bias".format(tensor_name_prefix_torch):
  432. {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
  433. "squeeze": None,
  434. "transpose": None,
  435. }, # (256,),(256,)
  436. }
  437. return map_dict_local
  438. def convert_tf2torch(self,
  439. var_dict_tf,
  440. var_dict_torch,
  441. ):
  442. map_dict = self.gen_tf2torch_map_dict()
  443. var_dict_torch_update = dict()
  444. for name in sorted(var_dict_torch.keys(), reverse=False):
  445. names = name.split('.')
  446. if names[0] == self.tf2torch_tensor_name_prefix_torch:
  447. if names[1] == "encoders0":
  448. layeridx = int(names[2])
  449. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  450. name_q = name_q.replace("encoders0", "encoders")
  451. layeridx_bias = 0
  452. layeridx += layeridx_bias
  453. if name_q in map_dict.keys():
  454. name_v = map_dict[name_q]["name"]
  455. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  456. data_tf = var_dict_tf[name_tf]
  457. if map_dict[name_q]["squeeze"] is not None:
  458. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  459. if map_dict[name_q]["transpose"] is not None:
  460. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  461. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  462. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  463. var_dict_torch[
  464. name].size(),
  465. data_tf.size())
  466. var_dict_torch_update[name] = data_tf
  467. logging.info(
  468. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  469. var_dict_tf[name_tf].shape))
  470. elif names[1] == "encoders":
  471. layeridx = int(names[2])
  472. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  473. layeridx_bias = 1
  474. layeridx += layeridx_bias
  475. if name_q in map_dict.keys():
  476. name_v = map_dict[name_q]["name"]
  477. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  478. data_tf = var_dict_tf[name_tf]
  479. if map_dict[name_q]["squeeze"] is not None:
  480. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  481. if map_dict[name_q]["transpose"] is not None:
  482. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  483. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  484. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  485. var_dict_torch[
  486. name].size(),
  487. data_tf.size())
  488. var_dict_torch_update[name] = data_tf
  489. logging.info(
  490. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  491. var_dict_tf[name_tf].shape))
  492. elif names[1] == "after_norm":
  493. name_tf = map_dict[name]["name"]
  494. data_tf = var_dict_tf[name_tf]
  495. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  496. var_dict_torch_update[name] = data_tf
  497. logging.info(
  498. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
  499. var_dict_tf[name_tf].shape))
  500. return var_dict_torch_update
  501. class SANMEncoderChunkOpt(AbsEncoder):
  502. """
  503. author: Speech Lab, Alibaba Group, China
  504. SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
  505. https://arxiv.org/abs/2006.01713
  506. """
  507. def __init__(
  508. self,
  509. input_size: int,
  510. output_size: int = 256,
  511. attention_heads: int = 4,
  512. linear_units: int = 2048,
  513. num_blocks: int = 6,
  514. dropout_rate: float = 0.1,
  515. positional_dropout_rate: float = 0.1,
  516. attention_dropout_rate: float = 0.0,
  517. input_layer: Optional[str] = "conv2d",
  518. pos_enc_class=SinusoidalPositionEncoder,
  519. normalize_before: bool = True,
  520. concat_after: bool = False,
  521. positionwise_layer_type: str = "linear",
  522. positionwise_conv_kernel_size: int = 1,
  523. padding_idx: int = -1,
  524. interctc_layer_idx: List[int] = [],
  525. interctc_use_conditioning: bool = False,
  526. kernel_size: int = 11,
  527. sanm_shfit: int = 0,
  528. selfattention_layer_type: str = "sanm",
  529. chunk_size: Union[int, Sequence[int]] = (16,),
  530. stride: Union[int, Sequence[int]] = (10,),
  531. pad_left: Union[int, Sequence[int]] = (0,),
  532. encoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
  533. decoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
  534. tf2torch_tensor_name_prefix_torch: str = "encoder",
  535. tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
  536. ):
  537. assert check_argument_types()
  538. super().__init__()
  539. self._output_size = output_size
  540. if input_layer == "linear":
  541. self.embed = torch.nn.Sequential(
  542. torch.nn.Linear(input_size, output_size),
  543. torch.nn.LayerNorm(output_size),
  544. torch.nn.Dropout(dropout_rate),
  545. torch.nn.ReLU(),
  546. pos_enc_class(output_size, positional_dropout_rate),
  547. )
  548. elif input_layer == "conv2d":
  549. self.embed = Conv2dSubsampling(input_size, output_size, dropout_rate)
  550. elif input_layer == "conv2d2":
  551. self.embed = Conv2dSubsampling2(input_size, output_size, dropout_rate)
  552. elif input_layer == "conv2d6":
  553. self.embed = Conv2dSubsampling6(input_size, output_size, dropout_rate)
  554. elif input_layer == "conv2d8":
  555. self.embed = Conv2dSubsampling8(input_size, output_size, dropout_rate)
  556. elif input_layer == "embed":
  557. self.embed = torch.nn.Sequential(
  558. torch.nn.Embedding(input_size, output_size, padding_idx=padding_idx),
  559. pos_enc_class(output_size, positional_dropout_rate),
  560. )
  561. elif input_layer is None:
  562. if input_size == output_size:
  563. self.embed = None
  564. else:
  565. self.embed = torch.nn.Linear(input_size, output_size)
  566. elif input_layer == "pe":
  567. self.embed = SinusoidalPositionEncoder()
  568. else:
  569. raise ValueError("unknown input_layer: " + input_layer)
  570. self.normalize_before = normalize_before
  571. if positionwise_layer_type == "linear":
  572. positionwise_layer = PositionwiseFeedForward
  573. positionwise_layer_args = (
  574. output_size,
  575. linear_units,
  576. dropout_rate,
  577. )
  578. elif positionwise_layer_type == "conv1d":
  579. positionwise_layer = MultiLayeredConv1d
  580. positionwise_layer_args = (
  581. output_size,
  582. linear_units,
  583. positionwise_conv_kernel_size,
  584. dropout_rate,
  585. )
  586. elif positionwise_layer_type == "conv1d-linear":
  587. positionwise_layer = Conv1dLinear
  588. positionwise_layer_args = (
  589. output_size,
  590. linear_units,
  591. positionwise_conv_kernel_size,
  592. dropout_rate,
  593. )
  594. else:
  595. raise NotImplementedError("Support only linear or conv1d.")
  596. if selfattention_layer_type == "selfattn":
  597. encoder_selfattn_layer = MultiHeadedAttention
  598. encoder_selfattn_layer_args = (
  599. attention_heads,
  600. output_size,
  601. attention_dropout_rate,
  602. )
  603. elif selfattention_layer_type == "sanm":
  604. encoder_selfattn_layer = MultiHeadedAttentionSANM
  605. encoder_selfattn_layer_args0 = (
  606. attention_heads,
  607. input_size,
  608. output_size,
  609. attention_dropout_rate,
  610. kernel_size,
  611. sanm_shfit,
  612. )
  613. encoder_selfattn_layer_args = (
  614. attention_heads,
  615. output_size,
  616. output_size,
  617. attention_dropout_rate,
  618. kernel_size,
  619. sanm_shfit,
  620. )
  621. self.encoders0 = repeat(
  622. 1,
  623. lambda lnum: EncoderLayerSANM(
  624. input_size,
  625. output_size,
  626. encoder_selfattn_layer(*encoder_selfattn_layer_args0),
  627. positionwise_layer(*positionwise_layer_args),
  628. dropout_rate,
  629. normalize_before,
  630. concat_after,
  631. ),
  632. )
  633. self.encoders = repeat(
  634. num_blocks - 1,
  635. lambda lnum: EncoderLayerSANM(
  636. output_size,
  637. output_size,
  638. encoder_selfattn_layer(*encoder_selfattn_layer_args),
  639. positionwise_layer(*positionwise_layer_args),
  640. dropout_rate,
  641. normalize_before,
  642. concat_after,
  643. ),
  644. )
  645. if self.normalize_before:
  646. self.after_norm = LayerNorm(output_size)
  647. self.interctc_layer_idx = interctc_layer_idx
  648. if len(interctc_layer_idx) > 0:
  649. assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
  650. self.interctc_use_conditioning = interctc_use_conditioning
  651. self.conditioning_layer = None
  652. shfit_fsmn = (kernel_size - 1) // 2
  653. self.overlap_chunk_cls = overlap_chunk(
  654. chunk_size=chunk_size,
  655. stride=stride,
  656. pad_left=pad_left,
  657. shfit_fsmn=shfit_fsmn,
  658. encoder_att_look_back_factor=encoder_att_look_back_factor,
  659. decoder_att_look_back_factor=decoder_att_look_back_factor,
  660. )
  661. self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
  662. self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
  663. def output_size(self) -> int:
  664. return self._output_size
  665. def forward(
  666. self,
  667. xs_pad: torch.Tensor,
  668. ilens: torch.Tensor,
  669. prev_states: torch.Tensor = None,
  670. ctc: CTC = None,
  671. ind: int = 0,
  672. ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
  673. """Embed positions in tensor.
  674. Args:
  675. xs_pad: input tensor (B, L, D)
  676. ilens: input length (B)
  677. prev_states: Not to be used now.
  678. Returns:
  679. position embedded tensor and mask
  680. """
  681. masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
  682. xs_pad *= self.output_size() ** 0.5
  683. if self.embed is None:
  684. xs_pad = xs_pad
  685. elif (
  686. isinstance(self.embed, Conv2dSubsampling)
  687. or isinstance(self.embed, Conv2dSubsampling2)
  688. or isinstance(self.embed, Conv2dSubsampling6)
  689. or isinstance(self.embed, Conv2dSubsampling8)
  690. ):
  691. short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
  692. if short_status:
  693. raise TooShortUttError(
  694. f"has {xs_pad.size(1)} frames and is too short for subsampling "
  695. + f"(it needs more than {limit_size} frames), return empty results",
  696. xs_pad.size(1),
  697. limit_size,
  698. )
  699. xs_pad, masks = self.embed(xs_pad, masks)
  700. else:
  701. xs_pad = self.embed(xs_pad)
  702. mask_shfit_chunk, mask_att_chunk_encoder = None, None
  703. if self.overlap_chunk_cls is not None:
  704. ilens = masks.squeeze(1).sum(1)
  705. chunk_outs = self.overlap_chunk_cls.gen_chunk_mask(ilens, ind)
  706. xs_pad, ilens = self.overlap_chunk_cls.split_chunk(xs_pad, ilens, chunk_outs=chunk_outs)
  707. masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
  708. mask_shfit_chunk = self.overlap_chunk_cls.get_mask_shfit_chunk(chunk_outs, xs_pad.device, xs_pad.size(0),
  709. dtype=xs_pad.dtype)
  710. mask_att_chunk_encoder = self.overlap_chunk_cls.get_mask_att_chunk_encoder(chunk_outs, xs_pad.device,
  711. xs_pad.size(0),
  712. dtype=xs_pad.dtype)
  713. encoder_outs = self.encoders0(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
  714. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  715. intermediate_outs = []
  716. if len(self.interctc_layer_idx) == 0:
  717. encoder_outs = self.encoders(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
  718. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  719. else:
  720. for layer_idx, encoder_layer in enumerate(self.encoders):
  721. encoder_outs = encoder_layer(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
  722. xs_pad, masks = encoder_outs[0], encoder_outs[1]
  723. if layer_idx + 1 in self.interctc_layer_idx:
  724. encoder_out = xs_pad
  725. # intermediate outputs are also normalized
  726. if self.normalize_before:
  727. encoder_out = self.after_norm(encoder_out)
  728. intermediate_outs.append((layer_idx + 1, encoder_out))
  729. if self.interctc_use_conditioning:
  730. ctc_out = ctc.softmax(encoder_out)
  731. xs_pad = xs_pad + self.conditioning_layer(ctc_out)
  732. if self.normalize_before:
  733. xs_pad = self.after_norm(xs_pad)
  734. olens = masks.squeeze(1).sum(1)
  735. if len(intermediate_outs) > 0:
  736. return (xs_pad, intermediate_outs), olens, None
  737. return xs_pad, olens, None
  738. def gen_tf2torch_map_dict(self):
  739. tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
  740. tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
  741. map_dict_local = {
  742. ## encoder
  743. # cicd
  744. "{}.encoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
  745. {"name": "{}/layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
  746. "squeeze": None,
  747. "transpose": None,
  748. }, # (256,),(256,)
  749. "{}.encoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
  750. {"name": "{}/layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
  751. "squeeze": None,
  752. "transpose": None,
  753. }, # (256,),(256,)
  754. "{}.encoders.layeridx.self_attn.linear_q_k_v.weight".format(tensor_name_prefix_torch):
  755. {"name": "{}/layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
  756. "squeeze": 0,
  757. "transpose": (1, 0),
  758. }, # (768,256),(1,256,768)
  759. "{}.encoders.layeridx.self_attn.linear_q_k_v.bias".format(tensor_name_prefix_torch):
  760. {"name": "{}/layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
  761. "squeeze": None,
  762. "transpose": None,
  763. }, # (768,),(768,)
  764. "{}.encoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
  765. {"name": "{}/layer_layeridx/multi_head/depth_conv_w".format(tensor_name_prefix_tf),
  766. "squeeze": 0,
  767. "transpose": (1, 2, 0),
  768. }, # (256,1,31),(1,31,256,1)
  769. "{}.encoders.layeridx.self_attn.linear_out.weight".format(tensor_name_prefix_torch):
  770. {"name": "{}/layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
  771. "squeeze": 0,
  772. "transpose": (1, 0),
  773. }, # (256,256),(1,256,256)
  774. "{}.encoders.layeridx.self_attn.linear_out.bias".format(tensor_name_prefix_torch):
  775. {"name": "{}/layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
  776. "squeeze": None,
  777. "transpose": None,
  778. }, # (256,),(256,)
  779. # ffn
  780. "{}.encoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
  781. {"name": "{}/layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
  782. "squeeze": None,
  783. "transpose": None,
  784. }, # (256,),(256,)
  785. "{}.encoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
  786. {"name": "{}/layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
  787. "squeeze": None,
  788. "transpose": None,
  789. }, # (256,),(256,)
  790. "{}.encoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
  791. {"name": "{}/layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf),
  792. "squeeze": 0,
  793. "transpose": (1, 0),
  794. }, # (1024,256),(1,256,1024)
  795. "{}.encoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
  796. {"name": "{}/layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf),
  797. "squeeze": None,
  798. "transpose": None,
  799. }, # (1024,),(1024,)
  800. "{}.encoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
  801. {"name": "{}/layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
  802. "squeeze": 0,
  803. "transpose": (1, 0),
  804. }, # (256,1024),(1,1024,256)
  805. "{}.encoders.layeridx.feed_forward.w_2.bias".format(tensor_name_prefix_torch):
  806. {"name": "{}/layer_layeridx/ffn/conv1d_1/bias".format(tensor_name_prefix_tf),
  807. "squeeze": None,
  808. "transpose": None,
  809. }, # (256,),(256,)
  810. # out norm
  811. "{}.after_norm.weight".format(tensor_name_prefix_torch):
  812. {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
  813. "squeeze": None,
  814. "transpose": None,
  815. }, # (256,),(256,)
  816. "{}.after_norm.bias".format(tensor_name_prefix_torch):
  817. {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
  818. "squeeze": None,
  819. "transpose": None,
  820. }, # (256,),(256,)
  821. }
  822. return map_dict_local
  823. def convert_tf2torch(self,
  824. var_dict_tf,
  825. var_dict_torch,
  826. ):
  827. map_dict = self.gen_tf2torch_map_dict()
  828. var_dict_torch_update = dict()
  829. for name in sorted(var_dict_torch.keys(), reverse=False):
  830. names = name.split('.')
  831. if names[0] == self.tf2torch_tensor_name_prefix_torch:
  832. if names[1] == "encoders0":
  833. layeridx = int(names[2])
  834. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  835. name_q = name_q.replace("encoders0", "encoders")
  836. layeridx_bias = 0
  837. layeridx += layeridx_bias
  838. if name_q in map_dict.keys():
  839. name_v = map_dict[name_q]["name"]
  840. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  841. data_tf = var_dict_tf[name_tf]
  842. if map_dict[name_q]["squeeze"] is not None:
  843. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  844. if map_dict[name_q]["transpose"] is not None:
  845. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  846. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  847. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  848. var_dict_torch[
  849. name].size(),
  850. data_tf.size())
  851. var_dict_torch_update[name] = data_tf
  852. logging.info(
  853. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  854. var_dict_tf[name_tf].shape))
  855. elif names[1] == "encoders":
  856. layeridx = int(names[2])
  857. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  858. layeridx_bias = 1
  859. layeridx += layeridx_bias
  860. if name_q in map_dict.keys():
  861. name_v = map_dict[name_q]["name"]
  862. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  863. data_tf = var_dict_tf[name_tf]
  864. if map_dict[name_q]["squeeze"] is not None:
  865. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  866. if map_dict[name_q]["transpose"] is not None:
  867. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  868. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  869. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  870. var_dict_torch[
  871. name].size(),
  872. data_tf.size())
  873. var_dict_torch_update[name] = data_tf
  874. logging.info(
  875. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  876. var_dict_tf[name_tf].shape))
  877. elif names[1] == "after_norm":
  878. name_tf = map_dict[name]["name"]
  879. data_tf = var_dict_tf[name_tf]
  880. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  881. var_dict_torch_update[name] = data_tf
  882. logging.info(
  883. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
  884. var_dict_tf[name_tf].shape))
  885. return var_dict_torch_update