decoder.py 41 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856
  1. #!/usr/bin/env python3
  2. # -*- encoding: utf-8 -*-
  3. # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
  4. # MIT License (https://opensource.org/licenses/MIT)
  5. from typing import List
  6. from typing import Tuple
  7. import logging
  8. import torch
  9. import torch.nn as nn
  10. import numpy as np
  11. from funasr.models.scama import utils as myutils
  12. from funasr.models.transformer.decoder import BaseTransformerDecoder
  13. from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
  14. from funasr.models.transformer.embedding import PositionalEncoding
  15. from funasr.models.transformer.layer_norm import LayerNorm
  16. from funasr.models.sanm.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
  17. from funasr.models.transformer.utils.repeat import repeat
  18. from funasr.register import tables
  19. class DecoderLayerSANM(nn.Module):
  20. """Single decoder layer module.
  21. Args:
  22. size (int): Input dimension.
  23. self_attn (torch.nn.Module): Self-attention module instance.
  24. `MultiHeadedAttention` instance can be used as the argument.
  25. src_attn (torch.nn.Module): Self-attention module instance.
  26. `MultiHeadedAttention` instance can be used as the argument.
  27. feed_forward (torch.nn.Module): Feed-forward module instance.
  28. `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
  29. can be used as the argument.
  30. dropout_rate (float): Dropout rate.
  31. normalize_before (bool): Whether to use layer_norm before the first block.
  32. concat_after (bool): Whether to concat attention layer's input and output.
  33. if True, additional linear will be applied.
  34. i.e. x -> x + linear(concat(x, att(x)))
  35. if False, no additional linear will be applied. i.e. x -> x + att(x)
  36. """
  37. def __init__(
  38. self,
  39. size,
  40. self_attn,
  41. src_attn,
  42. feed_forward,
  43. dropout_rate,
  44. normalize_before=True,
  45. concat_after=False,
  46. ):
  47. """Construct an DecoderLayer object."""
  48. super(DecoderLayerSANM, self).__init__()
  49. self.size = size
  50. self.self_attn = self_attn
  51. self.src_attn = src_attn
  52. self.feed_forward = feed_forward
  53. self.norm1 = LayerNorm(size)
  54. if self_attn is not None:
  55. self.norm2 = LayerNorm(size)
  56. if src_attn is not None:
  57. self.norm3 = LayerNorm(size)
  58. self.dropout = nn.Dropout(dropout_rate)
  59. self.normalize_before = normalize_before
  60. self.concat_after = concat_after
  61. if self.concat_after:
  62. self.concat_linear1 = nn.Linear(size + size, size)
  63. self.concat_linear2 = nn.Linear(size + size, size)
  64. def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
  65. """Compute decoded features.
  66. Args:
  67. tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
  68. tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
  69. memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
  70. memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
  71. cache (List[torch.Tensor]): List of cached tensors.
  72. Each tensor shape should be (#batch, maxlen_out - 1, size).
  73. Returns:
  74. torch.Tensor: Output tensor(#batch, maxlen_out, size).
  75. torch.Tensor: Mask for output tensor (#batch, maxlen_out).
  76. torch.Tensor: Encoded memory (#batch, maxlen_in, size).
  77. torch.Tensor: Encoded memory mask (#batch, maxlen_in).
  78. """
  79. # tgt = self.dropout(tgt)
  80. residual = tgt
  81. if self.normalize_before:
  82. tgt = self.norm1(tgt)
  83. tgt = self.feed_forward(tgt)
  84. x = tgt
  85. if self.self_attn:
  86. if self.normalize_before:
  87. tgt = self.norm2(tgt)
  88. x, _ = self.self_attn(tgt, tgt_mask)
  89. x = residual + self.dropout(x)
  90. if self.src_attn is not None:
  91. residual = x
  92. if self.normalize_before:
  93. x = self.norm3(x)
  94. x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
  95. return x, tgt_mask, memory, memory_mask, cache
  96. def forward_one_step(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
  97. """Compute decoded features.
  98. Args:
  99. tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
  100. tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
  101. memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
  102. memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
  103. cache (List[torch.Tensor]): List of cached tensors.
  104. Each tensor shape should be (#batch, maxlen_out - 1, size).
  105. Returns:
  106. torch.Tensor: Output tensor(#batch, maxlen_out, size).
  107. torch.Tensor: Mask for output tensor (#batch, maxlen_out).
  108. torch.Tensor: Encoded memory (#batch, maxlen_in, size).
  109. torch.Tensor: Encoded memory mask (#batch, maxlen_in).
  110. """
  111. # tgt = self.dropout(tgt)
  112. residual = tgt
  113. if self.normalize_before:
  114. tgt = self.norm1(tgt)
  115. tgt = self.feed_forward(tgt)
  116. x = tgt
  117. if self.self_attn:
  118. if self.normalize_before:
  119. tgt = self.norm2(tgt)
  120. if self.training:
  121. cache = None
  122. x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
  123. x = residual + self.dropout(x)
  124. if self.src_attn is not None:
  125. residual = x
  126. if self.normalize_before:
  127. x = self.norm3(x)
  128. x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
  129. return x, tgt_mask, memory, memory_mask, cache
  130. def forward_chunk(self, tgt, memory, fsmn_cache=None, opt_cache=None, chunk_size=None, look_back=0):
  131. """Compute decoded features.
  132. Args:
  133. tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
  134. tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
  135. memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
  136. memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
  137. cache (List[torch.Tensor]): List of cached tensors.
  138. Each tensor shape should be (#batch, maxlen_out - 1, size).
  139. Returns:
  140. torch.Tensor: Output tensor(#batch, maxlen_out, size).
  141. torch.Tensor: Mask for output tensor (#batch, maxlen_out).
  142. torch.Tensor: Encoded memory (#batch, maxlen_in, size).
  143. torch.Tensor: Encoded memory mask (#batch, maxlen_in).
  144. """
  145. residual = tgt
  146. if self.normalize_before:
  147. tgt = self.norm1(tgt)
  148. tgt = self.feed_forward(tgt)
  149. x = tgt
  150. if self.self_attn:
  151. if self.normalize_before:
  152. tgt = self.norm2(tgt)
  153. x, fsmn_cache = self.self_attn(tgt, None, fsmn_cache)
  154. x = residual + self.dropout(x)
  155. if self.src_attn is not None:
  156. residual = x
  157. if self.normalize_before:
  158. x = self.norm3(x)
  159. x, opt_cache = self.src_attn.forward_chunk(x, memory, opt_cache, chunk_size, look_back)
  160. x = residual + x
  161. return x, memory, fsmn_cache, opt_cache
  162. @tables.register("decoder_classes", "FsmnDecoderSCAMAOpt")
  163. class FsmnDecoderSCAMAOpt(BaseTransformerDecoder):
  164. """
  165. Author: Shiliang Zhang, Zhifu Gao, Haoneng Luo, Ming Lei, Jie Gao, Zhijie Yan, Lei Xie
  166. SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
  167. https://arxiv.org/abs/2006.01712
  168. """
  169. def __init__(
  170. self,
  171. vocab_size: int,
  172. encoder_output_size: int,
  173. attention_heads: int = 4,
  174. linear_units: int = 2048,
  175. num_blocks: int = 6,
  176. dropout_rate: float = 0.1,
  177. positional_dropout_rate: float = 0.1,
  178. self_attention_dropout_rate: float = 0.0,
  179. src_attention_dropout_rate: float = 0.0,
  180. input_layer: str = "embed",
  181. use_output_layer: bool = True,
  182. pos_enc_class=PositionalEncoding,
  183. normalize_before: bool = True,
  184. concat_after: bool = False,
  185. att_layer_num: int = 6,
  186. kernel_size: int = 21,
  187. sanm_shfit: int = None,
  188. concat_embeds: bool = False,
  189. attention_dim: int = None,
  190. tf2torch_tensor_name_prefix_torch: str = "decoder",
  191. tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder",
  192. embed_tensor_name_prefix_tf: str = None,
  193. ):
  194. super().__init__(
  195. vocab_size=vocab_size,
  196. encoder_output_size=encoder_output_size,
  197. dropout_rate=dropout_rate,
  198. positional_dropout_rate=positional_dropout_rate,
  199. input_layer=input_layer,
  200. use_output_layer=use_output_layer,
  201. pos_enc_class=pos_enc_class,
  202. normalize_before=normalize_before,
  203. )
  204. if attention_dim is None:
  205. attention_dim = encoder_output_size
  206. if input_layer == "embed":
  207. self.embed = torch.nn.Sequential(
  208. torch.nn.Embedding(vocab_size, attention_dim),
  209. )
  210. elif input_layer == "linear":
  211. self.embed = torch.nn.Sequential(
  212. torch.nn.Linear(vocab_size, attention_dim),
  213. torch.nn.LayerNorm(attention_dim),
  214. torch.nn.Dropout(dropout_rate),
  215. torch.nn.ReLU(),
  216. pos_enc_class(attention_dim, positional_dropout_rate),
  217. )
  218. else:
  219. raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
  220. self.normalize_before = normalize_before
  221. if self.normalize_before:
  222. self.after_norm = LayerNorm(attention_dim)
  223. if use_output_layer:
  224. self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
  225. else:
  226. self.output_layer = None
  227. self.att_layer_num = att_layer_num
  228. self.num_blocks = num_blocks
  229. if sanm_shfit is None:
  230. sanm_shfit = (kernel_size - 1) // 2
  231. self.decoders = repeat(
  232. att_layer_num,
  233. lambda lnum: DecoderLayerSANM(
  234. attention_dim,
  235. MultiHeadedAttentionSANMDecoder(
  236. attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
  237. ),
  238. MultiHeadedAttentionCrossAtt(
  239. attention_heads, attention_dim, src_attention_dropout_rate, encoder_output_size=encoder_output_size
  240. ),
  241. PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
  242. dropout_rate,
  243. normalize_before,
  244. concat_after,
  245. ),
  246. )
  247. if num_blocks - att_layer_num <= 0:
  248. self.decoders2 = None
  249. else:
  250. self.decoders2 = repeat(
  251. num_blocks - att_layer_num,
  252. lambda lnum: DecoderLayerSANM(
  253. attention_dim,
  254. MultiHeadedAttentionSANMDecoder(
  255. attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
  256. ),
  257. None,
  258. PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
  259. dropout_rate,
  260. normalize_before,
  261. concat_after,
  262. ),
  263. )
  264. self.decoders3 = repeat(
  265. 1,
  266. lambda lnum: DecoderLayerSANM(
  267. attention_dim,
  268. None,
  269. None,
  270. PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
  271. dropout_rate,
  272. normalize_before,
  273. concat_after,
  274. ),
  275. )
  276. if concat_embeds:
  277. self.embed_concat_ffn = repeat(
  278. 1,
  279. lambda lnum: DecoderLayerSANM(
  280. attention_dim + encoder_output_size,
  281. None,
  282. None,
  283. PositionwiseFeedForwardDecoderSANM(attention_dim + encoder_output_size, linear_units, dropout_rate,
  284. adim=attention_dim),
  285. dropout_rate,
  286. normalize_before,
  287. concat_after,
  288. ),
  289. )
  290. else:
  291. self.embed_concat_ffn = None
  292. self.concat_embeds = concat_embeds
  293. self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
  294. self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
  295. self.embed_tensor_name_prefix_tf = embed_tensor_name_prefix_tf
  296. def forward(
  297. self,
  298. hs_pad: torch.Tensor,
  299. hlens: torch.Tensor,
  300. ys_in_pad: torch.Tensor,
  301. ys_in_lens: torch.Tensor,
  302. chunk_mask: torch.Tensor = None,
  303. pre_acoustic_embeds: torch.Tensor = None,
  304. ) -> Tuple[torch.Tensor, torch.Tensor]:
  305. """Forward decoder.
  306. Args:
  307. hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
  308. hlens: (batch)
  309. ys_in_pad:
  310. input token ids, int64 (batch, maxlen_out)
  311. if input_layer == "embed"
  312. input tensor (batch, maxlen_out, #mels) in the other cases
  313. ys_in_lens: (batch)
  314. Returns:
  315. (tuple): tuple containing:
  316. x: decoded token score before softmax (batch, maxlen_out, token)
  317. if use_output_layer is True,
  318. olens: (batch, )
  319. """
  320. tgt = ys_in_pad
  321. tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
  322. memory = hs_pad
  323. memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
  324. if chunk_mask is not None:
  325. memory_mask = memory_mask * chunk_mask
  326. if tgt_mask.size(1) != memory_mask.size(1):
  327. memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1)
  328. x = self.embed(tgt)
  329. if pre_acoustic_embeds is not None and self.concat_embeds:
  330. x = torch.cat((x, pre_acoustic_embeds), dim=-1)
  331. x, _, _, _, _ = self.embed_concat_ffn(x, None, None, None, None)
  332. x, tgt_mask, memory, memory_mask, _ = self.decoders(
  333. x, tgt_mask, memory, memory_mask
  334. )
  335. if self.decoders2 is not None:
  336. x, tgt_mask, memory, memory_mask, _ = self.decoders2(
  337. x, tgt_mask, memory, memory_mask
  338. )
  339. x, tgt_mask, memory, memory_mask, _ = self.decoders3(
  340. x, tgt_mask, memory, memory_mask
  341. )
  342. if self.normalize_before:
  343. x = self.after_norm(x)
  344. if self.output_layer is not None:
  345. x = self.output_layer(x)
  346. olens = tgt_mask.sum(1)
  347. return x, olens
  348. def score(self, ys, state, x, x_mask=None, pre_acoustic_embeds: torch.Tensor = None, ):
  349. """Score."""
  350. ys_mask = myutils.sequence_mask(torch.tensor([len(ys)], dtype=torch.int32), device=x.device)[:, :, None]
  351. logp, state = self.forward_one_step(
  352. ys.unsqueeze(0), ys_mask, x.unsqueeze(0), memory_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds,
  353. cache=state
  354. )
  355. return logp.squeeze(0), state
  356. def forward_one_step(
  357. self,
  358. tgt: torch.Tensor,
  359. tgt_mask: torch.Tensor,
  360. memory: torch.Tensor,
  361. memory_mask: torch.Tensor = None,
  362. pre_acoustic_embeds: torch.Tensor = None,
  363. cache: List[torch.Tensor] = None,
  364. ) -> Tuple[torch.Tensor, List[torch.Tensor]]:
  365. """Forward one step.
  366. Args:
  367. tgt: input token ids, int64 (batch, maxlen_out)
  368. tgt_mask: input token mask, (batch, maxlen_out)
  369. dtype=torch.uint8 in PyTorch 1.2-
  370. dtype=torch.bool in PyTorch 1.2+ (include 1.2)
  371. memory: encoded memory, float32 (batch, maxlen_in, feat)
  372. cache: cached output list of (batch, max_time_out-1, size)
  373. Returns:
  374. y, cache: NN output value and cache per `self.decoders`.
  375. y.shape` is (batch, maxlen_out, token)
  376. """
  377. x = tgt[:, -1:]
  378. tgt_mask = None
  379. x = self.embed(x)
  380. if pre_acoustic_embeds is not None and self.concat_embeds:
  381. x = torch.cat((x, pre_acoustic_embeds), dim=-1)
  382. x, _, _, _, _ = self.embed_concat_ffn(x, None, None, None, None)
  383. if cache is None:
  384. cache_layer_num = len(self.decoders)
  385. if self.decoders2 is not None:
  386. cache_layer_num += len(self.decoders2)
  387. cache = [None] * cache_layer_num
  388. new_cache = []
  389. # for c, decoder in zip(cache, self.decoders):
  390. for i in range(self.att_layer_num):
  391. decoder = self.decoders[i]
  392. c = cache[i]
  393. x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step(
  394. x, tgt_mask, memory, memory_mask, cache=c
  395. )
  396. new_cache.append(c_ret)
  397. if self.num_blocks - self.att_layer_num >= 1:
  398. for i in range(self.num_blocks - self.att_layer_num):
  399. j = i + self.att_layer_num
  400. decoder = self.decoders2[i]
  401. c = cache[j]
  402. x, tgt_mask, memory, memory_mask, c_ret = decoder.forward_one_step(
  403. x, tgt_mask, memory, memory_mask, cache=c
  404. )
  405. new_cache.append(c_ret)
  406. for decoder in self.decoders3:
  407. x, tgt_mask, memory, memory_mask, _ = decoder.forward_one_step(
  408. x, tgt_mask, memory, None, cache=None
  409. )
  410. if self.normalize_before:
  411. y = self.after_norm(x[:, -1])
  412. else:
  413. y = x[:, -1]
  414. if self.output_layer is not None:
  415. y = self.output_layer(y)
  416. y = torch.log_softmax(y, dim=-1)
  417. return y, new_cache
  418. def gen_tf2torch_map_dict(self):
  419. tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
  420. tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
  421. embed_tensor_name_prefix_tf = self.embed_tensor_name_prefix_tf if self.embed_tensor_name_prefix_tf is not None else tensor_name_prefix_tf
  422. map_dict_local = {
  423. ## decoder
  424. # ffn
  425. "{}.decoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
  426. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
  427. "squeeze": None,
  428. "transpose": None,
  429. }, # (256,),(256,)
  430. "{}.decoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
  431. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
  432. "squeeze": None,
  433. "transpose": None,
  434. }, # (256,),(256,)
  435. "{}.decoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
  436. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/kernel".format(tensor_name_prefix_tf),
  437. "squeeze": 0,
  438. "transpose": (1, 0),
  439. }, # (1024,256),(1,256,1024)
  440. "{}.decoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
  441. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d/bias".format(tensor_name_prefix_tf),
  442. "squeeze": None,
  443. "transpose": None,
  444. }, # (1024,),(1024,)
  445. "{}.decoders.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
  446. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
  447. "squeeze": None,
  448. "transpose": None,
  449. }, # (1024,),(1024,)
  450. "{}.decoders.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
  451. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/LayerNorm_1/beta".format(tensor_name_prefix_tf),
  452. "squeeze": None,
  453. "transpose": None,
  454. }, # (1024,),(1024,)
  455. "{}.decoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
  456. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
  457. "squeeze": 0,
  458. "transpose": (1, 0),
  459. }, # (256,1024),(1,1024,256)
  460. # fsmn
  461. "{}.decoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
  462. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/gamma".format(
  463. tensor_name_prefix_tf),
  464. "squeeze": None,
  465. "transpose": None,
  466. }, # (256,),(256,)
  467. "{}.decoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
  468. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/LayerNorm/beta".format(
  469. tensor_name_prefix_tf),
  470. "squeeze": None,
  471. "transpose": None,
  472. }, # (256,),(256,)
  473. "{}.decoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
  474. {"name": "{}/decoder_fsmn_layer_layeridx/decoder_memory_block/depth_conv_w".format(
  475. tensor_name_prefix_tf),
  476. "squeeze": 0,
  477. "transpose": (1, 2, 0),
  478. }, # (256,1,31),(1,31,256,1)
  479. # src att
  480. "{}.decoders.layeridx.norm3.weight".format(tensor_name_prefix_torch):
  481. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
  482. "squeeze": None,
  483. "transpose": None,
  484. }, # (256,),(256,)
  485. "{}.decoders.layeridx.norm3.bias".format(tensor_name_prefix_torch):
  486. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
  487. "squeeze": None,
  488. "transpose": None,
  489. }, # (256,),(256,)
  490. "{}.decoders.layeridx.src_attn.linear_q.weight".format(tensor_name_prefix_torch):
  491. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
  492. "squeeze": 0,
  493. "transpose": (1, 0),
  494. }, # (256,256),(1,256,256)
  495. "{}.decoders.layeridx.src_attn.linear_q.bias".format(tensor_name_prefix_torch):
  496. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
  497. "squeeze": None,
  498. "transpose": None,
  499. }, # (256,),(256,)
  500. "{}.decoders.layeridx.src_attn.linear_k_v.weight".format(tensor_name_prefix_torch):
  501. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
  502. "squeeze": 0,
  503. "transpose": (1, 0),
  504. }, # (1024,256),(1,256,1024)
  505. "{}.decoders.layeridx.src_attn.linear_k_v.bias".format(tensor_name_prefix_torch):
  506. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
  507. "squeeze": None,
  508. "transpose": None,
  509. }, # (1024,),(1024,)
  510. "{}.decoders.layeridx.src_attn.linear_out.weight".format(tensor_name_prefix_torch):
  511. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/kernel".format(tensor_name_prefix_tf),
  512. "squeeze": 0,
  513. "transpose": (1, 0),
  514. }, # (256,256),(1,256,256)
  515. "{}.decoders.layeridx.src_attn.linear_out.bias".format(tensor_name_prefix_torch):
  516. {"name": "{}/decoder_fsmn_layer_layeridx/multi_head/conv1d_2/bias".format(tensor_name_prefix_tf),
  517. "squeeze": None,
  518. "transpose": None,
  519. }, # (256,),(256,)
  520. # dnn
  521. "{}.decoders3.layeridx.norm1.weight".format(tensor_name_prefix_torch):
  522. {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/gamma".format(tensor_name_prefix_tf),
  523. "squeeze": None,
  524. "transpose": None,
  525. }, # (256,),(256,)
  526. "{}.decoders3.layeridx.norm1.bias".format(tensor_name_prefix_torch):
  527. {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm/beta".format(tensor_name_prefix_tf),
  528. "squeeze": None,
  529. "transpose": None,
  530. }, # (256,),(256,)
  531. "{}.decoders3.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
  532. {"name": "{}/decoder_dnn_layer_layeridx/conv1d/kernel".format(tensor_name_prefix_tf),
  533. "squeeze": 0,
  534. "transpose": (1, 0),
  535. }, # (1024,256),(1,256,1024)
  536. "{}.decoders3.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
  537. {"name": "{}/decoder_dnn_layer_layeridx/conv1d/bias".format(tensor_name_prefix_tf),
  538. "squeeze": None,
  539. "transpose": None,
  540. }, # (1024,),(1024,)
  541. "{}.decoders3.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
  542. {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
  543. "squeeze": None,
  544. "transpose": None,
  545. }, # (1024,),(1024,)
  546. "{}.decoders3.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
  547. {"name": "{}/decoder_dnn_layer_layeridx/LayerNorm_1/beta".format(tensor_name_prefix_tf),
  548. "squeeze": None,
  549. "transpose": None,
  550. }, # (1024,),(1024,)
  551. "{}.decoders3.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
  552. {"name": "{}/decoder_dnn_layer_layeridx/conv1d_1/kernel".format(tensor_name_prefix_tf),
  553. "squeeze": 0,
  554. "transpose": (1, 0),
  555. }, # (256,1024),(1,1024,256)
  556. # embed_concat_ffn
  557. "{}.embed_concat_ffn.layeridx.norm1.weight".format(tensor_name_prefix_torch):
  558. {"name": "{}/cif_concat/LayerNorm/gamma".format(tensor_name_prefix_tf),
  559. "squeeze": None,
  560. "transpose": None,
  561. }, # (256,),(256,)
  562. "{}.embed_concat_ffn.layeridx.norm1.bias".format(tensor_name_prefix_torch):
  563. {"name": "{}/cif_concat/LayerNorm/beta".format(tensor_name_prefix_tf),
  564. "squeeze": None,
  565. "transpose": None,
  566. }, # (256,),(256,)
  567. "{}.embed_concat_ffn.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
  568. {"name": "{}/cif_concat/conv1d/kernel".format(tensor_name_prefix_tf),
  569. "squeeze": 0,
  570. "transpose": (1, 0),
  571. }, # (1024,256),(1,256,1024)
  572. "{}.embed_concat_ffn.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
  573. {"name": "{}/cif_concat/conv1d/bias".format(tensor_name_prefix_tf),
  574. "squeeze": None,
  575. "transpose": None,
  576. }, # (1024,),(1024,)
  577. "{}.embed_concat_ffn.layeridx.feed_forward.norm.weight".format(tensor_name_prefix_torch):
  578. {"name": "{}/cif_concat/LayerNorm_1/gamma".format(tensor_name_prefix_tf),
  579. "squeeze": None,
  580. "transpose": None,
  581. }, # (1024,),(1024,)
  582. "{}.embed_concat_ffn.layeridx.feed_forward.norm.bias".format(tensor_name_prefix_torch):
  583. {"name": "{}/cif_concat/LayerNorm_1/beta".format(tensor_name_prefix_tf),
  584. "squeeze": None,
  585. "transpose": None,
  586. }, # (1024,),(1024,)
  587. "{}.embed_concat_ffn.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
  588. {"name": "{}/cif_concat/conv1d_1/kernel".format(tensor_name_prefix_tf),
  589. "squeeze": 0,
  590. "transpose": (1, 0),
  591. }, # (256,1024),(1,1024,256)
  592. # out norm
  593. "{}.after_norm.weight".format(tensor_name_prefix_torch):
  594. {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
  595. "squeeze": None,
  596. "transpose": None,
  597. }, # (256,),(256,)
  598. "{}.after_norm.bias".format(tensor_name_prefix_torch):
  599. {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
  600. "squeeze": None,
  601. "transpose": None,
  602. }, # (256,),(256,)
  603. # in embed
  604. "{}.embed.0.weight".format(tensor_name_prefix_torch):
  605. {"name": "{}/w_embs".format(embed_tensor_name_prefix_tf),
  606. "squeeze": None,
  607. "transpose": None,
  608. }, # (4235,256),(4235,256)
  609. # out layer
  610. "{}.output_layer.weight".format(tensor_name_prefix_torch):
  611. {"name": ["{}/dense/kernel".format(tensor_name_prefix_tf),
  612. "{}/w_embs".format(embed_tensor_name_prefix_tf)],
  613. "squeeze": [None, None],
  614. "transpose": [(1, 0), None],
  615. }, # (4235,256),(256,4235)
  616. "{}.output_layer.bias".format(tensor_name_prefix_torch):
  617. {"name": ["{}/dense/bias".format(tensor_name_prefix_tf),
  618. "seq2seq/2bias" if tensor_name_prefix_tf == "seq2seq/decoder/inputter_1" else "seq2seq/bias"],
  619. "squeeze": [None, None],
  620. "transpose": [None, None],
  621. }, # (4235,),(4235,)
  622. }
  623. return map_dict_local
  624. def convert_tf2torch(self,
  625. var_dict_tf,
  626. var_dict_torch,
  627. ):
  628. map_dict = self.gen_tf2torch_map_dict()
  629. var_dict_torch_update = dict()
  630. decoder_layeridx_sets = set()
  631. for name in sorted(var_dict_torch.keys(), reverse=False):
  632. names = name.split('.')
  633. if names[0] == self.tf2torch_tensor_name_prefix_torch:
  634. if names[1] == "decoders":
  635. layeridx = int(names[2])
  636. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  637. layeridx_bias = 0
  638. layeridx += layeridx_bias
  639. decoder_layeridx_sets.add(layeridx)
  640. if name_q in map_dict.keys():
  641. name_v = map_dict[name_q]["name"]
  642. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  643. data_tf = var_dict_tf[name_tf]
  644. if map_dict[name_q]["squeeze"] is not None:
  645. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  646. if map_dict[name_q]["transpose"] is not None:
  647. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  648. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  649. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  650. var_dict_torch[
  651. name].size(),
  652. data_tf.size())
  653. var_dict_torch_update[name] = data_tf
  654. logging.info(
  655. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  656. var_dict_tf[name_tf].shape))
  657. elif names[1] == "decoders2":
  658. layeridx = int(names[2])
  659. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  660. name_q = name_q.replace("decoders2", "decoders")
  661. layeridx_bias = len(decoder_layeridx_sets)
  662. layeridx += layeridx_bias
  663. if "decoders." in name:
  664. decoder_layeridx_sets.add(layeridx)
  665. if name_q in map_dict.keys():
  666. name_v = map_dict[name_q]["name"]
  667. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  668. data_tf = var_dict_tf[name_tf]
  669. if map_dict[name_q]["squeeze"] is not None:
  670. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  671. if map_dict[name_q]["transpose"] is not None:
  672. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  673. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  674. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  675. var_dict_torch[
  676. name].size(),
  677. data_tf.size())
  678. var_dict_torch_update[name] = data_tf
  679. logging.info(
  680. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  681. var_dict_tf[name_tf].shape))
  682. elif names[1] == "decoders3":
  683. layeridx = int(names[2])
  684. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  685. layeridx_bias = 0
  686. layeridx += layeridx_bias
  687. if "decoders." in name:
  688. decoder_layeridx_sets.add(layeridx)
  689. if name_q in map_dict.keys():
  690. name_v = map_dict[name_q]["name"]
  691. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  692. data_tf = var_dict_tf[name_tf]
  693. if map_dict[name_q]["squeeze"] is not None:
  694. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  695. if map_dict[name_q]["transpose"] is not None:
  696. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  697. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  698. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  699. var_dict_torch[
  700. name].size(),
  701. data_tf.size())
  702. var_dict_torch_update[name] = data_tf
  703. logging.info(
  704. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  705. var_dict_tf[name_tf].shape))
  706. elif names[1] == "embed" or names[1] == "output_layer":
  707. name_tf = map_dict[name]["name"]
  708. if isinstance(name_tf, list):
  709. idx_list = 0
  710. if name_tf[idx_list] in var_dict_tf.keys():
  711. pass
  712. else:
  713. idx_list = 1
  714. data_tf = var_dict_tf[name_tf[idx_list]]
  715. if map_dict[name]["squeeze"][idx_list] is not None:
  716. data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"][idx_list])
  717. if map_dict[name]["transpose"][idx_list] is not None:
  718. data_tf = np.transpose(data_tf, map_dict[name]["transpose"][idx_list])
  719. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  720. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  721. var_dict_torch[
  722. name].size(),
  723. data_tf.size())
  724. var_dict_torch_update[name] = data_tf
  725. logging.info("torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(),
  726. name_tf[idx_list],
  727. var_dict_tf[name_tf[
  728. idx_list]].shape))
  729. else:
  730. data_tf = var_dict_tf[name_tf]
  731. if map_dict[name]["squeeze"] is not None:
  732. data_tf = np.squeeze(data_tf, axis=map_dict[name]["squeeze"])
  733. if map_dict[name]["transpose"] is not None:
  734. data_tf = np.transpose(data_tf, map_dict[name]["transpose"])
  735. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  736. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  737. var_dict_torch[
  738. name].size(),
  739. data_tf.size())
  740. var_dict_torch_update[name] = data_tf
  741. logging.info(
  742. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
  743. var_dict_tf[name_tf].shape))
  744. elif names[1] == "after_norm":
  745. name_tf = map_dict[name]["name"]
  746. data_tf = var_dict_tf[name_tf]
  747. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  748. var_dict_torch_update[name] = data_tf
  749. logging.info(
  750. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
  751. var_dict_tf[name_tf].shape))
  752. elif names[1] == "embed_concat_ffn":
  753. layeridx = int(names[2])
  754. name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
  755. layeridx_bias = 0
  756. layeridx += layeridx_bias
  757. if "decoders." in name:
  758. decoder_layeridx_sets.add(layeridx)
  759. if name_q in map_dict.keys():
  760. name_v = map_dict[name_q]["name"]
  761. name_tf = name_v.replace("layeridx", "{}".format(layeridx))
  762. data_tf = var_dict_tf[name_tf]
  763. if map_dict[name_q]["squeeze"] is not None:
  764. data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
  765. if map_dict[name_q]["transpose"] is not None:
  766. data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
  767. data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
  768. assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
  769. var_dict_torch[
  770. name].size(),
  771. data_tf.size())
  772. var_dict_torch_update[name] = data_tf
  773. logging.info(
  774. "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
  775. var_dict_tf[name_tf].shape))
  776. return var_dict_torch_update