decoder.py 40 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777
  1. #!/usr/bin/env python3
  2. # -*- encoding: utf-8 -*-
  3. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  4. # MIT License (https://opensource.org/licenses/MIT)
  5. import torch
  6. import logging
  7. import numpy as np
  8. from typing import Tuple
  9. from funasr.register import tables
  10. from funasr.models.scama import utils as myutils
  11. from funasr.models.transformer.utils.repeat import repeat
  12. from funasr.models.transformer.layer_norm import LayerNorm
  13. from funasr.models.transformer.embedding import PositionalEncoding
  14. from funasr.models.paraformer.decoder import DecoderLayerSANM, ParaformerSANMDecoder
  15. from funasr.models.sanm.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
  16. from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
  17. class ContextualDecoderLayer(torch.nn.Module):
  18. def __init__(
  19. self,
  20. size,
  21. self_attn,
  22. src_attn,
  23. feed_forward,
  24. dropout_rate,
  25. normalize_before=True,
  26. concat_after=False,
  27. ):
  28. """Construct an DecoderLayer object."""
  29. super(ContextualDecoderLayer, self).__init__()
  30. self.size = size
  31. self.self_attn = self_attn
  32. self.src_attn = src_attn
  33. self.feed_forward = feed_forward
  34. self.norm1 = LayerNorm(size)
  35. if self_attn is not None:
  36. self.norm2 = LayerNorm(size)
  37. if src_attn is not None:
  38. self.norm3 = LayerNorm(size)
  39. self.dropout = torch.nn.Dropout(dropout_rate)
  40. self.normalize_before = normalize_before
  41. self.concat_after = concat_after
  42. if self.concat_after:
  43. self.concat_linear1 = torch.nn.Linear(size + size, size)
  44. self.concat_linear2 = torch.nn.Linear(size + size, size)
  45. def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None,):
  46. # tgt = self.dropout(tgt)
  47. if isinstance(tgt, Tuple):
  48. tgt, _ = tgt
  49. residual = tgt
  50. if self.normalize_before:
  51. tgt = self.norm1(tgt)
  52. tgt = self.feed_forward(tgt)
  53. x = tgt
  54. if self.normalize_before:
  55. tgt = self.norm2(tgt)
  56. if self.training:
  57. cache = None
  58. x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
  59. x = residual + self.dropout(x)
  60. x_self_attn = x
  61. residual = x
  62. if self.normalize_before:
  63. x = self.norm3(x)
  64. x = self.src_attn(x, memory, memory_mask)
  65. x_src_attn = x
  66. x = residual + self.dropout(x)
  67. return x, tgt_mask, x_self_attn, x_src_attn
  68. class ContextualBiasDecoder(torch.nn.Module):
  69. def __init__(
  70. self,
  71. size,
  72. src_attn,
  73. dropout_rate,
  74. normalize_before=True,
  75. ):
  76. """Construct an DecoderLayer object."""
  77. super(ContextualBiasDecoder, self).__init__()
  78. self.size = size
  79. self.src_attn = src_attn
  80. if src_attn is not None:
  81. self.norm3 = LayerNorm(size)
  82. self.dropout = torch.nn.Dropout(dropout_rate)
  83. self.normalize_before = normalize_before
  84. def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
  85. x = tgt
  86. if self.src_attn is not None:
  87. if self.normalize_before:
  88. x = self.norm3(x)
  89. x = self.dropout(self.src_attn(x, memory, memory_mask))
  90. return x, tgt_mask, memory, memory_mask, cache
  91. @tables.register("decoder_classes", "ContextualParaformerDecoder")
  92. class ContextualParaformerDecoder(ParaformerSANMDecoder):
  93. """
  94. Author: Speech Lab of DAMO Academy, Alibaba Group
  95. Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
  96. https://arxiv.org/abs/2006.01713
  97. """
  98. def __init__(
  99. self,
  100. vocab_size: int,
  101. encoder_output_size: int,
  102. attention_heads: int = 4,
  103. linear_units: int = 2048,
  104. num_blocks: int = 6,
  105. dropout_rate: float = 0.1,
  106. positional_dropout_rate: float = 0.1,
  107. self_attention_dropout_rate: float = 0.0,
  108. src_attention_dropout_rate: float = 0.0,
  109. input_layer: str = "embed",
  110. use_output_layer: bool = True,
  111. pos_enc_class=PositionalEncoding,
  112. normalize_before: bool = True,
  113. concat_after: bool = False,
  114. att_layer_num: int = 6,
  115. kernel_size: int = 21,
  116. sanm_shfit: int = 0,
  117. ):
  118. super().__init__(
  119. vocab_size=vocab_size,
  120. encoder_output_size=encoder_output_size,
  121. dropout_rate=dropout_rate,
  122. positional_dropout_rate=positional_dropout_rate,
  123. input_layer=input_layer,
  124. use_output_layer=use_output_layer,
  125. pos_enc_class=pos_enc_class,
  126. normalize_before=normalize_before,
  127. )
  128. attention_dim = encoder_output_size
  129. if input_layer == 'none':
  130. self.embed = None
  131. if input_layer == "embed":
  132. self.embed = torch.nn.Sequential(
  133. torch.nn.Embedding(vocab_size, attention_dim),
  134. # pos_enc_class(attention_dim, positional_dropout_rate),
  135. )
  136. elif input_layer == "linear":
  137. self.embed = torch.nn.Sequential(
  138. torch.nn.Linear(vocab_size, attention_dim),
  139. torch.nn.LayerNorm(attention_dim),
  140. torch.nn.Dropout(dropout_rate),
  141. torch.nn.ReLU(),
  142. pos_enc_class(attention_dim, positional_dropout_rate),
  143. )
  144. else:
  145. raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
  146. self.normalize_before = normalize_before
  147. if self.normalize_before:
  148. self.after_norm = LayerNorm(attention_dim)
  149. if use_output_layer:
  150. self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
  151. else:
  152. self.output_layer = None
  153. self.att_layer_num = att_layer_num
  154. self.num_blocks = num_blocks
  155. if sanm_shfit is None:
  156. sanm_shfit = (kernel_size - 1) // 2
  157. self.decoders = repeat(
  158. att_layer_num - 1,
  159. lambda lnum: DecoderLayerSANM(
  160. attention_dim,
  161. MultiHeadedAttentionSANMDecoder(
  162. attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
  163. ),
  164. MultiHeadedAttentionCrossAtt(
  165. attention_heads, attention_dim, src_attention_dropout_rate
  166. ),
  167. PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
  168. dropout_rate,
  169. normalize_before,
  170. concat_after,
  171. ),
  172. )
  173. self.dropout = torch.nn.Dropout(dropout_rate)
  174. self.bias_decoder = ContextualBiasDecoder(
  175. size=attention_dim,
  176. src_attn=MultiHeadedAttentionCrossAtt(
  177. attention_heads, attention_dim, src_attention_dropout_rate
  178. ),
  179. dropout_rate=dropout_rate,
  180. normalize_before=True,
  181. )
  182. self.bias_output = torch.nn.Conv1d(attention_dim*2, attention_dim, 1, bias=False)
  183. self.last_decoder = ContextualDecoderLayer(
  184. attention_dim,
  185. MultiHeadedAttentionSANMDecoder(
  186. attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
  187. ),
  188. MultiHeadedAttentionCrossAtt(
  189. attention_heads, attention_dim, src_attention_dropout_rate
  190. ),
  191. PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
  192. dropout_rate,
  193. normalize_before,
  194. concat_after,
  195. )
  196. if num_blocks - att_layer_num <= 0:
  197. self.decoders2 = None
  198. else:
  199. self.decoders2 = repeat(
  200. num_blocks - att_layer_num,
  201. lambda lnum: DecoderLayerSANM(
  202. attention_dim,
  203. MultiHeadedAttentionSANMDecoder(
  204. attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=0
  205. ),
  206. None,
  207. PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
  208. dropout_rate,
  209. normalize_before,
  210. concat_after,
  211. ),
  212. )
  213. self.decoders3 = repeat(
  214. 1,
  215. lambda lnum: DecoderLayerSANM(
  216. attention_dim,
  217. None,
  218. None,
  219. PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
  220. dropout_rate,
  221. normalize_before,
  222. concat_after,
  223. ),
  224. )
  225. def forward(
  226. self,
  227. hs_pad: torch.Tensor,
  228. hlens: torch.Tensor,
  229. ys_in_pad: torch.Tensor,
  230. ys_in_lens: torch.Tensor,
  231. contextual_info: torch.Tensor,
  232. clas_scale: float = 1.0,
  233. return_hidden: bool = False,
  234. ) -> Tuple[torch.Tensor, torch.Tensor]:
  235. """Forward decoder.
  236. Args:
  237. hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
  238. hlens: (batch)
  239. ys_in_pad:
  240. input token ids, int64 (batch, maxlen_out)
  241. if input_layer == "embed"
  242. input tensor (batch, maxlen_out, #mels) in the other cases
  243. ys_in_lens: (batch)
  244. Returns:
  245. (tuple): tuple containing:
  246. x: decoded token score before softmax (batch, maxlen_out, token)
  247. if use_output_layer is True,
  248. olens: (batch, )
  249. """
  250. tgt = ys_in_pad
  251. tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
  252. memory = hs_pad
  253. memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
  254. x = tgt
  255. x, tgt_mask, memory, memory_mask, _ = self.decoders(
  256. x, tgt_mask, memory, memory_mask
  257. )
  258. _, _, x_self_attn, x_src_attn = self.last_decoder(
  259. x, tgt_mask, memory, memory_mask
  260. )
  261. # contextual paraformer related
  262. contextual_length = torch.Tensor([contextual_info.shape[1]]).int().repeat(hs_pad.shape[0])
  263. contextual_mask = myutils.sequence_mask(contextual_length, device=memory.device)[:, None, :]
  264. cx, tgt_mask, _, _, _ = self.bias_decoder(x_self_attn, tgt_mask, contextual_info, memory_mask=contextual_mask)
  265. if self.bias_output is not None:
  266. x = torch.cat([x_src_attn, cx*clas_scale], dim=2)
  267. x = self.bias_output(x.transpose(1, 2)).transpose(1, 2) # 2D -> D
  268. x = x_self_attn + self.dropout(x)
  269. if self.decoders2 is not None:
  270. x, tgt_mask, memory, memory_mask, _ = self.decoders2(
  271. x, tgt_mask, memory, memory_mask
  272. )
  273. x, tgt_mask, memory, memory_mask, _ = self.decoders3(
  274. x, tgt_mask, memory, memory_mask
  275. )
  276. if self.normalize_before:
  277. x = self.after_norm(x)
  278. olens = tgt_mask.sum(1)
  279. if self.output_layer is not None and return_hidden is False:
  280. x = self.output_layer(x)
  281. return x, olens
  282. def gen_tf2torch_map_dict(self):
  283. tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
  284. tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
  285. map_dict_local = {
  286. ## decoder
  287. # ffn
  288. "{}.decoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
  289. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
  290. "squeeze": None,
  291. "transpose": None,
  292. }, # (256,),(256,)
  293. "{}.decoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
  294. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
  295. "squeeze": None,
  296. "transpose": None,
  297. }, # (256,),(256,)
  298. "{}.decoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
  299. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/kernel".format(tensor_name_prefix_tf),
  300. "squeeze": 0,
  301. "transpose": (1, 0),
  302. }, # (1024,256),(1,256,1024)
  303. "{}.decoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
  304. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/bias".format(tensor_name_prefix_tf),
  305. "squeeze": None,
  306. "transpose": None,
  307. }, # (1024,),(1024,)
  308. "{}.decoders.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
  309. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
  310. "squeeze": None,
  311. "transpose": None,
  312. }, # (1024,),(1024,)
  313. "{}.decoders.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
  314. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/beta".format(tensor_name_prefix_tf),
  315. "squeeze": None,
  316. "transpose": None,
  317. }, # (1024,),(1024,)
  318. "{}.decoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
  319. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
  320. "squeeze": 0,
  321. "transpose": (1, 0),
  322. }, # (256,1024),(1,1024,256)
  323. # fsmn
  324. "{}.decoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
  325. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/gamma".format(
  326. tensor_name_prefix_tf),
  327. "squeeze": None,
  328. "transpose": None,
  329. }, # (256,),(256,)
  330. "{}.decoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
  331. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/beta".format(
  332. tensor_name_prefix_tf),
  333. "squeeze": None,
  334. "transpose": None,
  335. }, # (256,),(256,)
  336. "{}.decoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
  337. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/depth_conv_w".format(
  338. tensor_name_prefix_tf),
  339. "squeeze": 0,
  340. "transpose": (1, 2, 0),
  341. }, # (256,1,31),(1,31,256,1)
  342. # src att
  343. "{}.decoders.layeridx.norm3.weight".format(tensor_name_prefix_torch):
  344. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
  345. "squeeze": None,
  346. "transpose": None,
  347. }, # (256,),(256,)
  348. "{}.decoders.layeridx.norm3.bias".format(tensor_name_prefix_torch):
  349. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
  350. "squeeze": None,
  351. "transpose": None,
  352. }, # (256,),(256,)
  353. "{}.decoders.layeridx.src_attn.linear_q.weight".format(tensor_name_prefix_torch):
  354. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
  355. "squeeze": 0,
  356. "transpose": (1, 0),
  357. }, # (256,256),(1,256,256)
  358. "{}.decoders.layeridx.src_attn.linear_q.bias".format(tensor_name_prefix_torch):
  359. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
  360. "squeeze": None,
  361. "transpose": None,
  362. }, # (256,),(256,)
  363. "{}.decoders.layeridx.src_attn.linear_k_v.weight".format(tensor_name_prefix_torch):
  364. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
  365. "squeeze": 0,
  366. "transpose": (1, 0),
  367. }, # (1024,256),(1,256,1024)
  368. "{}.decoders.layeridx.src_attn.linear_k_v.bias".format(tensor_name_prefix_torch):
  369. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
  370. "squeeze": None,
  371. "transpose": None,
  372. }, # (1024,),(1024,)
  373. "{}.decoders.layeridx.src_attn.linear_out.weight".format(tensor_name_prefix_torch):
  374. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/kernel".format(tensor_name_prefix_tf),
  375. "squeeze": 0,
  376. "transpose": (1, 0),
  377. }, # (256,256),(1,256,256)
  378. "{}.decoders.layeridx.src_attn.linear_out.bias".format(tensor_name_prefix_torch):
  379. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/bias".format(tensor_name_prefix_tf),
  380. "squeeze": None,
  381. "transpose": None,
  382. }, # (256,),(256,)
  383. # dnn
  384. "{}.decoders3.layeridx.norm1.weight".format(tensor_name_prefix_torch):
  385. {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/gamma".format(tensor_name_prefix_tf),
  386. "squeeze": None,
  387. "transpose": None,
  388. }, # (256,),(256,)
  389. "{}.decoders3.layeridx.norm1.bias".format(tensor_name_prefix_torch):
  390. {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/beta".format(tensor_name_prefix_tf),
  391. "squeeze": None,
  392. "transpose": None,
  393. }, # (256,),(256,)
  394. "{}.decoders3.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
  395. {"name": "{}/decoder_dnn_layer_layeridx/conv1d/kernel".format(tensor_name_prefix_tf),
  396. "squeeze": 0,
  397. "transpose": (1, 0),
  398. }, # (1024,256),(1,256,1024)
  399. "{}.decoders3.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
  400. {"name": "{}/decoder_dnn_layer_layeridx/conv1d/bias".format(tensor_name_prefix_tf),
  401. "squeeze": None,
  402. "transpose": None,
  403. }, # (1024,),(1024,)
  404. "{}.decoders3.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
  405. {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
  406. "squeeze": None,
  407. "transpose": None,
  408. }, # (1024,),(1024,)
  409. "{}.decoders3.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
  410. {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/beta".format(tensor_name_prefix_tf),
  411. "squeeze": None,
  412. "transpose": None,
  413. }, # (1024,),(1024,)
  414. "{}.decoders3.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
  415. {"name": "{}/decoder_dnn_layer_layeridx/conv1d_1/kernel".format(tensor_name_prefix_tf),
  416. "squeeze": 0,
  417. "transpose": (1, 0),
  418. }, # (256,1024),(1,1024,256)
  419. # embed_concat_ffn
  420. "{}.embed_concat_ffn.layeridx.norm1.weight".format(tensor_name_prefix_torch):
  421. {"name": "{}/cif_concat/LayerNorm/gamma".format(tensor_name_prefix_tf),
  422. "squeeze": None,
  423. "transpose": None,
  424. }, # (256,),(256,)
  425. "{}.embed_concat_ffn.layeridx.norm1.bias".format(tensor_name_prefix_torch):
  426. {"name": "{}/cif_concat/LayerNorm/beta".format(tensor_name_prefix_tf),
  427. "squeeze": None,
  428. "transpose": None,
  429. }, # (256,),(256,)
  430. "{}.embed_concat_ffn.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
  431. {"name": "{}/cif_concat/conv1d/kernel".format(tensor_name_prefix_tf),
  432. "squeeze": 0,
  433. "transpose": (1, 0),
  434. }, # (1024,256),(1,256,1024)
  435. "{}.embed_concat_ffn.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
  436. {"name": "{}/cif_concat/conv1d/bias".format(tensor_name_prefix_tf),
  437. "squeeze": None,
  438. "transpose": None,
  439. }, # (1024,),(1024,)
  440. "{}.embed_concat_ffn.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
  441. {"name": "{}/cif_concat/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
  442. "squeeze": None,
  443. "transpose": None,
  444. }, # (1024,),(1024,)
  445. "{}.embed_concat_ffn.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
  446. {"name": "{}/cif_concat/LayerNorm_1/beta".format(tensor_name_prefix_tf),
  447. "squeeze": None,
  448. "transpose": None,
  449. }, # (1024,),(1024,)
  450. "{}.embed_concat_ffn.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
  451. {"name": "{}/cif_concat/conv1d_1/kernel".format(tensor_name_prefix_tf),
  452. "squeeze": 0,
  453. "transpose": (1, 0),
  454. }, # (256,1024),(1,1024,256)
  455. # out norm
  456. "{}.after_norm.weight".format(tensor_name_prefix_torch):
  457. {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
  458. "squeeze": None,
  459. "transpose": None,
  460. }, # (256,),(256,)
  461. "{}.after_norm.bias".format(tensor_name_prefix_torch):
  462. {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
  463. "squeeze": None,
  464. "transpose": None,
  465. }, # (256,),(256,)
  466. # in embed
  467. "{}.embed.0.weight".format(tensor_name_prefix_torch):
  468. {"name": "{}/w_embs".format(tensor_name_prefix_tf),
  469. "squeeze": None,
  470. "transpose": None,
  471. }, # (4235,256),(4235,256)
  472. # out layer
  473. "{}.output_layer.weight".format(tensor_name_prefix_torch):
  474. {"name": ["{}/dense/kernel".format(tensor_name_prefix_tf), "{}/w_embs".format(tensor_name_prefix_tf)],
  475. "squeeze": [None, None],
  476. "transpose": [(1, 0), None],
  477. }, # (4235,256),(256,4235)
  478. "{}.output_layer.bias".format(tensor_name_prefix_torch):
  479. {"name": ["{}/dense/bias".format(tensor_name_prefix_tf),
  480. "seq2seq/2bias" if tensor_name_prefix_tf == "seq2seq/decoder/inputter_1" else "seq2seq/bias"],
  481. "squeeze": [None, None],
  482. "transpose": [None, None],
  483. }, # (4235,),(4235,)
  484. ## clas decoder
  485. # src att
  486. "{}.bias_decoder.norm3.weight".format(tensor_name_prefix_torch):
  487. {"name": "{}/decoder_fsmn_layer_15/multi_head_1/LayerNorm/gamma".format(tensor_name_prefix_tf),
  488. "squeeze": None,
  489. "transpose": None,
  490. }, # (256,),(256,)
  491. "{}.bias_decoder.norm3.bias".format(tensor_name_prefix_torch):
  492. {"name": "{}/decoder_fsmn_layer_15/multi_head_1/LayerNorm/beta".format(tensor_name_prefix_tf),
  493. "squeeze": None,
  494. "transpose": None,
  495. }, # (256,),(256,)
  496. "{}.bias_decoder.src_attn.linear_q.weight".format(tensor_name_prefix_torch):
  497. {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d/kernel".format(tensor_name_prefix_tf),
  498. "squeeze": 0,
  499. "transpose": (1, 0),
  500. }, # (256,256),(1,256,256)
  501. "{}.bias_decoder.src_attn.linear_q.bias".format(tensor_name_prefix_torch):
  502. {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d/bias".format(tensor_name_prefix_tf),
  503. "squeeze": None,
  504. "transpose": None,
  505. }, # (256,),(256,)
  506. "{}.bias_decoder.src_attn.linear_k_v.weight".format(tensor_name_prefix_torch):
  507. {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_1/kernel".format(tensor_name_prefix_tf),
  508. "squeeze": 0,
  509. "transpose": (1, 0),
  510. }, # (1024,256),(1,256,1024)
  511. "{}.bias_decoder.src_attn.linear_k_v.bias".format(tensor_name_prefix_torch):
  512. {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_1/bias".format(tensor_name_prefix_tf),
  513. "squeeze": None,
  514. "transpose": None,
  515. }, # (1024,),(1024,)
  516. "{}.bias_decoder.src_attn.linear_out.weight".format(tensor_name_prefix_torch):
  517. {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_2/kernel".format(tensor_name_prefix_tf),
  518. "squeeze": 0,
  519. "transpose": (1, 0),
  520. }, # (256,256),(1,256,256)
  521. "{}.bias_decoder.src_attn.linear_out.bias".format(tensor_name_prefix_torch):
  522. {"name": "{}/decoder_fsmn_layer_15/multi_head_1/conv1d_2/bias".format(tensor_name_prefix_tf),
  523. "squeeze": None,
  524. "transpose": None,
  525. }, # (256,),(256,)
  526. # dnn
  527. "{}.bias_output.weight".format(tensor_name_prefix_torch):
  528. {"name": "{}/decoder_fsmn_layer_15/conv1d/kernel".format(tensor_name_prefix_tf),
  529. "squeeze": None,
  530. "transpose": (2, 1, 0),
  531. }, # (1024,256),(1,256,1024)
  532. }
  533. return map_dict_local
  534. def convert_tf2torch(self,
  535. var_dict_tf,
  536. var_dict_torch,
  537. ):
  538. map_dict = self.gen_tf2torch_map_dict()
  539. var_dict_torch_update = dict()
  540. decoder_layeridx_sets = set()
  541. for name in sorted(var_dict_torch.keys(), reverse=False):
  542. names = name.split('.')
  543. if names[0] == self.tf2torch_tensor_name_prefix_torch:
  544. if names[1] == "decoders":
  545. layeridx = int(names[2])
  546. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  547. layeridx_bias = 0
  548. layeridx += layeridx_bias
  549. decoder_layeridx_sets.add(layeridx)
  550. if name_q in map_dict.keys():
  551. name_v = map_dict[name_q]["name"]
  552. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  553. data_tf = var_dict_tf[name_tf]
  554. if map_dict[name_q]["squeeze"] is not None:
  555. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  556. if map_dict[name_q]["transpose"] is not None:
  557. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  558. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  559. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  560. var_dict_torch[
  561. name].size(),
  562. data_tf.size())
  563. var_dict_torch_update[name] = data_tf
  564. logging.info(
  565. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  566. var_dict_tf[name_tf].shape))
  567. elif names[1] == "last_decoder":
  568. layeridx = 15
  569. name_q = name.replace("last_decoder", "decoders.layeridx")
  570. layeridx_bias = 0
  571. layeridx += layeridx_bias
  572. decoder_layeridx_sets.add(layeridx)
  573. if name_q in map_dict.keys():
  574. name_v = map_dict[name_q]["name"]
  575. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  576. data_tf = var_dict_tf[name_tf]
  577. if map_dict[name_q]["squeeze"] is not None:
  578. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  579. if map_dict[name_q]["transpose"] is not None:
  580. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  581. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  582. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  583. var_dict_torch[
  584. name].size(),
  585. data_tf.size())
  586. var_dict_torch_update[name] = data_tf
  587. logging.info(
  588. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  589. var_dict_tf[name_tf].shape))
  590. elif names[1] == "decoders2":
  591. layeridx = int(names[2])
  592. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  593. name_q = name_q.replace("decoders2", "decoders")
  594. layeridx_bias = len(decoder_layeridx_sets)
  595. layeridx += layeridx_bias
  596. if "decoders." in name:
  597. decoder_layeridx_sets.add(layeridx)
  598. if name_q in map_dict.keys():
  599. name_v = map_dict[name_q]["name"]
  600. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  601. data_tf = var_dict_tf[name_tf]
  602. if map_dict[name_q]["squeeze"] is not None:
  603. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  604. if map_dict[name_q]["transpose"] is not None:
  605. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  606. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  607. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  608. var_dict_torch[
  609. name].size(),
  610. data_tf.size())
  611. var_dict_torch_update[name] = data_tf
  612. logging.info(
  613. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  614. var_dict_tf[name_tf].shape))
  615. elif names[1] == "decoders3":
  616. layeridx = int(names[2])
  617. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  618. layeridx_bias = 0
  619. layeridx += layeridx_bias
  620. if "decoders." in name:
  621. decoder_layeridx_sets.add(layeridx)
  622. if name_q in map_dict.keys():
  623. name_v = map_dict[name_q]["name"]
  624. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  625. data_tf = var_dict_tf[name_tf]
  626. if map_dict[name_q]["squeeze"] is not None:
  627. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  628. if map_dict[name_q]["transpose"] is not None:
  629. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  630. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  631. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  632. var_dict_torch[
  633. name].size(),
  634. data_tf.size())
  635. var_dict_torch_update[name] = data_tf
  636. logging.info(
  637. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  638. var_dict_tf[name_tf].shape))
  639. elif names[1] == "bias_decoder":
  640. name_q = name
  641. if name_q in map_dict.keys():
  642. name_v = map_dict[name_q]["name"]
  643. name_tf = name_v
  644. data_tf = var_dict_tf[name_tf]
  645. if map_dict[name_q]["squeeze"] is not None:
  646. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  647. if map_dict[name_q]["transpose"] is not None:
  648. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  649. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  650. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  651. var_dict_torch[
  652. name].size(),
  653. data_tf.size())
  654. var_dict_torch_update[name] = data_tf
  655. logging.info(
  656. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  657. var_dict_tf[name_tf].shape))
  658. elif names[1] == "embed" or names[1] == "output_layer" or names[1] == "bias_output":
  659. name_tf = map_dict[name]["name"]
  660. if isinstance(name_tf, list):
  661. idx_list = 0
  662. if name_tf[idx_list] in var_dict_tf.keys():
  663. pass
  664. else:
  665. idx_list = 1
  666. data_tf = var_dict_tf[name_tf[idx_list]]
  667. if map_dict[name]["squeeze"][idx_list] is not None:
  668. data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"][idx_list])
  669. if map_dict[name]["transpose"][idx_list] is not None:
  670. data_tf = np.transpose(data_tf, map_dict[name]["transpose"][idx_list])
  671. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  672. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  673. var_dict_torch[
  674. name].size(),
  675. data_tf.size())
  676. var_dict_torch_update[name] = data_tf
  677. logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(),
  678. name_tf[idx_list],
  679. var_dict_tf[name_tf[
  680. idx_list]].shape))
  681. else:
  682. data_tf = var_dict_tf[name_tf]
  683. if map_dict[name]["squeeze"] is not None:
  684. data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
  685. if map_dict[name]["transpose"] is not None:
  686. data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
  687. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  688. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  689. var_dict_torch[
  690. name].size(),
  691. data_tf.size())
  692. var_dict_torch_update[name] = data_tf
  693. logging.info(
  694. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
  695. var_dict_tf[name_tf].shape))
  696. elif names[1] == "after_norm":
  697. name_tf = map_dict[name]["name"]
  698. data_tf = var_dict_tf[name_tf]
  699. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  700. var_dict_torch_update[name] = data_tf
  701. logging.info(
  702. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
  703. var_dict_tf[name_tf].shape))
  704. elif names[1] == "embed_concat_ffn":
  705. layeridx = int(names[2])
  706. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  707. layeridx_bias = 0
  708. layeridx += layeridx_bias
  709. if "decoders." in name:
  710. decoder_layeridx_sets.add(layeridx)
  711. if name_q in map_dict.keys():
  712. name_v = map_dict[name_q]["name"]
  713. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  714. data_tf = var_dict_tf[name_tf]
  715. if map_dict[name_q]["squeeze"] is not None:
  716. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  717. if map_dict[name_q]["transpose"] is not None:
  718. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  719. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  720. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  721. var_dict_torch[
  722. name].size(),
  723. data_tf.size())
  724. var_dict_torch_update[name] = data_tf
  725. logging.info(
  726. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  727. var_dict_tf[name_tf].shape))
  728. return var_dict_torch_update